Merge pull request #1284 from ashleysommer/aiohttp_update
Fix broken tests when aiohttp >= 3.3.0
This commit is contained in:
commit
d52498b787
|
@ -1,5 +1,5 @@
|
||||||
aiofiles
|
aiofiles
|
||||||
aiohttp>=2.3.0
|
aiohttp>=2.3.0,<=3.2.1
|
||||||
chardet<=2.3.0
|
chardet<=2.3.0
|
||||||
beautifulsoup4
|
beautifulsoup4
|
||||||
coverage
|
coverage
|
||||||
|
|
|
@ -9,14 +9,39 @@ import aiohttp
|
||||||
from aiohttp import TCPConnector
|
from aiohttp import TCPConnector
|
||||||
from sanic.testing import SanicTestClient, HOST, PORT
|
from sanic.testing import SanicTestClient, HOST, PORT
|
||||||
|
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
import packaging # direct use
|
||||||
|
except ImportError:
|
||||||
|
# setuptools v39.0 and above.
|
||||||
|
try:
|
||||||
|
from setuptools.extern import packaging
|
||||||
|
except ImportError:
|
||||||
|
# Before setuptools v39.0
|
||||||
|
from pkg_resources.extern import packaging
|
||||||
|
version = packaging.version
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError("The 'packaging' library is missing.")
|
||||||
|
|
||||||
|
aiohttp_version = version.parse(aiohttp.__version__)
|
||||||
|
|
||||||
class ReuseableTCPConnector(TCPConnector):
|
class ReuseableTCPConnector(TCPConnector):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(ReuseableTCPConnector, self).__init__(*args, **kwargs)
|
super(ReuseableTCPConnector, self).__init__(*args, **kwargs)
|
||||||
self.old_proto = None
|
self.old_proto = None
|
||||||
|
|
||||||
if aiohttp.__version__ >= '3.0':
|
if aiohttp_version >= version.parse('3.3.0'):
|
||||||
|
async def connect(self, req, traces, timeout):
|
||||||
|
new_conn = await super(ReuseableTCPConnector, self)\
|
||||||
|
.connect(req, traces, timeout)
|
||||||
|
if self.old_proto is not None:
|
||||||
|
if self.old_proto != new_conn._protocol:
|
||||||
|
raise RuntimeError(
|
||||||
|
"We got a new connection, wanted the same one!")
|
||||||
|
print(new_conn.__dict__)
|
||||||
|
self.old_proto = new_conn._protocol
|
||||||
|
return new_conn
|
||||||
|
elif aiohttp_version >= version.parse('3.0.0'):
|
||||||
async def connect(self, req, traces=None):
|
async def connect(self, req, traces=None):
|
||||||
new_conn = await super(ReuseableTCPConnector, self)\
|
new_conn = await super(ReuseableTCPConnector, self)\
|
||||||
.connect(req, traces=traces)
|
.connect(req, traces=traces)
|
||||||
|
@ -28,7 +53,6 @@ class ReuseableTCPConnector(TCPConnector):
|
||||||
self.old_proto = new_conn._protocol
|
self.old_proto = new_conn._protocol
|
||||||
return new_conn
|
return new_conn
|
||||||
else:
|
else:
|
||||||
|
|
||||||
async def connect(self, req):
|
async def connect(self, req):
|
||||||
new_conn = await super(ReuseableTCPConnector, self)\
|
new_conn = await super(ReuseableTCPConnector, self)\
|
||||||
.connect(req)
|
.connect(req)
|
||||||
|
|
|
@ -5,9 +5,24 @@ import asyncio
|
||||||
from sanic.response import text
|
from sanic.response import text
|
||||||
from sanic.config import Config
|
from sanic.config import Config
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import TCPConnector
|
from aiohttp import TCPConnector, ClientResponse
|
||||||
from sanic.testing import SanicTestClient, HOST, PORT
|
from sanic.testing import SanicTestClient, HOST, PORT
|
||||||
|
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
import packaging # direct use
|
||||||
|
except ImportError:
|
||||||
|
# setuptools v39.0 and above.
|
||||||
|
try:
|
||||||
|
from setuptools.extern import packaging
|
||||||
|
except ImportError:
|
||||||
|
# Before setuptools v39.0
|
||||||
|
from pkg_resources.extern import packaging
|
||||||
|
version = packaging.version
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError("The 'packaging' library is missing.")
|
||||||
|
|
||||||
|
aiohttp_version = version.parse(aiohttp.__version__)
|
||||||
|
|
||||||
class DelayableTCPConnector(TCPConnector):
|
class DelayableTCPConnector(TCPConnector):
|
||||||
|
|
||||||
|
@ -38,8 +53,11 @@ class DelayableTCPConnector(TCPConnector):
|
||||||
self.orig_start = getattr(resp, 'start')
|
self.orig_start = getattr(resp, 'start')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ret = await self.orig_start(connection,
|
if aiohttp_version >= version.parse("3.3.0"):
|
||||||
read_until_eof)
|
ret = await self.orig_start(connection)
|
||||||
|
else:
|
||||||
|
ret = await self.orig_start(connection,
|
||||||
|
read_until_eof)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
return ret
|
return ret
|
||||||
|
@ -57,15 +75,31 @@ class DelayableTCPConnector(TCPConnector):
|
||||||
await asyncio.sleep(self.delay)
|
await asyncio.sleep(self.delay)
|
||||||
t = req.loop.time()
|
t = req.loop.time()
|
||||||
print("sending at {}".format(t), flush=True)
|
print("sending at {}".format(t), flush=True)
|
||||||
conn = next(iter(args)) # first arg is connection
|
conn = next(iter(args)) # first arg is connection
|
||||||
if aiohttp.__version__ >= "3.1.0":
|
|
||||||
|
if aiohttp_version >= version.parse("3.1.0"):
|
||||||
try:
|
try:
|
||||||
delayed_resp = await self.orig_send(*args, **kwargs)
|
delayed_resp = await self.orig_send(*args, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return aiohttp.ClientResponse(req.method, req.url,
|
if aiohttp_version >= version.parse("3.3.0"):
|
||||||
writer=None, continue100=None, timer=None,
|
return aiohttp.ClientResponse(req.method, req.url,
|
||||||
request_info=None, auto_decompress=None, traces=[],
|
writer=None,
|
||||||
loop=req.loop, session=None)
|
continue100=None,
|
||||||
|
timer=None,
|
||||||
|
request_info=None,
|
||||||
|
traces=[],
|
||||||
|
loop=req.loop,
|
||||||
|
session=None)
|
||||||
|
else:
|
||||||
|
return aiohttp.ClientResponse(req.method, req.url,
|
||||||
|
writer=None,
|
||||||
|
continue100=None,
|
||||||
|
timer=None,
|
||||||
|
request_info=None,
|
||||||
|
auto_decompress=None,
|
||||||
|
traces=[],
|
||||||
|
loop=req.loop,
|
||||||
|
session=None)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
delayed_resp = self.orig_send(*args, **kwargs)
|
delayed_resp = self.orig_send(*args, **kwargs)
|
||||||
|
@ -73,7 +107,7 @@ class DelayableTCPConnector(TCPConnector):
|
||||||
return aiohttp.ClientResponse(req.method, req.url)
|
return aiohttp.ClientResponse(req.method, req.url)
|
||||||
return delayed_resp
|
return delayed_resp
|
||||||
|
|
||||||
if aiohttp.__version__ >= "3.1.0":
|
if aiohttp_version >= version.parse("3.1.0"):
|
||||||
# aiohttp changed the request.send method to async
|
# aiohttp changed the request.send method to async
|
||||||
async def send(self, *args, **kwargs):
|
async def send(self, *args, **kwargs):
|
||||||
gen = self.delayed_send(*args, **kwargs)
|
gen = self.delayed_send(*args, **kwargs)
|
||||||
|
@ -96,12 +130,25 @@ class DelayableTCPConnector(TCPConnector):
|
||||||
self._post_connect_delay = _post_connect_delay
|
self._post_connect_delay = _post_connect_delay
|
||||||
self._pre_request_delay = _pre_request_delay
|
self._pre_request_delay = _pre_request_delay
|
||||||
|
|
||||||
if aiohttp.__version__ >= '3.0':
|
if aiohttp_version >= version.parse("3.3.0"):
|
||||||
|
async def connect(self, req, traces, timeout):
|
||||||
|
d_req = DelayableTCPConnector.\
|
||||||
|
RequestContextManager(req, self._pre_request_delay)
|
||||||
|
conn = await super(DelayableTCPConnector, self).\
|
||||||
|
connect(req, traces, timeout)
|
||||||
|
if self._post_connect_delay and self._post_connect_delay > 0:
|
||||||
|
await asyncio.sleep(self._post_connect_delay,
|
||||||
|
loop=self._loop)
|
||||||
|
req.send = d_req.send
|
||||||
|
t = req.loop.time()
|
||||||
|
print("Connected at {}".format(t), flush=True)
|
||||||
|
return conn
|
||||||
|
elif aiohttp_version >= version.parse("3.0.0"):
|
||||||
async def connect(self, req, traces=None):
|
async def connect(self, req, traces=None):
|
||||||
d_req = DelayableTCPConnector.\
|
d_req = DelayableTCPConnector.\
|
||||||
RequestContextManager(req, self._pre_request_delay)
|
RequestContextManager(req, self._pre_request_delay)
|
||||||
conn = await super(DelayableTCPConnector, self).connect(req, traces=traces)
|
conn = await super(DelayableTCPConnector, self).\
|
||||||
|
connect(req, traces=traces)
|
||||||
if self._post_connect_delay and self._post_connect_delay > 0:
|
if self._post_connect_delay and self._post_connect_delay > 0:
|
||||||
await asyncio.sleep(self._post_connect_delay,
|
await asyncio.sleep(self._post_connect_delay,
|
||||||
loop=self._loop)
|
loop=self._loop)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user