fix: GIT-1623: handle cookie duplication and serialization issue

Signed-off-by: Harsha Narayana <harsha2k4@gmail.com>
This commit is contained in:
Harsha Narayana 2019-07-08 13:03:33 +05:30
parent 68d5039c5f
commit 97f288a534
No known key found for this signature in database
GPG Key ID: 8AF211CB60D4B28C
3 changed files with 101 additions and 52 deletions

View File

@ -1,7 +1,6 @@
import asyncio import asyncio
import warnings import warnings
from http.cookies import SimpleCookie
from inspect import isawaitable from inspect import isawaitable
from typing import Any, Awaitable, Callable, MutableMapping, Union from typing import Any, Awaitable, Callable, MutableMapping, Union
from urllib.parse import quote from urllib.parse import quote
@ -288,11 +287,22 @@ class ASGIApp:
""" """
Write the response. Write the response.
""" """
headers = []
cookies = dict()
try: try:
headers = [ cookies = {
v.key: v
for _, v in list(
filter(
lambda item: item[0].lower() == "set-cookie",
response.headers.items(),
)
)
}
headers += [
(str(name).encode("latin-1"), str(value).encode("latin-1")) (str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items() for name, value in response.headers.items()
if name.lower() not in ["set-cookie"]
] ]
except AttributeError: except AttributeError:
logger.error( logger.error(
@ -319,12 +329,18 @@ class ASGIApp:
] ]
if response.cookies: if response.cookies:
cookies = SimpleCookie() cookies.update(
cookies.load(response.cookies) {
headers += [ v.key: v
(b"set-cookie", cookie.encode("utf-8")) for _, v in response.cookies.items()
for name, cookie in response.cookies.items() if v.key not in cookies.keys()
] }
)
headers += [
(b"set-cookie", cookie.encode("utf-8"))
for k, cookie in cookies.items()
]
await self.transport.send( await self.transport.send(
{ {

View File

@ -229,3 +229,30 @@ async def test_request_class_custom():
_, response = await app.asgi_client.get("/custom") _, response = await app.asgi_client.get("/custom")
assert response.body == b"MyCustomRequest" assert response.body == b"MyCustomRequest"
@pytest.mark.asyncio
async def test_cookie_customization(app):
@app.get("/cookie")
def get_cookie(request):
response = text("There's a cookie up in this response")
response.cookies["test"] = "Cookie1"
response.cookies["test"]["httponly"] = True
response.cookies["c2"] = "Cookie2"
response.cookies["c2"]["httponly"] = False
return response
_, response = await app.asgi_client.get("/cookie")
cookie_map = {
"test": {"value": "Cookie1", "HttpOnly": True},
"c2": {"value": "Cookie2", "HttpOnly": False},
}
for k, v in (
response.cookies._cookies.get("mockserver.local").get("/").items()
):
assert cookie_map.get(k).get("value") == v.value
if cookie_map.get(k).get("HttpOnly"):
assert "HttpOnly" in v._rest.keys()

View File

@ -635,13 +635,15 @@ def test_forwarded_scheme(app):
return text(request.remote_addr) return text(request.remote_addr)
request, response = app.test_client.get("/") request, response = app.test_client.get("/")
assert request.scheme == 'http' assert request.scheme == "http"
request, response = app.test_client.get("/", headers={'X-Forwarded-Proto': 'https'}) request, response = app.test_client.get(
assert request.scheme == 'https' "/", headers={"X-Forwarded-Proto": "https"}
)
assert request.scheme == "https"
request, response = app.test_client.get("/", headers={'X-Scheme': 'https'}) request, response = app.test_client.get("/", headers={"X-Scheme": "https"})
assert request.scheme == 'https' assert request.scheme == "https"
def test_match_info(app): def test_match_info(app):
@ -1677,7 +1679,7 @@ def test_request_server_name(app):
return text("OK") return text("OK")
request, response = app.test_client.get("/") request, response = app.test_client.get("/")
assert request.server_name == '127.0.0.1' assert request.server_name == "127.0.0.1"
def test_request_server_name_in_host_header(app): def test_request_server_name_in_host_header(app):
@ -1685,8 +1687,10 @@ def test_request_server_name_in_host_header(app):
def handler(request): def handler(request):
return text("OK") return text("OK")
request, response = app.test_client.get("/", headers={'Host': 'my_server:5555'}) request, response = app.test_client.get(
assert request.server_name == 'my_server' "/", headers={"Host": "my_server:5555"}
)
assert request.server_name == "my_server"
def test_request_server_name_forwarded(app): def test_request_server_name_forwarded(app):
@ -1694,11 +1698,11 @@ def test_request_server_name_forwarded(app):
def handler(request): def handler(request):
return text("OK") return text("OK")
request, response = app.test_client.get("/", headers={ request, response = app.test_client.get(
'Host': 'my_server:5555', "/",
'X-Forwarded-Host': 'your_server' headers={"Host": "my_server:5555", "X-Forwarded-Host": "your_server"},
}) )
assert request.server_name == 'your_server' assert request.server_name == "your_server"
def test_request_server_port(app): def test_request_server_port(app):
@ -1706,9 +1710,7 @@ def test_request_server_port(app):
def handler(request): def handler(request):
return text("OK") return text("OK")
request, response = app.test_client.get("/", headers={ request, response = app.test_client.get("/", headers={"Host": "my_server"})
'Host': 'my_server'
})
assert request.server_port == app.test_client.port assert request.server_port == app.test_client.port
@ -1717,9 +1719,9 @@ def test_request_server_port_in_host_header(app):
def handler(request): def handler(request):
return text("OK") return text("OK")
request, response = app.test_client.get("/", headers={ request, response = app.test_client.get(
'Host': 'my_server:5555' "/", headers={"Host": "my_server:5555"}
}) )
assert request.server_port == 5555 assert request.server_port == 5555
@ -1728,10 +1730,9 @@ def test_request_server_port_forwarded(app):
def handler(request): def handler(request):
return text("OK") return text("OK")
request, response = app.test_client.get("/", headers={ request, response = app.test_client.get(
'Host': 'my_server:5555', "/", headers={"Host": "my_server:5555", "X-Forwarded-Port": "4444"}
'X-Forwarded-Port': '4444' )
})
assert request.server_port == 4444 assert request.server_port == 4444
@ -1754,29 +1755,34 @@ def test_url_for_with_forwarded_request(app):
def view_name(request): def view_name(request):
return text("OK") return text("OK")
request, response = app.test_client.get("/", headers={ request, response = app.test_client.get(
'X-Forwarded-Proto': 'https', "/", headers={"X-Forwarded-Proto": "https"}
}) )
assert app.url_for('view_name') == '/another_view' assert app.url_for("view_name") == "/another_view"
assert app.url_for('view_name', _external=True) == 'http:///another_view' assert app.url_for("view_name", _external=True) == "http:///another_view"
assert request.url_for('view_name') == 'https://127.0.0.1:{}/another_view'.format(app.test_client.port) assert request.url_for(
"view_name"
) == "https://127.0.0.1:{}/another_view".format(app.test_client.port)
app.config.SERVER_NAME = "my_server" app.config.SERVER_NAME = "my_server"
request, response = app.test_client.get("/", headers={ request, response = app.test_client.get(
'X-Forwarded-Proto': 'https', "/", headers={"X-Forwarded-Proto": "https", "X-Forwarded-Port": "6789"}
'X-Forwarded-Port': '6789', )
}) assert app.url_for("view_name") == "/another_view"
assert app.url_for('view_name') == '/another_view' assert (
assert app.url_for('view_name', _external=True) == 'http://my_server/another_view' app.url_for("view_name", _external=True)
assert request.url_for('view_name') == 'https://my_server:6789/another_view' == "http://my_server/another_view"
)
assert (
request.url_for("view_name") == "https://my_server:6789/another_view"
)
request, response = app.test_client.get(
"/", headers={"X-Forwarded-Proto": "https", "X-Forwarded-Port": "443"}
)
assert request.url_for("view_name") == "https://my_server/another_view"
request, response = app.test_client.get("/", headers={
'X-Forwarded-Proto': 'https',
'X-Forwarded-Port': '443',
})
assert request.url_for('view_name') == 'https://my_server/another_view'
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_request_form_invalid_content_type_asgi(app): async def test_request_form_invalid_content_type_asgi(app):
@app.route("/", methods=["POST"]) @app.route("/", methods=["POST"])
@ -1787,7 +1793,7 @@ async def test_request_form_invalid_content_type_asgi(app):
assert request.form == {} assert request.form == {}
def test_endpoint_basic(): def test_endpoint_basic():
app = Sanic() app = Sanic()