Decode headers as UTF-8 also in ASGI (#2606)

Co-authored-by: Adam Hopkins <adam@amhopkins.com>
This commit is contained in:
Zhiwei 2023-03-20 09:39:57 -04:00 committed by GitHub
parent 71cd53b64e
commit 61aa16f6ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 7 deletions

View File

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
from urllib.parse import quote
from sanic.compat import Header
from sanic.exceptions import ServerError
from sanic.exceptions import BadRequest, ServerError
from sanic.helpers import Default
from sanic.http import Stage
from sanic.log import error_logger, logger
@ -132,12 +132,20 @@ class ASGIApp:
instance.sanic_app.state.is_started = True
setattr(instance.transport, "add_task", sanic_app.loop.create_task)
headers = Header(
[
(key.decode("latin-1"), value.decode("latin-1"))
for key, value in scope.get("headers", [])
]
)
try:
headers = Header(
[
(
key.decode("ASCII"),
value.decode(errors="surrogateescape"),
)
for key, value in scope.get("headers", [])
]
)
except UnicodeDecodeError:
raise BadRequest(
"Header names can only contain US-ASCII characters"
)
path = (
scope["path"][1:]
if scope["path"].startswith("/")

View File

@ -7,6 +7,9 @@ from unittest.mock import call
import pytest
import uvicorn
from httpx import Headers
from pytest import MonkeyPatch
from sanic import Sanic
from sanic.application.state import Mode
from sanic.asgi import ASGIApp, Lifespan, MockTransport
@ -626,3 +629,26 @@ async def test_error_on_lifespan_exception_stop(app: Sanic):
)
]
)
@pytest.mark.asyncio
async def test_asgi_headers_decoding(app: Sanic, monkeypatch: MonkeyPatch):
@app.get("/")
def handler(request: Request):
return text("")
headers_init = Headers.__init__
def mocked_headers_init(self, *args, **kwargs):
if "encoding" in kwargs:
kwargs.pop("encoding")
headers_init(self, encoding="utf-8", *args, **kwargs)
monkeypatch.setattr(Headers, "__init__", mocked_headers_init)
message = "Header names can only contain US-ASCII characters"
with pytest.raises(BadRequest, match=message):
_, response = await app.asgi_client.get("/", headers={"😂": "😅"})
_, response = await app.asgi_client.get("/", headers={"Test-Header": "😅"})
assert response.status_code == 200