Accept header parsing (#2200)

* Add some tests

* docstring

* Add accept matching

* Add some more tests on matching

* Add matching flags for wildcards

* Add mathing controls to accept

* Limit uvicorn 14 in testing
This commit is contained in:
Adam Hopkins 2021-08-19 21:09:40 +03:00 committed by GitHub
parent e2eefaac55
commit f32ef20b74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 387 additions and 4 deletions

View File

@ -126,8 +126,11 @@ class HeaderNotFound(InvalidUsage):
**Status**: 400 Bad Request **Status**: 400 Bad Request
""" """
status_code = 400
quiet = True class InvalidHeader(InvalidUsage):
"""
**Status**: 400 Bad Request
"""
class ContentRangeError(SanicException): class ContentRangeError(SanicException):

View File

@ -1,8 +1,11 @@
from __future__ import annotations
import re import re
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from urllib.parse import unquote from urllib.parse import unquote
from sanic.exceptions import InvalidHeader
from sanic.helpers import STATUS_CODES from sanic.helpers import STATUS_CODES
@ -30,6 +33,154 @@ _host_re = re.compile(
# For more information, consult ../tests/test_requests.py # For more information, consult ../tests/test_requests.py
def parse_arg_as_accept(f):
def func(self, other, *args, **kwargs):
if not isinstance(other, Accept):
other = Accept.parse(other)
return f(self, other, *args, **kwargs)
return func
class MediaType(str):
def __new__(cls, value: str):
return str.__new__(cls, value)
def __init__(self, value: str) -> None:
self.value = value
self.is_wildcard = self.check_if_wildcard(value)
def __eq__(self, other):
if self.is_wildcard:
return True
if self.match(other):
return True
other_is_wildcard = (
other.is_wildcard
if isinstance(other, MediaType)
else self.check_if_wildcard(other)
)
return other_is_wildcard
def match(self, other):
other_value = other.value if isinstance(other, MediaType) else other
return self.value == other_value
@staticmethod
def check_if_wildcard(value):
return value == "*"
class Accept(str):
def __new__(cls, value: str, *args, **kwargs):
return str.__new__(cls, value)
def __init__(
self,
value: str,
type_: MediaType,
subtype: MediaType,
*,
q: str = "1.0",
**kwargs: str,
):
qvalue = float(q)
if qvalue > 1 or qvalue < 0:
raise InvalidHeader(
f"Accept header qvalue must be between 0 and 1, not: {qvalue}"
)
self.value = value
self.type_ = type_
self.subtype = subtype
self.qvalue = qvalue
self.params = kwargs
def _compare(self, other, method):
try:
return method(self.qvalue, other.qvalue)
except (AttributeError, TypeError):
return NotImplemented
@parse_arg_as_accept
def __lt__(self, other: Union[str, Accept]):
return self._compare(other, lambda s, o: s < o)
@parse_arg_as_accept
def __le__(self, other: Union[str, Accept]):
return self._compare(other, lambda s, o: s <= o)
@parse_arg_as_accept
def __eq__(self, other: Union[str, Accept]): # type: ignore
return self._compare(other, lambda s, o: s == o)
@parse_arg_as_accept
def __ge__(self, other: Union[str, Accept]):
return self._compare(other, lambda s, o: s >= o)
@parse_arg_as_accept
def __gt__(self, other: Union[str, Accept]):
return self._compare(other, lambda s, o: s > o)
@parse_arg_as_accept
def __ne__(self, other: Union[str, Accept]): # type: ignore
return self._compare(other, lambda s, o: s != o)
@parse_arg_as_accept
def match(
self,
other,
*,
allow_type_wildcard: bool = True,
allow_subtype_wildcard: bool = True,
) -> bool:
type_match = (
self.type_ == other.type_
if allow_type_wildcard
else (
self.type_.match(other.type_)
and not self.type_.is_wildcard
and not other.type_.is_wildcard
)
)
subtype_match = (
self.subtype == other.subtype
if allow_subtype_wildcard
else (
self.subtype.match(other.subtype)
and not self.subtype.is_wildcard
and not other.subtype.is_wildcard
)
)
return type_match and subtype_match
@classmethod
def parse(cls, raw: str) -> Accept:
invalid = False
mtype = raw.strip()
try:
media, *raw_params = mtype.split(";")
type_, subtype = media.split("/")
except ValueError:
invalid = True
if invalid or not type_ or not subtype:
raise InvalidHeader(f"Header contains invalid Accept value: {raw}")
params = dict(
[
(key.strip(), value.strip())
for key, value in (param.split("=", 1) for param in raw_params)
]
)
return cls(mtype, MediaType(type_), MediaType(subtype), **params)
def parse_content_header(value: str) -> Tuple[str, Options]: def parse_content_header(value: str) -> Tuple[str, Options]:
"""Parse content-type and content-disposition header values. """Parse content-type and content-disposition header values.
@ -194,3 +345,29 @@ def format_http1_response(status: int, headers: HeaderBytesIterable) -> bytes:
ret += b"%b: %b\r\n" % h ret += b"%b: %b\r\n" % h
ret += b"\r\n" ret += b"\r\n"
return ret return ret
def _sort_accept_value(accept: Accept):
return (
accept.qvalue,
len(accept.params),
accept.subtype != "*",
accept.type_ != "*",
)
def parse_accept(accept: str) -> List[Accept]:
"""Parse an Accept header and order the acceptable media types in
accorsing to RFC 7231, s. 5.3.2
https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2
"""
media_types = accept.split(",")
accept_list: List[Accept] = []
for mtype in media_types:
if not mtype:
continue
accept_list.append(Accept.parse(mtype))
return sorted(accept_list, key=_sort_accept_value, reverse=True)

View File

@ -34,7 +34,9 @@ from sanic.compat import CancelledErrors, Header
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE
from sanic.exceptions import InvalidUsage from sanic.exceptions import InvalidUsage
from sanic.headers import ( from sanic.headers import (
Accept,
Options, Options,
parse_accept,
parse_content_header, parse_content_header,
parse_forwarded, parse_forwarded,
parse_host, parse_host,
@ -94,6 +96,7 @@ class Request:
"head", "head",
"headers", "headers",
"method", "method",
"parsed_accept",
"parsed_args", "parsed_args",
"parsed_not_grouped_args", "parsed_not_grouped_args",
"parsed_files", "parsed_files",
@ -136,6 +139,7 @@ class Request:
self.conn_info: Optional[ConnInfo] = None self.conn_info: Optional[ConnInfo] = None
self.ctx = SimpleNamespace() self.ctx = SimpleNamespace()
self.parsed_forwarded: Optional[Options] = None self.parsed_forwarded: Optional[Options] = None
self.parsed_accept: Optional[List[Accept]] = None
self.parsed_json = None self.parsed_json = None
self.parsed_form = None self.parsed_form = None
self.parsed_files = None self.parsed_files = None
@ -296,6 +300,13 @@ class Request:
return self.parsed_json return self.parsed_json
@property
def accept(self) -> List[Accept]:
if self.parsed_accept is None:
accept_header = self.headers.getone("accept", "")
self.parsed_accept = parse_accept(accept_header)
return self.parsed_accept
@property @property
def token(self): def token(self):
"""Attempt to return the auth header token. """Attempt to return the auth header token.

View File

@ -110,7 +110,7 @@ tests_require = [
"mypy>=0.901", "mypy>=0.901",
"docutils", "docutils",
"pygments", "pygments",
"uvicorn", "uvicorn<0.15.0",
types_ujson, types_ujson,
] ]

View File

@ -3,7 +3,7 @@ from unittest.mock import Mock
import pytest import pytest
from sanic import headers, text from sanic import headers, text
from sanic.exceptions import PayloadTooLarge from sanic.exceptions import InvalidHeader, PayloadTooLarge
from sanic.http import Http from sanic.http import Http
@ -182,3 +182,159 @@ def test_request_line(app):
) )
assert request.request_line == b"GET / HTTP/1.1" assert request.request_line == b"GET / HTTP/1.1"
@pytest.mark.parametrize(
"raw",
(
"show/first, show/second",
"show/*, show/first",
"*/*, show/first",
"*/*, show/*",
"other/*; q=0.1, show/*; q=0.2",
"show/first; q=0.5, show/second; q=0.5",
"show/first; foo=bar, show/second; foo=bar",
"show/second, show/first; foo=bar",
"show/second; q=0.5, show/first; foo=bar; q=0.5",
"show/second; q=0.5, show/first; q=1.0",
"show/first, show/second; q=1.0",
),
)
def test_parse_accept_ordered_okay(raw):
ordered = headers.parse_accept(raw)
expected_subtype = (
"*" if all(q.subtype.is_wildcard for q in ordered) else "first"
)
assert ordered[0].type_ == "show"
assert ordered[0].subtype == expected_subtype
@pytest.mark.parametrize(
"raw",
(
"missing",
"missing/",
"/missing",
),
)
def test_bad_accept(raw):
with pytest.raises(InvalidHeader):
headers.parse_accept(raw)
def test_empty_accept():
assert headers.parse_accept("") == []
def test_wildcard_accept_set_ok():
accept = headers.parse_accept("*/*")[0]
assert accept.type_.is_wildcard
assert accept.subtype.is_wildcard
accept = headers.parse_accept("foo/bar")[0]
assert not accept.type_.is_wildcard
assert not accept.subtype.is_wildcard
def test_accept_parsed_against_str():
accept = headers.Accept.parse("foo/bar")
assert accept > "foo/bar; q=0.1"
def test_media_type_equality():
assert headers.MediaType("foo") == headers.MediaType("foo") == "foo"
assert headers.MediaType("foo") == headers.MediaType("*") == "*"
assert headers.MediaType("foo") != headers.MediaType("bar")
assert headers.MediaType("foo") != "bar"
def test_media_type_matching():
assert headers.MediaType("foo").match(headers.MediaType("foo"))
assert headers.MediaType("foo").match("foo")
assert not headers.MediaType("foo").match(headers.MediaType("*"))
assert not headers.MediaType("foo").match("*")
assert not headers.MediaType("foo").match(headers.MediaType("bar"))
assert not headers.MediaType("foo").match("bar")
@pytest.mark.parametrize(
"value,other,outcome,allow_type,allow_subtype",
(
# ALLOW BOTH
("foo/bar", "foo/bar", True, True, True),
("foo/bar", headers.Accept.parse("foo/bar"), True, True, True),
("foo/bar", "foo/*", True, True, True),
("foo/bar", headers.Accept.parse("foo/*"), True, True, True),
("foo/bar", "*/*", True, True, True),
("foo/bar", headers.Accept.parse("*/*"), True, True, True),
("foo/*", "foo/bar", True, True, True),
("foo/*", headers.Accept.parse("foo/bar"), True, True, True),
("foo/*", "foo/*", True, True, True),
("foo/*", headers.Accept.parse("foo/*"), True, True, True),
("foo/*", "*/*", True, True, True),
("foo/*", headers.Accept.parse("*/*"), True, True, True),
("*/*", "foo/bar", True, True, True),
("*/*", headers.Accept.parse("foo/bar"), True, True, True),
("*/*", "foo/*", True, True, True),
("*/*", headers.Accept.parse("foo/*"), True, True, True),
("*/*", "*/*", True, True, True),
("*/*", headers.Accept.parse("*/*"), True, True, True),
# ALLOW TYPE
("foo/bar", "foo/bar", True, True, False),
("foo/bar", headers.Accept.parse("foo/bar"), True, True, False),
("foo/bar", "foo/*", False, True, False),
("foo/bar", headers.Accept.parse("foo/*"), False, True, False),
("foo/bar", "*/*", False, True, False),
("foo/bar", headers.Accept.parse("*/*"), False, True, False),
("foo/*", "foo/bar", False, True, False),
("foo/*", headers.Accept.parse("foo/bar"), False, True, False),
("foo/*", "foo/*", False, True, False),
("foo/*", headers.Accept.parse("foo/*"), False, True, False),
("foo/*", "*/*", False, True, False),
("foo/*", headers.Accept.parse("*/*"), False, True, False),
("*/*", "foo/bar", False, True, False),
("*/*", headers.Accept.parse("foo/bar"), False, True, False),
("*/*", "foo/*", False, True, False),
("*/*", headers.Accept.parse("foo/*"), False, True, False),
("*/*", "*/*", False, True, False),
("*/*", headers.Accept.parse("*/*"), False, True, False),
# ALLOW SUBTYPE
("foo/bar", "foo/bar", True, False, True),
("foo/bar", headers.Accept.parse("foo/bar"), True, False, True),
("foo/bar", "foo/*", True, False, True),
("foo/bar", headers.Accept.parse("foo/*"), True, False, True),
("foo/bar", "*/*", False, False, True),
("foo/bar", headers.Accept.parse("*/*"), False, False, True),
("foo/*", "foo/bar", True, False, True),
("foo/*", headers.Accept.parse("foo/bar"), True, False, True),
("foo/*", "foo/*", True, False, True),
("foo/*", headers.Accept.parse("foo/*"), True, False, True),
("foo/*", "*/*", False, False, True),
("foo/*", headers.Accept.parse("*/*"), False, False, True),
("*/*", "foo/bar", False, False, True),
("*/*", headers.Accept.parse("foo/bar"), False, False, True),
("*/*", "foo/*", False, False, True),
("*/*", headers.Accept.parse("foo/*"), False, False, True),
("*/*", "*/*", False, False, True),
("*/*", headers.Accept.parse("*/*"), False, False, True),
),
)
def test_accept_matching(value, other, outcome, allow_type, allow_subtype):
assert (
headers.Accept.parse(value).match(
other,
allow_type_wildcard=allow_type,
allow_subtype_wildcard=allow_subtype,
)
is outcome
)
@pytest.mark.parametrize("value", ("foo/bar", "foo/*", "*/*"))
def test_value_in_accept(value):
acceptable = headers.parse_accept(value)
assert "foo/bar" in acceptable
assert "foo/*" in acceptable
assert "*/*" in acceptable

View File

@ -140,3 +140,39 @@ def test_ipv6_address_is_not_wrapped(app):
assert resp.json["client"] == "[::1]" assert resp.json["client"] == "[::1]"
assert resp.json["client_ip"] == "::1" assert resp.json["client_ip"] == "::1"
assert request.ip == "::1" assert request.ip == "::1"
def test_request_accept():
app = Sanic("req-generator")
@app.get("/")
async def get(request):
return response.empty()
request, _ = app.test_client.get(
"/",
headers={
"Accept": "text/*, text/plain, text/plain;format=flowed, */*"
},
)
assert request.accept == [
"text/plain;format=flowed",
"text/plain",
"text/*",
"*/*",
]
request, _ = app.test_client.get(
"/",
headers={
"Accept": (
"text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c"
)
},
)
assert request.accept == [
"text/html",
"text/x-c",
"text/x-dvi; q=0.8",
"text/plain; q=0.5",
]