fix: GIT-1623: handle cookie duplication and serialization issue
Signed-off-by: Harsha Narayana <harsha2k4@gmail.com>
This commit is contained in:
parent
68d5039c5f
commit
97f288a534
@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from http.cookies import SimpleCookie
|
|
||||||
from inspect import isawaitable
|
from inspect import isawaitable
|
||||||
from typing import Any, Awaitable, Callable, MutableMapping, Union
|
from typing import Any, Awaitable, Callable, MutableMapping, Union
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
@ -288,11 +287,22 @@ class ASGIApp:
|
|||||||
"""
|
"""
|
||||||
Write the response.
|
Write the response.
|
||||||
"""
|
"""
|
||||||
|
headers = []
|
||||||
|
cookies = dict()
|
||||||
try:
|
try:
|
||||||
headers = [
|
cookies = {
|
||||||
|
v.key: v
|
||||||
|
for _, v in list(
|
||||||
|
filter(
|
||||||
|
lambda item: item[0].lower() == "set-cookie",
|
||||||
|
response.headers.items(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
headers += [
|
||||||
(str(name).encode("latin-1"), str(value).encode("latin-1"))
|
(str(name).encode("latin-1"), str(value).encode("latin-1"))
|
||||||
for name, value in response.headers.items()
|
for name, value in response.headers.items()
|
||||||
|
if name.lower() not in ["set-cookie"]
|
||||||
]
|
]
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
logger.error(
|
logger.error(
|
||||||
@ -319,12 +329,18 @@ class ASGIApp:
|
|||||||
]
|
]
|
||||||
|
|
||||||
if response.cookies:
|
if response.cookies:
|
||||||
cookies = SimpleCookie()
|
cookies.update(
|
||||||
cookies.load(response.cookies)
|
{
|
||||||
headers += [
|
v.key: v
|
||||||
(b"set-cookie", cookie.encode("utf-8"))
|
for _, v in response.cookies.items()
|
||||||
for name, cookie in response.cookies.items()
|
if v.key not in cookies.keys()
|
||||||
]
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
headers += [
|
||||||
|
(b"set-cookie", cookie.encode("utf-8"))
|
||||||
|
for k, cookie in cookies.items()
|
||||||
|
]
|
||||||
|
|
||||||
await self.transport.send(
|
await self.transport.send(
|
||||||
{
|
{
|
||||||
|
@ -229,3 +229,30 @@ async def test_request_class_custom():
|
|||||||
|
|
||||||
_, response = await app.asgi_client.get("/custom")
|
_, response = await app.asgi_client.get("/custom")
|
||||||
assert response.body == b"MyCustomRequest"
|
assert response.body == b"MyCustomRequest"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cookie_customization(app):
|
||||||
|
@app.get("/cookie")
|
||||||
|
def get_cookie(request):
|
||||||
|
response = text("There's a cookie up in this response")
|
||||||
|
response.cookies["test"] = "Cookie1"
|
||||||
|
response.cookies["test"]["httponly"] = True
|
||||||
|
|
||||||
|
response.cookies["c2"] = "Cookie2"
|
||||||
|
response.cookies["c2"]["httponly"] = False
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
_, response = await app.asgi_client.get("/cookie")
|
||||||
|
cookie_map = {
|
||||||
|
"test": {"value": "Cookie1", "HttpOnly": True},
|
||||||
|
"c2": {"value": "Cookie2", "HttpOnly": False},
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v in (
|
||||||
|
response.cookies._cookies.get("mockserver.local").get("/").items()
|
||||||
|
):
|
||||||
|
assert cookie_map.get(k).get("value") == v.value
|
||||||
|
if cookie_map.get(k).get("HttpOnly"):
|
||||||
|
assert "HttpOnly" in v._rest.keys()
|
||||||
|
@ -635,13 +635,15 @@ def test_forwarded_scheme(app):
|
|||||||
return text(request.remote_addr)
|
return text(request.remote_addr)
|
||||||
|
|
||||||
request, response = app.test_client.get("/")
|
request, response = app.test_client.get("/")
|
||||||
assert request.scheme == 'http'
|
assert request.scheme == "http"
|
||||||
|
|
||||||
request, response = app.test_client.get("/", headers={'X-Forwarded-Proto': 'https'})
|
request, response = app.test_client.get(
|
||||||
assert request.scheme == 'https'
|
"/", headers={"X-Forwarded-Proto": "https"}
|
||||||
|
)
|
||||||
|
assert request.scheme == "https"
|
||||||
|
|
||||||
request, response = app.test_client.get("/", headers={'X-Scheme': 'https'})
|
request, response = app.test_client.get("/", headers={"X-Scheme": "https"})
|
||||||
assert request.scheme == 'https'
|
assert request.scheme == "https"
|
||||||
|
|
||||||
|
|
||||||
def test_match_info(app):
|
def test_match_info(app):
|
||||||
@ -1677,7 +1679,7 @@ def test_request_server_name(app):
|
|||||||
return text("OK")
|
return text("OK")
|
||||||
|
|
||||||
request, response = app.test_client.get("/")
|
request, response = app.test_client.get("/")
|
||||||
assert request.server_name == '127.0.0.1'
|
assert request.server_name == "127.0.0.1"
|
||||||
|
|
||||||
|
|
||||||
def test_request_server_name_in_host_header(app):
|
def test_request_server_name_in_host_header(app):
|
||||||
@ -1685,8 +1687,10 @@ def test_request_server_name_in_host_header(app):
|
|||||||
def handler(request):
|
def handler(request):
|
||||||
return text("OK")
|
return text("OK")
|
||||||
|
|
||||||
request, response = app.test_client.get("/", headers={'Host': 'my_server:5555'})
|
request, response = app.test_client.get(
|
||||||
assert request.server_name == 'my_server'
|
"/", headers={"Host": "my_server:5555"}
|
||||||
|
)
|
||||||
|
assert request.server_name == "my_server"
|
||||||
|
|
||||||
|
|
||||||
def test_request_server_name_forwarded(app):
|
def test_request_server_name_forwarded(app):
|
||||||
@ -1694,11 +1698,11 @@ def test_request_server_name_forwarded(app):
|
|||||||
def handler(request):
|
def handler(request):
|
||||||
return text("OK")
|
return text("OK")
|
||||||
|
|
||||||
request, response = app.test_client.get("/", headers={
|
request, response = app.test_client.get(
|
||||||
'Host': 'my_server:5555',
|
"/",
|
||||||
'X-Forwarded-Host': 'your_server'
|
headers={"Host": "my_server:5555", "X-Forwarded-Host": "your_server"},
|
||||||
})
|
)
|
||||||
assert request.server_name == 'your_server'
|
assert request.server_name == "your_server"
|
||||||
|
|
||||||
|
|
||||||
def test_request_server_port(app):
|
def test_request_server_port(app):
|
||||||
@ -1706,9 +1710,7 @@ def test_request_server_port(app):
|
|||||||
def handler(request):
|
def handler(request):
|
||||||
return text("OK")
|
return text("OK")
|
||||||
|
|
||||||
request, response = app.test_client.get("/", headers={
|
request, response = app.test_client.get("/", headers={"Host": "my_server"})
|
||||||
'Host': 'my_server'
|
|
||||||
})
|
|
||||||
assert request.server_port == app.test_client.port
|
assert request.server_port == app.test_client.port
|
||||||
|
|
||||||
|
|
||||||
@ -1717,9 +1719,9 @@ def test_request_server_port_in_host_header(app):
|
|||||||
def handler(request):
|
def handler(request):
|
||||||
return text("OK")
|
return text("OK")
|
||||||
|
|
||||||
request, response = app.test_client.get("/", headers={
|
request, response = app.test_client.get(
|
||||||
'Host': 'my_server:5555'
|
"/", headers={"Host": "my_server:5555"}
|
||||||
})
|
)
|
||||||
assert request.server_port == 5555
|
assert request.server_port == 5555
|
||||||
|
|
||||||
|
|
||||||
@ -1728,10 +1730,9 @@ def test_request_server_port_forwarded(app):
|
|||||||
def handler(request):
|
def handler(request):
|
||||||
return text("OK")
|
return text("OK")
|
||||||
|
|
||||||
request, response = app.test_client.get("/", headers={
|
request, response = app.test_client.get(
|
||||||
'Host': 'my_server:5555',
|
"/", headers={"Host": "my_server:5555", "X-Forwarded-Port": "4444"}
|
||||||
'X-Forwarded-Port': '4444'
|
)
|
||||||
})
|
|
||||||
assert request.server_port == 4444
|
assert request.server_port == 4444
|
||||||
|
|
||||||
|
|
||||||
@ -1754,29 +1755,34 @@ def test_url_for_with_forwarded_request(app):
|
|||||||
def view_name(request):
|
def view_name(request):
|
||||||
return text("OK")
|
return text("OK")
|
||||||
|
|
||||||
request, response = app.test_client.get("/", headers={
|
request, response = app.test_client.get(
|
||||||
'X-Forwarded-Proto': 'https',
|
"/", headers={"X-Forwarded-Proto": "https"}
|
||||||
})
|
)
|
||||||
assert app.url_for('view_name') == '/another_view'
|
assert app.url_for("view_name") == "/another_view"
|
||||||
assert app.url_for('view_name', _external=True) == 'http:///another_view'
|
assert app.url_for("view_name", _external=True) == "http:///another_view"
|
||||||
assert request.url_for('view_name') == 'https://127.0.0.1:{}/another_view'.format(app.test_client.port)
|
assert request.url_for(
|
||||||
|
"view_name"
|
||||||
|
) == "https://127.0.0.1:{}/another_view".format(app.test_client.port)
|
||||||
|
|
||||||
app.config.SERVER_NAME = "my_server"
|
app.config.SERVER_NAME = "my_server"
|
||||||
request, response = app.test_client.get("/", headers={
|
request, response = app.test_client.get(
|
||||||
'X-Forwarded-Proto': 'https',
|
"/", headers={"X-Forwarded-Proto": "https", "X-Forwarded-Port": "6789"}
|
||||||
'X-Forwarded-Port': '6789',
|
)
|
||||||
})
|
assert app.url_for("view_name") == "/another_view"
|
||||||
assert app.url_for('view_name') == '/another_view'
|
assert (
|
||||||
assert app.url_for('view_name', _external=True) == 'http://my_server/another_view'
|
app.url_for("view_name", _external=True)
|
||||||
assert request.url_for('view_name') == 'https://my_server:6789/another_view'
|
== "http://my_server/another_view"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
request.url_for("view_name") == "https://my_server:6789/another_view"
|
||||||
|
)
|
||||||
|
|
||||||
|
request, response = app.test_client.get(
|
||||||
|
"/", headers={"X-Forwarded-Proto": "https", "X-Forwarded-Port": "443"}
|
||||||
|
)
|
||||||
|
assert request.url_for("view_name") == "https://my_server/another_view"
|
||||||
|
|
||||||
request, response = app.test_client.get("/", headers={
|
|
||||||
'X-Forwarded-Proto': 'https',
|
|
||||||
'X-Forwarded-Port': '443',
|
|
||||||
})
|
|
||||||
assert request.url_for('view_name') == 'https://my_server/another_view'
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_request_form_invalid_content_type_asgi(app):
|
async def test_request_form_invalid_content_type_asgi(app):
|
||||||
@app.route("/", methods=["POST"])
|
@app.route("/", methods=["POST"])
|
||||||
@ -1787,7 +1793,7 @@ async def test_request_form_invalid_content_type_asgi(app):
|
|||||||
|
|
||||||
assert request.form == {}
|
assert request.form == {}
|
||||||
|
|
||||||
|
|
||||||
def test_endpoint_basic():
|
def test_endpoint_basic():
|
||||||
app = Sanic()
|
app = Sanic()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user