Merge pull request #26 from huge-success/master

Merge upstream master branch
This commit is contained in:
7 2018-10-10 20:33:57 -07:00 committed by GitHub
commit 1bf1c9d006
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 143 additions and 188 deletions

View File

@ -34,6 +34,6 @@ def authorized():
@app.route("/")
@authorized()
async def test(request):
return json({status: 'authorized'})
return json({'status': 'authorized'})
```

View File

@ -1,6 +1,5 @@
import sys
import json
import socket
from cgi import parse_header
from collections import namedtuple
from http.cookies import SimpleCookie
@ -192,18 +191,10 @@ class Request(dict):
return self._socket
def _get_address(self):
sock = self.transport.get_extra_info('socket')
if sock.family == socket.AF_INET:
self._socket = (self.transport.get_extra_info('peername') or
(None, None))
self._ip, self._port = self._socket
elif sock.family == socket.AF_INET6:
self._socket = (self.transport.get_extra_info('peername') or
(None, None, None, None))
self._ip, self._port, *_ = self._socket
else:
self._ip, self._port = (None, None)
self._socket = self.transport.get_extra_info('peername') or \
(None, None)
self._ip = self._socket[0]
self._port = self._socket[1]
@property
def remote_addr(self):

View File

@ -1,7 +1,11 @@
import sys
import pytest
from sanic import Sanic
if sys.platform in ['win32', 'cygwin']:
collect_ignore = ["test_worker.py"]
@pytest.fixture
def app(request):

View File

@ -1,10 +1,20 @@
from os import environ
from pathlib import Path
from contextlib import contextmanager
from tempfile import TemporaryDirectory
from textwrap import dedent
import pytest
from tempfile import NamedTemporaryFile
from sanic import Sanic
@contextmanager
def temp_path():
""" a simple cross platform replacement for NamedTemporaryFile """
with TemporaryDirectory() as td:
yield Path(td, 'file')
def test_load_from_object(app):
class Config:
not_for_config = 'should not be used'
@ -38,16 +48,15 @@ def test_load_env_prefix():
def test_load_from_file(app):
config = b"""
VALUE = 'some value'
condition = 1 == 1
if condition:
CONDITIONAL = 'should be set'
"""
with NamedTemporaryFile() as config_file:
config_file.write(config)
config_file.seek(0)
app.config.from_pyfile(config_file.name)
config = dedent("""
VALUE = 'some value'
condition = 1 == 1
if condition:
CONDITIONAL = 'should be set'
""")
with temp_path() as config_path:
config_path.write_text(config)
app.config.from_pyfile(str(config_path))
assert 'VALUE' in app.config
assert app.config.VALUE == 'some value'
assert 'CONDITIONAL' in app.config
@ -61,11 +70,10 @@ def test_load_from_missing_file(app):
def test_load_from_envvar(app):
config = b"VALUE = 'some value'"
with NamedTemporaryFile() as config_file:
config_file.write(config)
config_file.seek(0)
environ['APP_CONFIG'] = config_file.name
config = "VALUE = 'some value'"
with temp_path() as config_path:
config_path.write_text(config)
environ['APP_CONFIG'] = str(config_path)
app.config.from_envvar('APP_CONFIG')
assert 'VALUE' in app.config
assert app.config.VALUE == 'some value'

View File

@ -9,60 +9,22 @@ import aiohttp
from aiohttp import TCPConnector
from sanic.testing import SanicTestClient, HOST, PORT
try:
try:
import packaging # direct use
except ImportError:
# setuptools v39.0 and above.
try:
from setuptools.extern import packaging
except ImportError:
# Before setuptools v39.0
from pkg_resources.extern import packaging
version = packaging.version
except ImportError:
raise RuntimeError("The 'packaging' library is missing.")
aiohttp_version = version.parse(aiohttp.__version__)
class ReuseableTCPConnector(TCPConnector):
def __init__(self, *args, **kwargs):
super(ReuseableTCPConnector, self).__init__(*args, **kwargs)
self.old_proto = None
if aiohttp_version >= version.parse('3.3.0'):
async def connect(self, req, traces, timeout):
new_conn = await super(ReuseableTCPConnector, self)\
.connect(req, traces, timeout)
if self.old_proto is not None:
if self.old_proto != new_conn._protocol:
raise RuntimeError(
"We got a new connection, wanted the same one!")
print(new_conn.__dict__)
self.old_proto = new_conn._protocol
return new_conn
elif aiohttp_version >= version.parse('3.0.0'):
async def connect(self, req, traces=None):
new_conn = await super(ReuseableTCPConnector, self)\
.connect(req, traces=traces)
if self.old_proto is not None:
if self.old_proto != new_conn._protocol:
raise RuntimeError(
"We got a new connection, wanted the same one!")
print(new_conn.__dict__)
self.old_proto = new_conn._protocol
return new_conn
else:
async def connect(self, req):
new_conn = await super(ReuseableTCPConnector, self)\
.connect(req)
if self.old_proto is not None:
if self.old_proto != new_conn._protocol:
raise RuntimeError(
"We got a new connection, wanted the same one!")
print(new_conn.__dict__)
self.old_proto = new_conn._protocol
return new_conn
async def connect(self, req, *args, **kwargs):
new_conn = await super(ReuseableTCPConnector, self)\
.connect(req, *args, **kwargs)
if self.old_proto is not None:
if self.old_proto != new_conn._protocol:
raise RuntimeError(
"We got a new connection, wanted the same one!")
print(new_conn.__dict__)
self.old_proto = new_conn._protocol
return new_conn
class ReuseableSanicTestClient(SanicTestClient):
@ -147,14 +109,14 @@ class ReuseableSanicTestClient(SanicTestClient):
try:
request, response = results
return request, response
except:
except Exception:
raise ValueError(
"Request and response object expected, got ({})".format(
results))
else:
try:
return results[-1]
except:
except Exception:
raise ValueError(
"Request object expected, got ({})".format(results))

View File

@ -1,14 +1,20 @@
import multiprocessing
import random
import signal
import pytest
from sanic.testing import HOST, PORT
@pytest.mark.skipif(
not hasattr(signal, 'SIGALRM'),
reason='SIGALRM is not implemented for this platform, we have to come '
'up with another timeout strategy to test these'
)
def test_multiprocessing(app):
"""Tests that the number of children we produce is correct"""
# Selects a number at random so we can spot check
num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1))
num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1))
process_list = set()
def stop_on_alarm(*args):

View File

@ -5,13 +5,15 @@ import asyncio
from sanic.response import text
from sanic.config import Config
import aiohttp
from aiohttp import TCPConnector, ClientResponse
from sanic.testing import SanicTestClient, HOST, PORT
from aiohttp import TCPConnector
from sanic.testing import SanicTestClient, HOST
try:
try:
import packaging # direct use
except ImportError:
# direct use
import packaging
version = packaging.version
except (ImportError, AttributeError):
# setuptools v39.0 and above.
try:
from setuptools.extern import packaging
@ -22,8 +24,10 @@ try:
except ImportError:
raise RuntimeError("The 'packaging' library is missing.")
aiohttp_version = version.parse(aiohttp.__version__)
class DelayableTCPConnector(TCPConnector):
class RequestContextManager(object):
@ -56,8 +60,7 @@ class DelayableTCPConnector(TCPConnector):
if aiohttp_version >= version.parse("3.3.0"):
ret = await self.orig_start(connection)
else:
ret = await self.orig_start(connection,
read_until_eof)
ret = await self.orig_start(connection, read_until_eof)
except Exception as e:
raise e
return ret
@ -71,57 +74,43 @@ class DelayableTCPConnector(TCPConnector):
async def delayed_send(self, *args, **kwargs):
req = self.req
if self.delay and self.delay > 0:
#sync_sleep(self.delay)
# sync_sleep(self.delay)
await asyncio.sleep(self.delay)
t = req.loop.time()
print("sending at {}".format(t), flush=True)
conn = next(iter(args)) # first arg is connection
if aiohttp_version >= version.parse("3.1.0"):
try:
delayed_resp = await self.orig_send(*args, **kwargs)
except Exception as e:
if aiohttp_version >= version.parse("3.3.0"):
return aiohttp.ClientResponse(req.method, req.url,
writer=None,
continue100=None,
timer=None,
request_info=None,
traces=[],
loop=req.loop,
session=None)
else:
return aiohttp.ClientResponse(req.method, req.url,
writer=None,
continue100=None,
timer=None,
request_info=None,
auto_decompress=None,
traces=[],
loop=req.loop,
session=None)
else:
try:
delayed_resp = self.orig_send(*args, **kwargs)
except Exception as e:
try:
return await self.orig_send(*args, **kwargs)
except Exception as e:
if aiohttp_version < version.parse("3.1.0"):
return aiohttp.ClientResponse(req.method, req.url)
return delayed_resp
kw = dict(
writer=None,
continue100=None,
timer=None,
request_info=None,
traces=[],
loop=req.loop,
session=None
)
if aiohttp_version < version.parse("3.3.0"):
kw['auto_decompress'] = None
return aiohttp.ClientResponse(req.method, req.url, **kw)
def _send(self, *args, **kwargs):
gen = self.delayed_send(*args, **kwargs)
task = self.req.loop.create_task(gen)
self.send_task = task
self._acting_as = task
return self
if aiohttp_version >= version.parse("3.1.0"):
# aiohttp changed the request.send method to async
async def send(self, *args, **kwargs):
gen = self.delayed_send(*args, **kwargs)
task = self.req.loop.create_task(gen)
self.send_task = task
self._acting_as = task
return self
return self._send(*args, **kwargs)
else:
def send(self, *args, **kwargs):
gen = self.delayed_send(*args, **kwargs)
task = self.req.loop.create_task(gen)
self.send_task = task
self._acting_as = task
return self
send = _send
def __init__(self, *args, **kwargs):
_post_connect_delay = kwargs.pop('post_connect_delay', 0)
@ -130,45 +119,18 @@ class DelayableTCPConnector(TCPConnector):
self._post_connect_delay = _post_connect_delay
self._pre_request_delay = _pre_request_delay
if aiohttp_version >= version.parse("3.3.0"):
async def connect(self, req, traces, timeout):
d_req = DelayableTCPConnector.\
RequestContextManager(req, self._pre_request_delay)
conn = await super(DelayableTCPConnector, self).\
connect(req, traces, timeout)
if self._post_connect_delay and self._post_connect_delay > 0:
await asyncio.sleep(self._post_connect_delay,
loop=self._loop)
req.send = d_req.send
t = req.loop.time()
print("Connected at {}".format(t), flush=True)
return conn
elif aiohttp_version >= version.parse("3.0.0"):
async def connect(self, req, traces=None):
d_req = DelayableTCPConnector.\
RequestContextManager(req, self._pre_request_delay)
conn = await super(DelayableTCPConnector, self).\
connect(req, traces=traces)
if self._post_connect_delay and self._post_connect_delay > 0:
await asyncio.sleep(self._post_connect_delay,
loop=self._loop)
req.send = d_req.send
t = req.loop.time()
print("Connected at {}".format(t), flush=True)
return conn
else:
async def connect(self, req):
d_req = DelayableTCPConnector.\
RequestContextManager(req, self._pre_request_delay)
conn = await super(DelayableTCPConnector, self).connect(req)
if self._post_connect_delay and self._post_connect_delay > 0:
await asyncio.sleep(self._post_connect_delay,
loop=self._loop)
req.send = d_req.send
t = req.loop.time()
print("Connected at {}".format(t), flush=True)
return conn
async def connect(self, req, *args, **kwargs):
d_req = DelayableTCPConnector.\
RequestContextManager(req, self._pre_request_delay)
conn = await super(DelayableTCPConnector, self).\
connect(req, *args, **kwargs)
if self._post_connect_delay and self._post_connect_delay > 0:
await asyncio.sleep(self._post_connect_delay,
loop=self._loop)
req.send = d_req.send
t = req.loop.time()
print("Connected at {}".format(t), flush=True)
return conn
class DelayableSanicTestClient(SanicTestClient):

View File

@ -1,3 +1,4 @@
import sys
import asyncio
import inspect
import os
@ -8,7 +9,9 @@ from urllib.parse import unquote
import pytest
from random import choice
from sanic.response import HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json
from sanic.response import (
HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json
)
from sanic.server import HttpProtocol
from sanic.testing import HOST, PORT
from unittest.mock import MagicMock
@ -72,11 +75,15 @@ def test_response_header(app):
'CONTENT-TYPE': 'application/json'
})
is_windows = sys.platform in ['win32', 'cygwin']
request, response = app.test_client.get('/')
assert dict(response.headers) == {
'Connection': 'keep-alive',
'Keep-Alive': '2',
'Content-Length': '11',
'Keep-Alive': str(app.config.KEEP_ALIVE_TIMEOUT),
# response body contains an extra \r at the end if its windows
# TODO: this is the only place this difference shows up in our tests
# we should figure out a way to unify testing on both platforms
'Content-Length': '12' if is_windows else '11',
'Content-Type': 'application/json',
}

View File

@ -165,40 +165,42 @@ def test_route_optional_slash(app):
request, response = app.test_client.get('/get/')
assert response.text == 'OK'
def test_route_strict_slashes_set_to_false_and_host_is_a_list(app):
#Part of regression test for issue #1120
# Part of regression test for issue #1120
site1 = 'localhost:{}'.format(app.test_client.port)
site1 = '127.0.0.1:{}'.format(app.test_client.port)
#before fix, this raises a RouteExists error
# before fix, this raises a RouteExists error
@app.get('/get', host=[site1, 'site2.com'], strict_slashes=False)
def handler(request):
def get_handler(request):
return text('OK')
request, response = app.test_client.get('http://' + site1 + '/get')
assert response.text == 'OK'
@app.post('/post', host=[site1, 'site2.com'], strict_slashes=False)
def handler(request):
def post_handler(request):
return text('OK')
request, response = app.test_client.post('http://' + site1 +'/post')
request, response = app.test_client.post('http://' + site1 + '/post')
assert response.text == 'OK'
@app.put('/put', host=[site1, 'site2.com'], strict_slashes=False)
def handler(request):
def put_handler(request):
return text('OK')
request, response = app.test_client.put('http://' + site1 +'/put')
request, response = app.test_client.put('http://' + site1 + '/put')
assert response.text == 'OK'
@app.delete('/delete', host=[site1, 'site2.com'], strict_slashes=False)
def handler(request):
def delete_handler(request):
return text('OK')
request, response = app.test_client.delete('http://' + site1 +'/delete')
request, response = app.test_client.delete('http://' + site1 + '/delete')
assert response.text == 'OK'
def test_shorthand_routes_post(app):
@app.post('/post')

View File

@ -11,6 +11,12 @@ AVAILABLE_LISTENERS = [
'after_server_stop'
]
skipif_no_alarm = pytest.mark.skipif(
not hasattr(signal, 'SIGALRM'),
reason='SIGALRM is not implemented for this platform, we have to come '
'up with another timeout strategy to test these'
)
def create_listener(listener_name, in_list):
async def _listener(app, loop):
@ -32,6 +38,7 @@ def start_stop_app(random_name_app, **run_kwargs):
pass
@skipif_no_alarm
@pytest.mark.parametrize('listener_name', AVAILABLE_LISTENERS)
def test_single_listener(app, listener_name):
"""Test that listeners on their own work"""
@ -43,6 +50,7 @@ def test_single_listener(app, listener_name):
assert app.name + listener_name == output.pop()
@skipif_no_alarm
@pytest.mark.parametrize('listener_name', AVAILABLE_LISTENERS)
def test_register_listener(app, listener_name):
"""
@ -52,12 +60,12 @@ def test_register_listener(app, listener_name):
output = []
# Register listener
listener = create_listener(listener_name, output)
app.register_listener(listener,
event=listener_name)
app.register_listener(listener, event=listener_name)
start_stop_app(app)
assert app.name + listener_name == output.pop()
@skipif_no_alarm
def test_all_listeners(app):
output = []
for listener_name in AVAILABLE_LISTENERS:

View File

@ -7,13 +7,17 @@ from unittest import mock
from sanic.worker import GunicornWorker
from sanic.app import Sanic
import asyncio
import logging
import pytest
@pytest.fixture(scope='module')
def gunicorn_worker():
command = 'gunicorn --bind 127.0.0.1:1337 --worker-class sanic.worker.GunicornWorker examples.simple_server:app'
command = (
'gunicorn '
'--bind 127.0.0.1:1337 '
'--worker-class sanic.worker.GunicornWorker '
'examples.simple_server:app'
)
worker = subprocess.Popen(shlex.split(command))
time.sleep(3)
yield
@ -96,11 +100,12 @@ def test_run_max_requests_exceeded(worker):
_runner = asyncio.ensure_future(worker._check_alive(), loop=loop)
loop.run_until_complete(_runner)
assert worker.alive == False
assert not worker.alive
worker.notify.assert_called_with()
worker.log.info.assert_called_with("Max requests exceeded, shutting down: %s",
worker)
def test_worker_close(worker):
loop = asyncio.new_event_loop()
asyncio.sleep = mock.Mock(wraps=asyncio.coroutine(lambda *a, **kw: None))
@ -113,8 +118,8 @@ def test_worker_close(worker):
conn = mock.Mock()
conn.websocket = mock.Mock()
conn.websocket.close_connection = mock.Mock(
wraps=asyncio.coroutine(lambda *a, **kw: None)
)
wraps=asyncio.coroutine(lambda *a, **kw: None)
)
worker.connections = set([conn])
worker.log = mock.Mock()
worker.loop = loop
@ -130,6 +135,6 @@ def test_worker_close(worker):
_close = asyncio.ensure_future(worker.close(), loop=loop)
loop.run_until_complete(_close)
assert worker.signal.stopped == True
assert conn.websocket.close_connection.called == True
assert worker.signal.stopped
assert conn.websocket.close_connection.called
assert len(worker.servers) == 0