Merge pull request #26 from huge-success/master

Merge upstream master branch
This commit is contained in:
7 2018-10-10 20:33:57 -07:00 committed by GitHub
commit 1bf1c9d006
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 143 additions and 188 deletions

View File

@ -34,6 +34,6 @@ def authorized():
@app.route("/") @app.route("/")
@authorized() @authorized()
async def test(request): async def test(request):
return json({status: 'authorized'}) return json({'status': 'authorized'})
``` ```

View File

@ -1,6 +1,5 @@
import sys import sys
import json import json
import socket
from cgi import parse_header from cgi import parse_header
from collections import namedtuple from collections import namedtuple
from http.cookies import SimpleCookie from http.cookies import SimpleCookie
@ -192,18 +191,10 @@ class Request(dict):
return self._socket return self._socket
def _get_address(self): def _get_address(self):
sock = self.transport.get_extra_info('socket') self._socket = self.transport.get_extra_info('peername') or \
(None, None)
if sock.family == socket.AF_INET: self._ip = self._socket[0]
self._socket = (self.transport.get_extra_info('peername') or self._port = self._socket[1]
(None, None))
self._ip, self._port = self._socket
elif sock.family == socket.AF_INET6:
self._socket = (self.transport.get_extra_info('peername') or
(None, None, None, None))
self._ip, self._port, *_ = self._socket
else:
self._ip, self._port = (None, None)
@property @property
def remote_addr(self): def remote_addr(self):

View File

@ -1,7 +1,11 @@
import sys
import pytest import pytest
from sanic import Sanic from sanic import Sanic
if sys.platform in ['win32', 'cygwin']:
collect_ignore = ["test_worker.py"]
@pytest.fixture @pytest.fixture
def app(request): def app(request):

View File

@ -1,10 +1,20 @@
from os import environ from os import environ
from pathlib import Path
from contextlib import contextmanager
from tempfile import TemporaryDirectory
from textwrap import dedent
import pytest import pytest
from tempfile import NamedTemporaryFile
from sanic import Sanic from sanic import Sanic
@contextmanager
def temp_path():
""" a simple cross platform replacement for NamedTemporaryFile """
with TemporaryDirectory() as td:
yield Path(td, 'file')
def test_load_from_object(app): def test_load_from_object(app):
class Config: class Config:
not_for_config = 'should not be used' not_for_config = 'should not be used'
@ -38,16 +48,15 @@ def test_load_env_prefix():
def test_load_from_file(app): def test_load_from_file(app):
config = b""" config = dedent("""
VALUE = 'some value' VALUE = 'some value'
condition = 1 == 1 condition = 1 == 1
if condition: if condition:
CONDITIONAL = 'should be set' CONDITIONAL = 'should be set'
""" """)
with NamedTemporaryFile() as config_file: with temp_path() as config_path:
config_file.write(config) config_path.write_text(config)
config_file.seek(0) app.config.from_pyfile(str(config_path))
app.config.from_pyfile(config_file.name)
assert 'VALUE' in app.config assert 'VALUE' in app.config
assert app.config.VALUE == 'some value' assert app.config.VALUE == 'some value'
assert 'CONDITIONAL' in app.config assert 'CONDITIONAL' in app.config
@ -61,11 +70,10 @@ def test_load_from_missing_file(app):
def test_load_from_envvar(app): def test_load_from_envvar(app):
config = b"VALUE = 'some value'" config = "VALUE = 'some value'"
with NamedTemporaryFile() as config_file: with temp_path() as config_path:
config_file.write(config) config_path.write_text(config)
config_file.seek(0) environ['APP_CONFIG'] = str(config_path)
environ['APP_CONFIG'] = config_file.name
app.config.from_envvar('APP_CONFIG') app.config.from_envvar('APP_CONFIG')
assert 'VALUE' in app.config assert 'VALUE' in app.config
assert app.config.VALUE == 'some value' assert app.config.VALUE == 'some value'

View File

@ -9,60 +9,22 @@ import aiohttp
from aiohttp import TCPConnector from aiohttp import TCPConnector
from sanic.testing import SanicTestClient, HOST, PORT from sanic.testing import SanicTestClient, HOST, PORT
try:
try:
import packaging # direct use
except ImportError:
# setuptools v39.0 and above.
try:
from setuptools.extern import packaging
except ImportError:
# Before setuptools v39.0
from pkg_resources.extern import packaging
version = packaging.version
except ImportError:
raise RuntimeError("The 'packaging' library is missing.")
aiohttp_version = version.parse(aiohttp.__version__)
class ReuseableTCPConnector(TCPConnector): class ReuseableTCPConnector(TCPConnector):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ReuseableTCPConnector, self).__init__(*args, **kwargs) super(ReuseableTCPConnector, self).__init__(*args, **kwargs)
self.old_proto = None self.old_proto = None
if aiohttp_version >= version.parse('3.3.0'): async def connect(self, req, *args, **kwargs):
async def connect(self, req, traces, timeout): new_conn = await super(ReuseableTCPConnector, self)\
new_conn = await super(ReuseableTCPConnector, self)\ .connect(req, *args, **kwargs)
.connect(req, traces, timeout) if self.old_proto is not None:
if self.old_proto is not None: if self.old_proto != new_conn._protocol:
if self.old_proto != new_conn._protocol: raise RuntimeError(
raise RuntimeError( "We got a new connection, wanted the same one!")
"We got a new connection, wanted the same one!") print(new_conn.__dict__)
print(new_conn.__dict__) self.old_proto = new_conn._protocol
self.old_proto = new_conn._protocol return new_conn
return new_conn
elif aiohttp_version >= version.parse('3.0.0'):
async def connect(self, req, traces=None):
new_conn = await super(ReuseableTCPConnector, self)\
.connect(req, traces=traces)
if self.old_proto is not None:
if self.old_proto != new_conn._protocol:
raise RuntimeError(
"We got a new connection, wanted the same one!")
print(new_conn.__dict__)
self.old_proto = new_conn._protocol
return new_conn
else:
async def connect(self, req):
new_conn = await super(ReuseableTCPConnector, self)\
.connect(req)
if self.old_proto is not None:
if self.old_proto != new_conn._protocol:
raise RuntimeError(
"We got a new connection, wanted the same one!")
print(new_conn.__dict__)
self.old_proto = new_conn._protocol
return new_conn
class ReuseableSanicTestClient(SanicTestClient): class ReuseableSanicTestClient(SanicTestClient):
@ -147,14 +109,14 @@ class ReuseableSanicTestClient(SanicTestClient):
try: try:
request, response = results request, response = results
return request, response return request, response
except: except Exception:
raise ValueError( raise ValueError(
"Request and response object expected, got ({})".format( "Request and response object expected, got ({})".format(
results)) results))
else: else:
try: try:
return results[-1] return results[-1]
except: except Exception:
raise ValueError( raise ValueError(
"Request object expected, got ({})".format(results)) "Request object expected, got ({})".format(results))

View File

@ -1,14 +1,20 @@
import multiprocessing import multiprocessing
import random import random
import signal import signal
import pytest
from sanic.testing import HOST, PORT from sanic.testing import HOST, PORT
@pytest.mark.skipif(
not hasattr(signal, 'SIGALRM'),
reason='SIGALRM is not implemented for this platform, we have to come '
'up with another timeout strategy to test these'
)
def test_multiprocessing(app): def test_multiprocessing(app):
"""Tests that the number of children we produce is correct""" """Tests that the number of children we produce is correct"""
# Selects a number at random so we can spot check # Selects a number at random so we can spot check
num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1)) num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1))
process_list = set() process_list = set()
def stop_on_alarm(*args): def stop_on_alarm(*args):

View File

@ -5,13 +5,15 @@ import asyncio
from sanic.response import text from sanic.response import text
from sanic.config import Config from sanic.config import Config
import aiohttp import aiohttp
from aiohttp import TCPConnector, ClientResponse from aiohttp import TCPConnector
from sanic.testing import SanicTestClient, HOST, PORT from sanic.testing import SanicTestClient, HOST
try: try:
try: try:
import packaging # direct use # direct use
except ImportError: import packaging
version = packaging.version
except (ImportError, AttributeError):
# setuptools v39.0 and above. # setuptools v39.0 and above.
try: try:
from setuptools.extern import packaging from setuptools.extern import packaging
@ -22,8 +24,10 @@ try:
except ImportError: except ImportError:
raise RuntimeError("The 'packaging' library is missing.") raise RuntimeError("The 'packaging' library is missing.")
aiohttp_version = version.parse(aiohttp.__version__) aiohttp_version = version.parse(aiohttp.__version__)
class DelayableTCPConnector(TCPConnector): class DelayableTCPConnector(TCPConnector):
class RequestContextManager(object): class RequestContextManager(object):
@ -56,8 +60,7 @@ class DelayableTCPConnector(TCPConnector):
if aiohttp_version >= version.parse("3.3.0"): if aiohttp_version >= version.parse("3.3.0"):
ret = await self.orig_start(connection) ret = await self.orig_start(connection)
else: else:
ret = await self.orig_start(connection, ret = await self.orig_start(connection, read_until_eof)
read_until_eof)
except Exception as e: except Exception as e:
raise e raise e
return ret return ret
@ -71,57 +74,43 @@ class DelayableTCPConnector(TCPConnector):
async def delayed_send(self, *args, **kwargs): async def delayed_send(self, *args, **kwargs):
req = self.req req = self.req
if self.delay and self.delay > 0: if self.delay and self.delay > 0:
#sync_sleep(self.delay) # sync_sleep(self.delay)
await asyncio.sleep(self.delay) await asyncio.sleep(self.delay)
t = req.loop.time() t = req.loop.time()
print("sending at {}".format(t), flush=True) print("sending at {}".format(t), flush=True)
conn = next(iter(args)) # first arg is connection conn = next(iter(args)) # first arg is connection
if aiohttp_version >= version.parse("3.1.0"): try:
try: return await self.orig_send(*args, **kwargs)
delayed_resp = await self.orig_send(*args, **kwargs) except Exception as e:
except Exception as e: if aiohttp_version < version.parse("3.1.0"):
if aiohttp_version >= version.parse("3.3.0"):
return aiohttp.ClientResponse(req.method, req.url,
writer=None,
continue100=None,
timer=None,
request_info=None,
traces=[],
loop=req.loop,
session=None)
else:
return aiohttp.ClientResponse(req.method, req.url,
writer=None,
continue100=None,
timer=None,
request_info=None,
auto_decompress=None,
traces=[],
loop=req.loop,
session=None)
else:
try:
delayed_resp = self.orig_send(*args, **kwargs)
except Exception as e:
return aiohttp.ClientResponse(req.method, req.url) return aiohttp.ClientResponse(req.method, req.url)
return delayed_resp kw = dict(
writer=None,
continue100=None,
timer=None,
request_info=None,
traces=[],
loop=req.loop,
session=None
)
if aiohttp_version < version.parse("3.3.0"):
kw['auto_decompress'] = None
return aiohttp.ClientResponse(req.method, req.url, **kw)
def _send(self, *args, **kwargs):
gen = self.delayed_send(*args, **kwargs)
task = self.req.loop.create_task(gen)
self.send_task = task
self._acting_as = task
return self
if aiohttp_version >= version.parse("3.1.0"): if aiohttp_version >= version.parse("3.1.0"):
# aiohttp changed the request.send method to async # aiohttp changed the request.send method to async
async def send(self, *args, **kwargs): async def send(self, *args, **kwargs):
gen = self.delayed_send(*args, **kwargs) return self._send(*args, **kwargs)
task = self.req.loop.create_task(gen)
self.send_task = task
self._acting_as = task
return self
else: else:
def send(self, *args, **kwargs): send = _send
gen = self.delayed_send(*args, **kwargs)
task = self.req.loop.create_task(gen)
self.send_task = task
self._acting_as = task
return self
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
_post_connect_delay = kwargs.pop('post_connect_delay', 0) _post_connect_delay = kwargs.pop('post_connect_delay', 0)
@ -130,45 +119,18 @@ class DelayableTCPConnector(TCPConnector):
self._post_connect_delay = _post_connect_delay self._post_connect_delay = _post_connect_delay
self._pre_request_delay = _pre_request_delay self._pre_request_delay = _pre_request_delay
if aiohttp_version >= version.parse("3.3.0"): async def connect(self, req, *args, **kwargs):
async def connect(self, req, traces, timeout): d_req = DelayableTCPConnector.\
d_req = DelayableTCPConnector.\ RequestContextManager(req, self._pre_request_delay)
RequestContextManager(req, self._pre_request_delay) conn = await super(DelayableTCPConnector, self).\
conn = await super(DelayableTCPConnector, self).\ connect(req, *args, **kwargs)
connect(req, traces, timeout) if self._post_connect_delay and self._post_connect_delay > 0:
if self._post_connect_delay and self._post_connect_delay > 0: await asyncio.sleep(self._post_connect_delay,
await asyncio.sleep(self._post_connect_delay, loop=self._loop)
loop=self._loop) req.send = d_req.send
req.send = d_req.send t = req.loop.time()
t = req.loop.time() print("Connected at {}".format(t), flush=True)
print("Connected at {}".format(t), flush=True) return conn
return conn
elif aiohttp_version >= version.parse("3.0.0"):
async def connect(self, req, traces=None):
d_req = DelayableTCPConnector.\
RequestContextManager(req, self._pre_request_delay)
conn = await super(DelayableTCPConnector, self).\
connect(req, traces=traces)
if self._post_connect_delay and self._post_connect_delay > 0:
await asyncio.sleep(self._post_connect_delay,
loop=self._loop)
req.send = d_req.send
t = req.loop.time()
print("Connected at {}".format(t), flush=True)
return conn
else:
async def connect(self, req):
d_req = DelayableTCPConnector.\
RequestContextManager(req, self._pre_request_delay)
conn = await super(DelayableTCPConnector, self).connect(req)
if self._post_connect_delay and self._post_connect_delay > 0:
await asyncio.sleep(self._post_connect_delay,
loop=self._loop)
req.send = d_req.send
t = req.loop.time()
print("Connected at {}".format(t), flush=True)
return conn
class DelayableSanicTestClient(SanicTestClient): class DelayableSanicTestClient(SanicTestClient):

View File

@ -1,3 +1,4 @@
import sys
import asyncio import asyncio
import inspect import inspect
import os import os
@ -8,7 +9,9 @@ from urllib.parse import unquote
import pytest import pytest
from random import choice from random import choice
from sanic.response import HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json from sanic.response import (
HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json
)
from sanic.server import HttpProtocol from sanic.server import HttpProtocol
from sanic.testing import HOST, PORT from sanic.testing import HOST, PORT
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -72,11 +75,15 @@ def test_response_header(app):
'CONTENT-TYPE': 'application/json' 'CONTENT-TYPE': 'application/json'
}) })
is_windows = sys.platform in ['win32', 'cygwin']
request, response = app.test_client.get('/') request, response = app.test_client.get('/')
assert dict(response.headers) == { assert dict(response.headers) == {
'Connection': 'keep-alive', 'Connection': 'keep-alive',
'Keep-Alive': '2', 'Keep-Alive': str(app.config.KEEP_ALIVE_TIMEOUT),
'Content-Length': '11', # response body contains an extra \r at the end if its windows
# TODO: this is the only place this difference shows up in our tests
# we should figure out a way to unify testing on both platforms
'Content-Length': '12' if is_windows else '11',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
} }

View File

@ -165,40 +165,42 @@ def test_route_optional_slash(app):
request, response = app.test_client.get('/get/') request, response = app.test_client.get('/get/')
assert response.text == 'OK' assert response.text == 'OK'
def test_route_strict_slashes_set_to_false_and_host_is_a_list(app): def test_route_strict_slashes_set_to_false_and_host_is_a_list(app):
#Part of regression test for issue #1120 # Part of regression test for issue #1120
site1 = 'localhost:{}'.format(app.test_client.port) site1 = '127.0.0.1:{}'.format(app.test_client.port)
#before fix, this raises a RouteExists error # before fix, this raises a RouteExists error
@app.get('/get', host=[site1, 'site2.com'], strict_slashes=False) @app.get('/get', host=[site1, 'site2.com'], strict_slashes=False)
def handler(request): def get_handler(request):
return text('OK') return text('OK')
request, response = app.test_client.get('http://' + site1 + '/get') request, response = app.test_client.get('http://' + site1 + '/get')
assert response.text == 'OK' assert response.text == 'OK'
@app.post('/post', host=[site1, 'site2.com'], strict_slashes=False) @app.post('/post', host=[site1, 'site2.com'], strict_slashes=False)
def handler(request): def post_handler(request):
return text('OK') return text('OK')
request, response = app.test_client.post('http://' + site1 +'/post') request, response = app.test_client.post('http://' + site1 + '/post')
assert response.text == 'OK' assert response.text == 'OK'
@app.put('/put', host=[site1, 'site2.com'], strict_slashes=False) @app.put('/put', host=[site1, 'site2.com'], strict_slashes=False)
def handler(request): def put_handler(request):
return text('OK') return text('OK')
request, response = app.test_client.put('http://' + site1 +'/put') request, response = app.test_client.put('http://' + site1 + '/put')
assert response.text == 'OK' assert response.text == 'OK'
@app.delete('/delete', host=[site1, 'site2.com'], strict_slashes=False) @app.delete('/delete', host=[site1, 'site2.com'], strict_slashes=False)
def handler(request): def delete_handler(request):
return text('OK') return text('OK')
request, response = app.test_client.delete('http://' + site1 +'/delete') request, response = app.test_client.delete('http://' + site1 + '/delete')
assert response.text == 'OK' assert response.text == 'OK'
def test_shorthand_routes_post(app): def test_shorthand_routes_post(app):
@app.post('/post') @app.post('/post')

View File

@ -11,6 +11,12 @@ AVAILABLE_LISTENERS = [
'after_server_stop' 'after_server_stop'
] ]
skipif_no_alarm = pytest.mark.skipif(
not hasattr(signal, 'SIGALRM'),
reason='SIGALRM is not implemented for this platform, we have to come '
'up with another timeout strategy to test these'
)
def create_listener(listener_name, in_list): def create_listener(listener_name, in_list):
async def _listener(app, loop): async def _listener(app, loop):
@ -32,6 +38,7 @@ def start_stop_app(random_name_app, **run_kwargs):
pass pass
@skipif_no_alarm
@pytest.mark.parametrize('listener_name', AVAILABLE_LISTENERS) @pytest.mark.parametrize('listener_name', AVAILABLE_LISTENERS)
def test_single_listener(app, listener_name): def test_single_listener(app, listener_name):
"""Test that listeners on their own work""" """Test that listeners on their own work"""
@ -43,6 +50,7 @@ def test_single_listener(app, listener_name):
assert app.name + listener_name == output.pop() assert app.name + listener_name == output.pop()
@skipif_no_alarm
@pytest.mark.parametrize('listener_name', AVAILABLE_LISTENERS) @pytest.mark.parametrize('listener_name', AVAILABLE_LISTENERS)
def test_register_listener(app, listener_name): def test_register_listener(app, listener_name):
""" """
@ -52,12 +60,12 @@ def test_register_listener(app, listener_name):
output = [] output = []
# Register listener # Register listener
listener = create_listener(listener_name, output) listener = create_listener(listener_name, output)
app.register_listener(listener, app.register_listener(listener, event=listener_name)
event=listener_name)
start_stop_app(app) start_stop_app(app)
assert app.name + listener_name == output.pop() assert app.name + listener_name == output.pop()
@skipif_no_alarm
def test_all_listeners(app): def test_all_listeners(app):
output = [] output = []
for listener_name in AVAILABLE_LISTENERS: for listener_name in AVAILABLE_LISTENERS:

View File

@ -7,13 +7,17 @@ from unittest import mock
from sanic.worker import GunicornWorker from sanic.worker import GunicornWorker
from sanic.app import Sanic from sanic.app import Sanic
import asyncio import asyncio
import logging
import pytest import pytest
@pytest.fixture(scope='module') @pytest.fixture(scope='module')
def gunicorn_worker(): def gunicorn_worker():
command = 'gunicorn --bind 127.0.0.1:1337 --worker-class sanic.worker.GunicornWorker examples.simple_server:app' command = (
'gunicorn '
'--bind 127.0.0.1:1337 '
'--worker-class sanic.worker.GunicornWorker '
'examples.simple_server:app'
)
worker = subprocess.Popen(shlex.split(command)) worker = subprocess.Popen(shlex.split(command))
time.sleep(3) time.sleep(3)
yield yield
@ -96,11 +100,12 @@ def test_run_max_requests_exceeded(worker):
_runner = asyncio.ensure_future(worker._check_alive(), loop=loop) _runner = asyncio.ensure_future(worker._check_alive(), loop=loop)
loop.run_until_complete(_runner) loop.run_until_complete(_runner)
assert worker.alive == False assert not worker.alive
worker.notify.assert_called_with() worker.notify.assert_called_with()
worker.log.info.assert_called_with("Max requests exceeded, shutting down: %s", worker.log.info.assert_called_with("Max requests exceeded, shutting down: %s",
worker) worker)
def test_worker_close(worker): def test_worker_close(worker):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.sleep = mock.Mock(wraps=asyncio.coroutine(lambda *a, **kw: None)) asyncio.sleep = mock.Mock(wraps=asyncio.coroutine(lambda *a, **kw: None))
@ -113,8 +118,8 @@ def test_worker_close(worker):
conn = mock.Mock() conn = mock.Mock()
conn.websocket = mock.Mock() conn.websocket = mock.Mock()
conn.websocket.close_connection = mock.Mock( conn.websocket.close_connection = mock.Mock(
wraps=asyncio.coroutine(lambda *a, **kw: None) wraps=asyncio.coroutine(lambda *a, **kw: None)
) )
worker.connections = set([conn]) worker.connections = set([conn])
worker.log = mock.Mock() worker.log = mock.Mock()
worker.loop = loop worker.loop = loop
@ -130,6 +135,6 @@ def test_worker_close(worker):
_close = asyncio.ensure_future(worker.close(), loop=loop) _close = asyncio.ensure_future(worker.close(), loop=loop)
loop.run_until_complete(_close) loop.run_until_complete(_close)
assert worker.signal.stopped == True assert worker.signal.stopped
assert conn.websocket.close_connection.called == True assert conn.websocket.close_connection.called
assert len(worker.servers) == 0 assert len(worker.servers) == 0