Add header accessors (#2696)

This commit is contained in:
Adam Hopkins 2023-02-28 00:26:53 +02:00 committed by GitHub
parent dfc0704831
commit cb49c2b26d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 1 deletions

View File

@ -88,6 +88,12 @@ class Header(CIMultiDict):
very similar to a regular dictionary. very similar to a regular dictionary.
""" """
def __getattr__(self, key: str) -> str:
if key.startswith("_"):
return self.__getattribute__(key)
key = key.rstrip("_").replace("_", "-")
return ",".join(self.getall(key, default=[]))
def get_all(self, key: str): def get_all(self, key: str):
""" """
Convenience method mapped to ``getall()``. Convenience method mapped to ``getall()``.

View File

@ -2,12 +2,16 @@ from unittest.mock import Mock
import pytest import pytest
from sanic import headers, text from sanic import Sanic, headers, json, text
from sanic.exceptions import InvalidHeader, PayloadTooLarge from sanic.exceptions import InvalidHeader, PayloadTooLarge
from sanic.http import Http from sanic.http import Http
from sanic.request import Request from sanic.request import Request
def make_request(headers) -> Request:
return Request(b"/", headers, "1.1", "GET", None, None)
@pytest.fixture @pytest.fixture
def raised_ceiling(): def raised_ceiling():
Http.HEADER_CEILING = 32_768 Http.HEADER_CEILING = 32_768
@ -435,3 +439,46 @@ def test_accept_misc():
a = headers.parse_accept("") a = headers.parse_accept("")
assert a == [] assert a == []
assert not a.match("foo/bar") assert not a.match("foo/bar")
@pytest.mark.parametrize(
"headers,expected",
(
({"foo": "bar"}, "bar"),
((("foo", "bar"), ("foo", "baz")), "bar,baz"),
({}, ""),
),
)
def test_field_simple_accessor(headers, expected):
request = make_request(headers)
assert request.headers.foo == request.headers.foo_ == expected
@pytest.mark.parametrize(
"headers,expected",
(
({"foo-bar": "bar"}, "bar"),
((("foo-bar", "bar"), ("foo-bar", "baz")), "bar,baz"),
),
)
def test_field_hyphenated_accessor(headers, expected):
request = make_request(headers)
assert request.headers.foo_bar == request.headers.foo_bar_ == expected
def test_bad_accessor():
request = make_request({})
msg = "'Header' object has no attribute '_foo'"
with pytest.raises(AttributeError, match=msg):
request.headers._foo
def test_multiple_fields_accessor(app: Sanic):
@app.get("")
async def handler(request: Request):
return json({"field": request.headers.example_field})
_, response = app.test_client.get(
"/", headers=(("Example-Field", "Foo, Bar"), ("Example-Field", "Baz"))
)
assert response.json["field"] == "Foo, Bar,Baz"