Enforce Datetime Type for Expires on Set-Cookie (#1484)

* Enforce Datetime Type for Expires on Set-Cookie

* Fix lint issues

* Format code and improve error type

* Fix import order
This commit is contained in:
Leonardo Teixeira Menezes 2019-02-06 16:29:33 -02:00 committed by Stephen Sadowski
parent 4f70dba935
commit 08794ae1cf
9 changed files with 70 additions and 69 deletions

View File

@ -1,6 +1,8 @@
import re
import string
from datetime import datetime
DEFAULT_MAX_AGE = 0
@ -108,6 +110,11 @@ class Cookie(dict):
if key.lower() == "max-age":
if not str(value).isdigit():
value = DEFAULT_MAX_AGE
elif key.lower() == "expires":
if not isinstance(value, datetime):
raise TypeError(
"Cookie 'expires' property must be a datetime"
)
return super().__setitem__(key, value)
def encode(self, encoding):
@ -131,16 +138,10 @@ class Cookie(dict):
except TypeError:
output.append("%s=%s" % (self._keys[key], value))
elif key == "expires":
try:
output.append(
"%s=%s"
% (
self._keys[key],
value.strftime("%a, %d-%b-%Y %T GMT"),
)
)
except AttributeError:
output.append("%s=%s" % (self._keys[key], value))
output.append(
"%s=%s"
% (self._keys[key], value.strftime("%a, %d-%b-%Y %T GMT"))
)
elif key in self._flags and self[key]:
output.append(self._keys[key])
else:

View File

@ -12,6 +12,7 @@ from sanic.response import text
def uvloop_installed():
try:
import uvloop
return True
except ImportError:
return False
@ -27,28 +28,28 @@ def test_app_loop_running(app):
assert response.text == "pass"
@pytest.mark.skipif(sys.version_info < (3, 7),
reason="requires python3.7 or higher")
@pytest.mark.skipif(
sys.version_info < (3, 7), reason="requires python3.7 or higher"
)
def test_create_asyncio_server(app):
if not uvloop_installed():
loop = asyncio.get_event_loop()
asyncio_srv_coro = app.create_server(
return_asyncio_server=True)
asyncio_srv_coro = app.create_server(return_asyncio_server=True)
assert isawaitable(asyncio_srv_coro)
srv = loop.run_until_complete(asyncio_srv_coro)
assert srv.is_serving() is True
@pytest.mark.skipif(sys.version_info < (3, 7),
reason="requires python3.7 or higher")
@pytest.mark.skipif(
sys.version_info < (3, 7), reason="requires python3.7 or higher"
)
def test_asyncio_server_start_serving(app):
if not uvloop_installed():
loop = asyncio.get_event_loop()
asyncio_srv_coro = app.create_server(
return_asyncio_server=True,
asyncio_server_kwargs=dict(
start_serving=False
))
asyncio_server_kwargs=dict(start_serving=False),
)
srv = loop.run_until_complete(asyncio_srv_coro)
assert srv.is_serving() is False

View File

@ -166,7 +166,7 @@ def test_config_custom_defaults():
custom_defaults = {
"REQUEST_MAX_SIZE": 1,
"KEEP_ALIVE": False,
"ACCESS_LOG": False
"ACCESS_LOG": False,
}
conf = Config(defaults=custom_defaults)
for key, value in DEFAULT_CONFIG.items():
@ -182,13 +182,13 @@ def test_config_custom_defaults_with_env():
custom_defaults = {
"REQUEST_MAX_SIZE123": 1,
"KEEP_ALIVE123": False,
"ACCESS_LOG123": False
"ACCESS_LOG123": False,
}
environ_defaults = {
"SANIC_REQUEST_MAX_SIZE123": "2",
"SANIC_KEEP_ALIVE123": "True",
"SANIC_ACCESS_LOG123": "False"
"SANIC_ACCESS_LOG123": "False",
}
for key, value in environ_defaults.items():
@ -201,8 +201,8 @@ def test_config_custom_defaults_with_env():
try:
value = int(value)
except ValueError:
if value in ['True', 'False']:
value = value == 'True'
if value in ["True", "False"]:
value = value == "True"
assert getattr(conf, key) == value
@ -213,7 +213,7 @@ def test_config_custom_defaults_with_env():
def test_config_access_log_passing_in_run(app):
assert app.config.ACCESS_LOG == True
@app.listener('after_server_start')
@app.listener("after_server_start")
async def _request(sanic, loop):
app.stop()
@ -227,16 +227,18 @@ def test_config_access_log_passing_in_run(app):
async def test_config_access_log_passing_in_create_server(app):
assert app.config.ACCESS_LOG == True
@app.listener('after_server_start')
@app.listener("after_server_start")
async def _request(sanic, loop):
app.stop()
await app.create_server(port=1341, access_log=False,
return_asyncio_server=True)
await app.create_server(
port=1341, access_log=False, return_asyncio_server=True
)
assert app.config.ACCESS_LOG == False
await app.create_server(port=1342, access_log=True,
return_asyncio_server=True)
await app.create_server(
port=1342, access_log=True, return_asyncio_server=True
)
assert app.config.ACCESS_LOG == True

View File

@ -160,10 +160,7 @@ def test_cookie_max_age(app, max_age):
assert response.cookies["test"]["max-age"] == str(DEFAULT_MAX_AGE)
@pytest.mark.parametrize(
"expires",
[datetime.now() + timedelta(seconds=60), "Fri, 21-Dec-2018 15:30:00 GMT"],
)
@pytest.mark.parametrize("expires", [datetime.now() + timedelta(seconds=60)])
def test_cookie_expires(app, expires):
cookies = {"test": "wait"}
@ -183,3 +180,11 @@ def test_cookie_expires(app, expires):
expires = expires.strftime("%a, %d-%b-%Y %T GMT")
assert response.cookies["test"]["expires"] == expires
@pytest.mark.parametrize("expires", ["Fri, 21-Dec-2018 15:30:00 GMT"])
def test_cookie_expires_illegal_instance_type(expires):
c = Cookie("test_cookie", "value")
with pytest.raises(expected_exception=TypeError) as e:
c["expires"] = expires
assert e.message == "Cookie 'expires' property must be a datetime"

View File

@ -9,10 +9,8 @@ from aiohttp import TCPConnector
from sanic.testing import SanicTestClient, HOST, PORT
CONFIG_FOR_TESTS = {
"KEEP_ALIVE_TIMEOUT": 2,
"KEEP_ALIVE": True
}
CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True}
class ReuseableTCPConnector(TCPConnector):
def __init__(self, *args, **kwargs):
@ -51,9 +49,7 @@ class ReuseableSanicTestClient(SanicTestClient):
uri="/",
gather_request=True,
debug=False,
server_kwargs={
"return_asyncio_server": True,
},
server_kwargs={"return_asyncio_server": True},
*request_args,
**request_kwargs
):
@ -147,7 +143,7 @@ class ReuseableSanicTestClient(SanicTestClient):
# loop, so the changes above are required too.
async def _local_request(self, method, uri, cookies=None, *args, **kwargs):
request_keepalive = kwargs.pop(
"request_keepalive", CONFIG_FOR_TESTS['KEEP_ALIVE_TIMEOUT']
"request_keepalive", CONFIG_FOR_TESTS["KEEP_ALIVE_TIMEOUT"]
)
if uri.startswith(("http:", "https:", "ftp:", "ftps://" "//")):
url = uri

View File

@ -12,8 +12,7 @@ except BaseException:
def test_logo_base(app, caplog):
server = app.create_server(
debug=True, return_asyncio_server=True)
server = app.create_server(debug=True, return_asyncio_server=True)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
@ -32,8 +31,7 @@ def test_logo_base(app, caplog):
def test_logo_false(app, caplog):
app.config.LOGO = False
server = app.create_server(
debug=True, return_asyncio_server=True)
server = app.create_server(debug=True, return_asyncio_server=True)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
@ -52,8 +50,7 @@ def test_logo_false(app, caplog):
def test_logo_true(app, caplog):
app.config.LOGO = True
server = app.create_server(
debug=True, return_asyncio_server=True)
server = app.create_server(debug=True, return_asyncio_server=True)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
@ -72,8 +69,7 @@ def test_logo_true(app, caplog):
def test_logo_custom(app, caplog):
app.config.LOGO = "My Custom Logo"
server = app.create_server(
debug=True, return_asyncio_server=True)
server = app.create_server(debug=True, return_asyncio_server=True)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False

View File

@ -151,9 +151,7 @@ class DelayableSanicTestClient(SanicTestClient):
host=HOST, port=self.port, uri=uri
)
conn = DelayableTCPConnector(
pre_request_delay=self._request_delay,
ssl=False,
loop=self._loop,
pre_request_delay=self._request_delay, ssl=False, loop=self._loop
)
async with aiohttp.ClientSession(
cookies=cookies, connector=conn, loop=self._loop

View File

@ -83,8 +83,7 @@ async def test_trigger_before_events_create_server(app):
async def init_db(app, loop):
app.db = MySanicDb()
await app.create_server(
debug=True, return_asyncio_server=True)
await app.create_server(debug=True, return_asyncio_server=True)
assert hasattr(app, "db")
assert isinstance(app.db, MySanicDb)

View File

@ -24,28 +24,28 @@ def gunicorn_worker():
worker.kill()
@pytest.fixture(scope='module')
@pytest.fixture(scope="module")
def gunicorn_worker_with_access_logs():
command = (
'gunicorn '
'--bind 127.0.0.1:1338 '
'--worker-class sanic.worker.GunicornWorker '
'examples.simple_server:app'
"gunicorn "
"--bind 127.0.0.1:1338 "
"--worker-class sanic.worker.GunicornWorker "
"examples.simple_server:app"
)
worker = subprocess.Popen(shlex.split(command), stdout=subprocess.PIPE)
time.sleep(2)
return worker
@pytest.fixture(scope='module')
@pytest.fixture(scope="module")
def gunicorn_worker_with_env_var():
command = (
'env SANIC_ACCESS_LOG="False" '
'gunicorn '
'--bind 127.0.0.1:1339 '
'--worker-class sanic.worker.GunicornWorker '
'--log-level info '
'examples.simple_server:app'
"gunicorn "
"--bind 127.0.0.1:1339 "
"--worker-class sanic.worker.GunicornWorker "
"--log-level info "
"examples.simple_server:app"
)
worker = subprocess.Popen(shlex.split(command), stdout=subprocess.PIPE)
time.sleep(2)
@ -62,7 +62,7 @@ def test_gunicorn_worker_no_logs(gunicorn_worker_with_env_var):
"""
if SANIC_ACCESS_LOG was set to False do not show access logs
"""
with urllib.request.urlopen('http://localhost:1339/') as _:
with urllib.request.urlopen("http://localhost:1339/") as _:
gunicorn_worker_with_env_var.kill()
assert not gunicorn_worker_with_env_var.stdout.read()
@ -71,9 +71,12 @@ def test_gunicorn_worker_with_logs(gunicorn_worker_with_access_logs):
"""
default - show access logs
"""
with urllib.request.urlopen('http://localhost:1338/') as _:
with urllib.request.urlopen("http://localhost:1338/") as _:
gunicorn_worker_with_access_logs.kill()
assert b"(sanic.access)[INFO][127.0.0.1" in gunicorn_worker_with_access_logs.stdout.read()
assert (
b"(sanic.access)[INFO][127.0.0.1"
in gunicorn_worker_with_access_logs.stdout.read()
)
class GunicornTestWorker(GunicornWorker):