simplified aiohttp version diffs, reverted worker import policy
This commit is contained in:
parent
7a9e100b0f
commit
b7d74c82ba
@ -1,30 +1,12 @@
|
||||
import sys
|
||||
import traceback
|
||||
from json import JSONDecodeError
|
||||
from sanic.log import logger
|
||||
from sanic.exceptions import MethodNotSupported
|
||||
from sanic.response import text
|
||||
|
||||
try:
|
||||
try:
|
||||
# direct use
|
||||
import packaging as _packaging
|
||||
version = _packaging.version
|
||||
except (ImportError, AttributeError):
|
||||
# setuptools v39.0 and above.
|
||||
try:
|
||||
from setuptools.extern import packaging as _packaging
|
||||
except ImportError:
|
||||
# Before setuptools v39.0
|
||||
from pkg_resources.extern import packaging as _packaging
|
||||
version = _packaging.version
|
||||
except ImportError:
|
||||
raise RuntimeError("The 'packaging' library is missing.")
|
||||
|
||||
|
||||
HOST = '127.0.0.1'
|
||||
PORT = 42101
|
||||
is_windows = sys.platform in ['win32', 'cygwin']
|
||||
|
||||
|
||||
class SanicTestClient:
|
||||
|
@ -12,18 +12,15 @@ except ImportError:
|
||||
|
||||
try:
|
||||
import uvloop
|
||||
import gunicorn.workers.base as base
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
except ImportError:
|
||||
base = None
|
||||
pass
|
||||
import gunicorn.workers.base as base
|
||||
|
||||
from sanic.server import trigger_events, serve, HttpProtocol, Signal
|
||||
from sanic.websocket import WebSocketProtocol
|
||||
|
||||
|
||||
# gunicorn is not available on windows
|
||||
if base is not None:
|
||||
class GunicornWorker(base.Worker):
|
||||
|
||||
http_protocol = HttpProtocol
|
||||
@ -147,23 +144,18 @@ if base is not None:
|
||||
self.notify()
|
||||
|
||||
req_count = sum(
|
||||
srv['requests_count'] for srv in self.servers.values()
|
||||
self.servers[srv]["requests_count"] for srv in self.servers
|
||||
)
|
||||
if self.max_requests and req_count > self.max_requests:
|
||||
self.alive = False
|
||||
self.log.info(
|
||||
"Max requests exceeded, shutting down: %s",
|
||||
self
|
||||
)
|
||||
self.log.info("Max requests exceeded, shutting down: %s",
|
||||
self)
|
||||
elif pid == os.getpid() and self.ppid != os.getppid():
|
||||
self.alive = False
|
||||
self.log.info(
|
||||
"Parent changed, shutting down: %s",
|
||||
self
|
||||
)
|
||||
self.log.info("Parent changed, shutting down: %s", self)
|
||||
else:
|
||||
await asyncio.sleep(1.0, loop=self.loop)
|
||||
except (BaseException, GeneratorExit, KeyboardInterrupt):
|
||||
except (Exception, BaseException, GeneratorExit, KeyboardInterrupt):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
|
@ -1,7 +1,11 @@
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
from sanic import Sanic
|
||||
|
||||
if sys.platform in ['win32', 'cygwin']:
|
||||
collect_ignore = ["test_worker.py"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(request):
|
||||
|
@ -7,10 +7,7 @@ from sanic.config import Config
|
||||
from sanic import server
|
||||
import aiohttp
|
||||
from aiohttp import TCPConnector
|
||||
from sanic.testing import SanicTestClient, HOST, PORT, version
|
||||
|
||||
|
||||
aiohttp_version = version.parse(aiohttp.__version__)
|
||||
from sanic.testing import SanicTestClient, HOST, PORT
|
||||
|
||||
|
||||
class ReuseableTCPConnector(TCPConnector):
|
||||
@ -18,32 +15,9 @@ class ReuseableTCPConnector(TCPConnector):
|
||||
super(ReuseableTCPConnector, self).__init__(*args, **kwargs)
|
||||
self.old_proto = None
|
||||
|
||||
if aiohttp_version >= version.parse('3.3.0'):
|
||||
async def connect(self, req, traces, timeout):
|
||||
async def connect(self, req, *args, **kwargs):
|
||||
new_conn = await super(ReuseableTCPConnector, self)\
|
||||
.connect(req, traces, timeout)
|
||||
if self.old_proto is not None:
|
||||
if self.old_proto != new_conn._protocol:
|
||||
raise RuntimeError(
|
||||
"We got a new connection, wanted the same one!")
|
||||
print(new_conn.__dict__)
|
||||
self.old_proto = new_conn._protocol
|
||||
return new_conn
|
||||
elif aiohttp_version >= version.parse('3.0.0'):
|
||||
async def connect(self, req, traces=None):
|
||||
new_conn = await super(ReuseableTCPConnector, self)\
|
||||
.connect(req, traces=traces)
|
||||
if self.old_proto is not None:
|
||||
if self.old_proto != new_conn._protocol:
|
||||
raise RuntimeError(
|
||||
"We got a new connection, wanted the same one!")
|
||||
print(new_conn.__dict__)
|
||||
self.old_proto = new_conn._protocol
|
||||
return new_conn
|
||||
else:
|
||||
async def connect(self, req):
|
||||
new_conn = await super(ReuseableTCPConnector, self)\
|
||||
.connect(req)
|
||||
.connect(req, *args, **kwargs)
|
||||
if self.old_proto is not None:
|
||||
if self.old_proto != new_conn._protocol:
|
||||
raise RuntimeError(
|
||||
@ -135,14 +109,14 @@ class ReuseableSanicTestClient(SanicTestClient):
|
||||
try:
|
||||
request, response = results
|
||||
return request, response
|
||||
except:
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"Request and response object expected, got ({})".format(
|
||||
results))
|
||||
else:
|
||||
try:
|
||||
return results[-1]
|
||||
except:
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"Request object expected, got ({})".format(results))
|
||||
|
||||
|
@ -6,7 +6,23 @@ from sanic.response import text
|
||||
from sanic.config import Config
|
||||
import aiohttp
|
||||
from aiohttp import TCPConnector
|
||||
from sanic.testing import SanicTestClient, HOST, version
|
||||
from sanic.testing import SanicTestClient, HOST
|
||||
|
||||
try:
|
||||
try:
|
||||
# direct use
|
||||
import packaging
|
||||
version = packaging.version
|
||||
except (ImportError, AttributeError):
|
||||
# setuptools v39.0 and above.
|
||||
try:
|
||||
from setuptools.extern import packaging
|
||||
except ImportError:
|
||||
# Before setuptools v39.0
|
||||
from pkg_resources.extern import packaging
|
||||
version = packaging.version
|
||||
except ImportError:
|
||||
raise RuntimeError("The 'packaging' library is missing.")
|
||||
|
||||
|
||||
aiohttp_version = version.parse(aiohttp.__version__)
|
||||
@ -44,8 +60,7 @@ class DelayableTCPConnector(TCPConnector):
|
||||
if aiohttp_version >= version.parse("3.3.0"):
|
||||
ret = await self.orig_start(connection)
|
||||
else:
|
||||
ret = await self.orig_start(connection,
|
||||
read_until_eof)
|
||||
ret = await self.orig_start(connection, read_until_eof)
|
||||
except Exception as e:
|
||||
raise e
|
||||
return ret
|
||||
@ -65,51 +80,37 @@ class DelayableTCPConnector(TCPConnector):
|
||||
print("sending at {}".format(t), flush=True)
|
||||
conn = next(iter(args)) # first arg is connection
|
||||
|
||||
if aiohttp_version >= version.parse("3.1.0"):
|
||||
try:
|
||||
delayed_resp = await self.orig_send(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if aiohttp_version >= version.parse("3.3.0"):
|
||||
return aiohttp.ClientResponse(req.method, req.url,
|
||||
writer=None,
|
||||
continue100=None,
|
||||
timer=None,
|
||||
request_info=None,
|
||||
traces=[],
|
||||
loop=req.loop,
|
||||
session=None)
|
||||
else:
|
||||
return aiohttp.ClientResponse(req.method, req.url,
|
||||
writer=None,
|
||||
continue100=None,
|
||||
timer=None,
|
||||
request_info=None,
|
||||
auto_decompress=None,
|
||||
traces=[],
|
||||
loop=req.loop,
|
||||
session=None)
|
||||
else:
|
||||
try:
|
||||
delayed_resp = self.orig_send(*args, **kwargs)
|
||||
return await self.orig_send(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if aiohttp_version < version.parse("3.1.0"):
|
||||
return aiohttp.ClientResponse(req.method, req.url)
|
||||
return delayed_resp
|
||||
kw = dict(
|
||||
writer=None,
|
||||
continue100=None,
|
||||
timer=None,
|
||||
request_info=None,
|
||||
traces=[],
|
||||
loop=req.loop,
|
||||
session=None
|
||||
)
|
||||
if aiohttp_version < version.parse("3.3.0"):
|
||||
kw['auto_decompress'] = None
|
||||
return aiohttp.ClientResponse(req.method, req.url, **kw)
|
||||
|
||||
def _send(self, *args, **kwargs):
|
||||
gen = self.delayed_send(*args, **kwargs)
|
||||
task = self.req.loop.create_task(gen)
|
||||
self.send_task = task
|
||||
self._acting_as = task
|
||||
return self
|
||||
|
||||
if aiohttp_version >= version.parse("3.1.0"):
|
||||
# aiohttp changed the request.send method to async
|
||||
async def send(self, *args, **kwargs):
|
||||
gen = self.delayed_send(*args, **kwargs)
|
||||
task = self.req.loop.create_task(gen)
|
||||
self.send_task = task
|
||||
self._acting_as = task
|
||||
return self
|
||||
return self._send(*args, **kwargs)
|
||||
else:
|
||||
def send(self, *args, **kwargs):
|
||||
gen = self.delayed_send(*args, **kwargs)
|
||||
task = self.req.loop.create_task(gen)
|
||||
self.send_task = task
|
||||
self._acting_as = task
|
||||
return self
|
||||
send = _send
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
_post_connect_delay = kwargs.pop('post_connect_delay', 0)
|
||||
@ -118,38 +119,11 @@ class DelayableTCPConnector(TCPConnector):
|
||||
self._post_connect_delay = _post_connect_delay
|
||||
self._pre_request_delay = _pre_request_delay
|
||||
|
||||
if aiohttp_version >= version.parse("3.3.0"):
|
||||
async def connect(self, req, traces, timeout):
|
||||
async def connect(self, req, *args, **kwargs):
|
||||
d_req = DelayableTCPConnector.\
|
||||
RequestContextManager(req, self._pre_request_delay)
|
||||
conn = await super(DelayableTCPConnector, self).\
|
||||
connect(req, traces, timeout)
|
||||
if self._post_connect_delay and self._post_connect_delay > 0:
|
||||
await asyncio.sleep(self._post_connect_delay,
|
||||
loop=self._loop)
|
||||
req.send = d_req.send
|
||||
t = req.loop.time()
|
||||
print("Connected at {}".format(t), flush=True)
|
||||
return conn
|
||||
elif aiohttp_version >= version.parse("3.0.0"):
|
||||
async def connect(self, req, traces=None):
|
||||
d_req = DelayableTCPConnector.\
|
||||
RequestContextManager(req, self._pre_request_delay)
|
||||
conn = await super(DelayableTCPConnector, self).\
|
||||
connect(req, traces=traces)
|
||||
if self._post_connect_delay and self._post_connect_delay > 0:
|
||||
await asyncio.sleep(self._post_connect_delay,
|
||||
loop=self._loop)
|
||||
req.send = d_req.send
|
||||
t = req.loop.time()
|
||||
print("Connected at {}".format(t), flush=True)
|
||||
return conn
|
||||
else:
|
||||
|
||||
async def connect(self, req):
|
||||
d_req = DelayableTCPConnector.\
|
||||
RequestContextManager(req, self._pre_request_delay)
|
||||
conn = await super(DelayableTCPConnector, self).connect(req)
|
||||
connect(req, *args, **kwargs)
|
||||
if self._post_connect_delay and self._post_connect_delay > 0:
|
||||
await asyncio.sleep(self._post_connect_delay,
|
||||
loop=self._loop)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import sys
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
@ -12,7 +13,7 @@ from sanic.response import (
|
||||
HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json
|
||||
)
|
||||
from sanic.server import HttpProtocol
|
||||
from sanic.testing import HOST, PORT, is_windows
|
||||
from sanic.testing import HOST, PORT
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
JSON_DATA = {'ok': True}
|
||||
@ -74,6 +75,7 @@ def test_response_header(app):
|
||||
'CONTENT-TYPE': 'application/json'
|
||||
})
|
||||
|
||||
is_windows = sys.platform in ['win32', 'cygwin']
|
||||
request, response = app.test_client.get('/')
|
||||
assert dict(response.headers) == {
|
||||
'Connection': 'keep-alive',
|
||||
|
@ -4,17 +4,10 @@ import shlex
|
||||
import subprocess
|
||||
import urllib.request
|
||||
from unittest import mock
|
||||
from sanic.worker import GunicornWorker
|
||||
from sanic.app import Sanic
|
||||
import asyncio
|
||||
import pytest
|
||||
from sanic.app import Sanic
|
||||
try:
|
||||
from sanic.worker import GunicornWorker
|
||||
except ImportError:
|
||||
pytestmark = pytest.mark.skip(
|
||||
reason="GunicornWorker Not supported on this platform"
|
||||
)
|
||||
# this has to be defined or pytest will err on import
|
||||
GunicornWorker = object
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
@ -107,11 +100,12 @@ def test_run_max_requests_exceeded(worker):
|
||||
_runner = asyncio.ensure_future(worker._check_alive(), loop=loop)
|
||||
loop.run_until_complete(_runner)
|
||||
|
||||
assert worker.alive == False
|
||||
assert not worker.alive
|
||||
worker.notify.assert_called_with()
|
||||
worker.log.info.assert_called_with("Max requests exceeded, shutting down: %s",
|
||||
worker)
|
||||
|
||||
|
||||
def test_worker_close(worker):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.sleep = mock.Mock(wraps=asyncio.coroutine(lambda *a, **kw: None))
|
||||
@ -141,6 +135,6 @@ def test_worker_close(worker):
|
||||
_close = asyncio.ensure_future(worker.close(), loop=loop)
|
||||
loop.run_until_complete(_close)
|
||||
|
||||
assert worker.signal.stopped == True
|
||||
assert conn.websocket.close_connection.called == True
|
||||
assert worker.signal.stopped
|
||||
assert conn.websocket.close_connection.called
|
||||
assert len(worker.servers) == 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user