diff --git a/sanic/compat.py b/sanic/compat.py index 35769876..693b1b78 100644 --- a/sanic/compat.py +++ b/sanic/compat.py @@ -88,6 +88,12 @@ class Header(CIMultiDict): very similar to a regular dictionary. """ + def __getattr__(self, key: str) -> str: + if key.startswith("_"): + return self.__getattribute__(key) + key = key.rstrip("_").replace("_", "-") + return ",".join(self.getall(key, default=[])) + def get_all(self, key: str): """ Convenience method mapped to ``getall()``. diff --git a/tests/test_headers.py b/tests/test_headers.py index 7f3a56dc..b6cd5210 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -2,12 +2,16 @@ from unittest.mock import Mock import pytest -from sanic import headers, text +from sanic import Sanic, headers, json, text from sanic.exceptions import InvalidHeader, PayloadTooLarge from sanic.http import Http from sanic.request import Request +def make_request(headers) -> Request: + return Request(b"/", headers, "1.1", "GET", None, None) + + @pytest.fixture def raised_ceiling(): Http.HEADER_CEILING = 32_768 @@ -435,3 +439,46 @@ def test_accept_misc(): a = headers.parse_accept("") assert a == [] assert not a.match("foo/bar") + + +@pytest.mark.parametrize( + "headers,expected", + ( + ({"foo": "bar"}, "bar"), + ((("foo", "bar"), ("foo", "baz")), "bar,baz"), + ({}, ""), + ), +) +def test_field_simple_accessor(headers, expected): + request = make_request(headers) + assert request.headers.foo == request.headers.foo_ == expected + + +@pytest.mark.parametrize( + "headers,expected", + ( + ({"foo-bar": "bar"}, "bar"), + ((("foo-bar", "bar"), ("foo-bar", "baz")), "bar,baz"), + ), +) +def test_field_hyphenated_accessor(headers, expected): + request = make_request(headers) + assert request.headers.foo_bar == request.headers.foo_bar_ == expected + + +def test_bad_accessor(): + request = make_request({}) + msg = "'Header' object has no attribute '_foo'" + with pytest.raises(AttributeError, match=msg): + request.headers._foo + + +def test_multiple_fields_accessor(app: Sanic): + @app.get("") + async def handler(request: Request): + return json({"field": request.headers.example_field}) + + _, response = app.test_client.get( + "/", headers=(("Example-Field", "Foo, Bar"), ("Example-Field", "Baz")) + ) + assert response.json["field"] == "Foo, Bar,Baz"