Refresh Request.accept functionality (#2687)

This commit is contained in:
Adam Hopkins 2023-02-21 08:22:51 +02:00 committed by GitHub
parent 6f5303e080
commit d238995f1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 396 additions and 295 deletions

View File

@ -470,10 +470,8 @@ def exception_response(
# Source: # Source:
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Content_negotiation/List_of_default_Accept_values # https://developer.mozilla.org/en-US/docs/Web/HTTP/Content_negotiation/List_of_default_Accept_values
if acceptable and acceptable[0].match( if acceptable and acceptable.match(
"text/html", "text/html", accept_wildcards=False
allow_type_wildcard=False,
allow_subtype_wildcard=False,
): ):
renderer = HTMLRenderer renderer = HTMLRenderer
@ -483,9 +481,7 @@ def exception_response(
elif ( elif (
acceptable acceptable
and acceptable.match( and acceptable.match(
"application/json", "application/json", accept_wildcards=False
allow_type_wildcard=False,
allow_subtype_wildcard=False,
) )
or content_type == "application/json" or content_type == "application/json"
): ):
@ -514,13 +510,13 @@ def exception_response(
# our choice is okay # our choice is okay
if acceptable: if acceptable:
type_ = CONTENT_TYPE_BY_RENDERERS.get(renderer) # type: ignore type_ = CONTENT_TYPE_BY_RENDERERS.get(renderer) # type: ignore
if type_ and type_ not in acceptable: if type_ and not acceptable.match(type_):
# If the renderer selected is not in the Accept header # If the renderer selected is not in the Accept header
# look through what is in the Accept header, and select # look through what is in the Accept header, and select
# the first option that matches. Otherwise, just drop back # the first option that matches. Otherwise, just drop back
# to the original default # to the original default
for accept in acceptable: for accept in acceptable:
mtype = f"{accept.type_}/{accept.subtype}" mtype = f"{accept.type}/{accept.subtype}"
maybe = RENDERERS_BY_CONTENT_TYPE.get(mtype) maybe = RENDERERS_BY_CONTENT_TYPE.get(mtype)
if maybe: if maybe:
renderer = maybe renderer = maybe

View File

@ -33,143 +33,96 @@ _host_re = re.compile(
# For more information, consult ../tests/test_requests.py # For more information, consult ../tests/test_requests.py
def parse_arg_as_accept(f): class MediaType:
def func(self, other, *args, **kwargs): """A media type, as used in the Accept header."""
if not isinstance(other, Accept) and other:
other = Accept.parse(other)
return f(self, other, *args, **kwargs)
return func
class MediaType(str):
def __new__(cls, value: str):
return str.__new__(cls, value)
def __init__(self, value: str) -> None:
self.value = value
self.is_wildcard = self.check_if_wildcard(value)
def __eq__(self, other):
if self.is_wildcard:
return True
if self.match(other):
return True
other_is_wildcard = (
other.is_wildcard
if isinstance(other, MediaType)
else self.check_if_wildcard(other)
)
return other_is_wildcard
def match(self, other):
other_value = other.value if isinstance(other, MediaType) else other
return self.value == other_value
@staticmethod
def check_if_wildcard(value):
return value == "*"
class Accept(str):
def __new__(cls, value: str, *args, **kwargs):
return str.__new__(cls, value)
def __init__( def __init__(
self, self,
value: str, type_: str,
type_: MediaType, subtype: str,
subtype: MediaType, **params: str,
*,
q: str = "1.0",
**kwargs: str,
): ):
qvalue = float(q) self.type = type_
if qvalue > 1 or qvalue < 0:
raise InvalidHeader(
f"Accept header qvalue must be between 0 and 1, not: {qvalue}"
)
self.value = value
self.type_ = type_
self.subtype = subtype self.subtype = subtype
self.qvalue = qvalue self.q = float(params.get("q", "1.0"))
self.params = kwargs self.params = params
self.mime = f"{type_}/{subtype}"
self.key = (
-1 * self.q,
-1 * len(self.params),
self.subtype == "*",
self.type == "*",
)
def _compare(self, other, method): def __repr__(self):
try: return self.mime + "".join(f";{k}={v}" for k, v in self.params.items())
return method(self.qvalue, other.qvalue)
except (AttributeError, TypeError):
return NotImplemented
@parse_arg_as_accept def __eq__(self, other):
def __lt__(self, other: Union[str, Accept]): """Check for mime (str or MediaType) identical type/subtype.
return self._compare(other, lambda s, o: s < o) Parameters such as q are not considered."""
if isinstance(other, str):
# Give a friendly reminder if str contains parameters
if ";" in other:
raise ValueError("Use match() to compare with parameters")
return self.mime == other
if isinstance(other, MediaType):
# Ignore parameters silently with MediaType objects
return self.mime == other.mime
return NotImplemented
@parse_arg_as_accept
def __le__(self, other: Union[str, Accept]):
return self._compare(other, lambda s, o: s <= o)
@parse_arg_as_accept
def __eq__(self, other: Union[str, Accept]): # type: ignore
return self._compare(other, lambda s, o: s == o)
@parse_arg_as_accept
def __ge__(self, other: Union[str, Accept]):
return self._compare(other, lambda s, o: s >= o)
@parse_arg_as_accept
def __gt__(self, other: Union[str, Accept]):
return self._compare(other, lambda s, o: s > o)
@parse_arg_as_accept
def __ne__(self, other: Union[str, Accept]): # type: ignore
return self._compare(other, lambda s, o: s != o)
@parse_arg_as_accept
def match( def match(
self, self,
other, mime_with_params: Union[str, MediaType],
*, ) -> Optional[MediaType]:
allow_type_wildcard: bool = True, """Check if this media type matches the given mime type/subtype.
allow_subtype_wildcard: bool = True, Wildcards are supported both ways on both type and subtype.
) -> bool: If mime contains a semicolon, optionally followed by parameters,
type_match = ( the parameters of the two media types must match exactly.
self.type_ == other.type_ Note: Use the `==` operator instead to check for literal matches
if allow_type_wildcard without expanding wildcards.
else ( @param media_type: A type/subtype string to match.
self.type_.match(other.type_) @return `self` if the media types are compatible, else `None`
and not self.type_.is_wildcard """
and not other.type_.is_wildcard mt = (
) MediaType._parse(mime_with_params)
if isinstance(mime_with_params, str)
else mime_with_params
) )
subtype_match = ( return (
self.subtype == other.subtype self
if allow_subtype_wildcard if (
else ( mt
self.subtype.match(other.subtype) # All parameters given in the other media type must match
and not self.subtype.is_wildcard and all(self.params.get(k) == v for k, v in mt.params.items())
and not other.subtype.is_wildcard # Subtype match
and (
self.subtype == mt.subtype
or self.subtype == "*"
or mt.subtype == "*"
)
# Type match
and (
self.type == mt.type or self.type == "*" or mt.type == "*"
)
) )
else None
) )
return type_match and subtype_match @property
def has_wildcard(self) -> bool:
"""Return True if this media type has a wildcard in it."""
return any(part == "*" for part in (self.subtype, self.type))
@classmethod @classmethod
def parse(cls, raw: str) -> Accept: def _parse(cls, mime_with_params: str) -> Optional[MediaType]:
invalid = False mtype = mime_with_params.strip()
mtype = raw.strip() if "/" not in mime_with_params:
return None
try: mime, *raw_params = mtype.split(";")
media, *raw_params = mtype.split(";") type_, subtype = mime.split("/", 1)
type_, subtype = media.split("/") if not type_ or not subtype:
except ValueError: raise ValueError(f"Invalid media type: {mtype}")
invalid = True
if invalid or not type_ or not subtype:
raise InvalidHeader(f"Header contains invalid Accept value: {raw}")
params = dict( params = dict(
[ [
@ -178,29 +131,140 @@ class Accept(str):
] ]
) )
return cls(mtype, MediaType(type_), MediaType(subtype), **params) return cls(type_.lstrip(), subtype.rstrip(), **params)
class AcceptContainer(list): class Matched:
def __contains__(self, o: object) -> bool: """A matching result of a MIME string against a header."""
return any(item.match(o) for item in self)
def match( def __init__(self, mime: str, header: Optional[MediaType]):
self, self.mime = mime
o: object, self.header = header
*,
allow_type_wildcard: bool = True, def __repr__(self):
allow_subtype_wildcard: bool = True, return f"<{self} matched {self.header}>" if self else "<no match>"
) -> bool:
return any( def __str__(self):
item.match( return self.mime
o,
allow_type_wildcard=allow_type_wildcard, def __bool__(self):
allow_subtype_wildcard=allow_subtype_wildcard, return self.header is not None
def __eq__(self, other: Any) -> bool:
try:
comp, other_accept = self._compare(other)
except TypeError:
return False
return bool(
comp
and (
(self.header and other_accept.header)
or (not self.header and not other_accept.header)
) )
for item in self
) )
def _compare(self, other) -> Tuple[bool, Matched]:
if isinstance(other, str):
# return self.mime == other, Accept.parse(other)
parsed = Matched.parse(other)
if self.mime == other:
return True, parsed
other = parsed
if isinstance(other, Matched):
return self.header == other.header, other
raise TypeError(
"Comparison not supported between unequal "
f"mime types of '{self.mime}' and '{other}'"
)
def match(self, other: Union[str, Matched]) -> Optional[Matched]:
accept = Matched.parse(other) if isinstance(other, str) else other
if not self.header or not accept.header:
return None
if self.header.match(accept.header):
return accept
return None
@classmethod
def parse(cls, raw: str) -> Matched:
media_type = MediaType._parse(raw)
return cls(raw, media_type)
class AcceptList(list):
"""A list of media types, as used in the Accept header.
The Accept header entries are listed in order of preference, starting
with the most preferred. This class is a list of `MediaType` objects,
that encapsulate also the q value or any other parameters.
Two separate methods are provided for searching the list:
- 'match' for finding the most preferred match (wildcards supported)
- operator 'in' for checking explicit matches (wildcards as literals)
"""
def match(self, *mimes: str, accept_wildcards=True) -> Matched:
"""Find a media type accepted by the client.
This method can be used to find which of the media types requested by
the client is most preferred against the ones given as arguments.
The ordering of preference is set by:
1. The order set by RFC 7231, s. 5.3.2, giving a higher priority
to q values and more specific type definitions,
2. The order of the arguments (first is most preferred), and
3. The first matching entry on the Accept header.
Wildcards are matched both ways. A match is usually found, as the
Accept headers typically include `*/*`, in particular if the header
is missing, is not manually set, or if the client is a browser.
Note: the returned object behaves as a string of the mime argument
that matched, and is empty/falsy if no match was found. The matched
header entry `MediaType` or `None` is available as the `m` attribute.
@param mimes: Any MIME types to search for in order of preference.
@param accept_wildcards: Match Accept entries with wildcards in them.
@return A match object with the mime string and the MediaType object.
"""
a = sorted(
(-acc.q, i, j, mime, acc)
for j, acc in enumerate(self)
if accept_wildcards or not acc.has_wildcard
for i, mime in enumerate(mimes)
if acc.match(mime)
)
return Matched(*(a[0][-2:] if a else ("", None)))
def __str__(self):
"""Format as Accept header value (parsed, not original)."""
return ", ".join(str(m) for m in self)
def parse_accept(accept: Optional[str]) -> AcceptList:
"""Parse an Accept header and order the acceptable media types in
according to RFC 7231, s. 5.3.2
https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2
"""
if not accept:
if accept == "":
return AcceptList() # Empty header, accept nothing
accept = "*/*" # No header means that all types are accepted
try:
a = [
mt
for mt in [MediaType._parse(mtype) for mtype in accept.split(",")]
if mt
]
if not a:
raise ValueError
return AcceptList(sorted(a, key=lambda x: x.key))
except ValueError:
raise InvalidHeader(f"Invalid header value in Accept: {accept}")
def parse_content_header(value: str) -> Tuple[str, Options]: def parse_content_header(value: str) -> Tuple[str, Options]:
"""Parse content-type and content-disposition header values. """Parse content-type and content-disposition header values.
@ -368,34 +432,6 @@ def format_http1_response(status: int, headers: HeaderBytesIterable) -> bytes:
return ret return ret
def _sort_accept_value(accept: Accept):
return (
accept.qvalue,
len(accept.params),
accept.subtype != "*",
accept.type_ != "*",
)
def parse_accept(accept: str) -> AcceptContainer:
"""Parse an Accept header and order the acceptable media types in
accorsing to RFC 7231, s. 5.3.2
https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2
"""
media_types = accept.split(",")
accept_list: List[Accept] = []
for mtype in media_types:
if not mtype:
continue
accept_list.append(Accept.parse(mtype))
return AcceptContainer(
sorted(accept_list, key=_sort_accept_value, reverse=True)
)
def parse_credentials( def parse_credentials(
header: Optional[str], header: Optional[str],
prefixes: Union[List, Tuple, Set] = None, prefixes: Union[List, Tuple, Set] = None,

View File

@ -47,7 +47,7 @@ from sanic.constants import (
) )
from sanic.exceptions import BadRequest, BadURL, ServerError from sanic.exceptions import BadRequest, BadURL, ServerError
from sanic.headers import ( from sanic.headers import (
AcceptContainer, AcceptList,
Options, Options,
parse_accept, parse_accept,
parse_content_header, parse_content_header,
@ -167,7 +167,7 @@ class Request:
self.conn_info: Optional[ConnInfo] = None self.conn_info: Optional[ConnInfo] = None
self.ctx = SimpleNamespace() self.ctx = SimpleNamespace()
self.parsed_forwarded: Optional[Options] = None self.parsed_forwarded: Optional[Options] = None
self.parsed_accept: Optional[AcceptContainer] = None self.parsed_accept: Optional[AcceptList] = None
self.parsed_credentials: Optional[Credentials] = None self.parsed_credentials: Optional[Credentials] = None
self.parsed_json = None self.parsed_json = None
self.parsed_form: Optional[RequestParameters] = None self.parsed_form: Optional[RequestParameters] = None
@ -499,10 +499,10 @@ class Request:
return self.parsed_json return self.parsed_json
@property @property
def accept(self) -> AcceptContainer: def accept(self) -> AcceptList:
""" """
:return: The ``Accept`` header parsed :return: The ``Accept`` header parsed
:rtype: AcceptContainer :rtype: AcceptList
""" """
if self.parsed_accept is None: if self.parsed_accept is None:
accept_header = self.headers.getone("accept", "") accept_header = self.headers.getone("accept", "")

View File

@ -351,6 +351,7 @@ async def test_websocket_text_receive(send, receive, message_stack):
assert text == msg["text"] assert text == msg["text"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_websocket_bytes_receive(send, receive, message_stack): async def test_websocket_bytes_receive(send, receive, message_stack):
msg = {"bytes": b"hello", "type": "websocket.receive"} msg = {"bytes": b"hello", "type": "websocket.receive"}
@ -361,6 +362,7 @@ async def test_websocket_bytes_receive(send, receive, message_stack):
assert data == msg["bytes"] assert data == msg["bytes"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_websocket_accept_with_no_subprotocols( async def test_websocket_accept_with_no_subprotocols(
send, receive, message_stack send, receive, message_stack

View File

@ -50,7 +50,10 @@ def raised_ceiling():
# cgi.parse_header: # cgi.parse_header:
# ('form-data', {'name': 'files', 'filename': 'fo"o;bar\\'}) # ('form-data', {'name': 'files', 'filename': 'fo"o;bar\\'})
# werkzeug.parse_options_header: # werkzeug.parse_options_header:
# ('form-data', {'name': 'files', 'filename': '"fo\\"o', 'bar\\"': None}) # (
# "form-data",
# {"name": "files", "filename": '"fo\\"o', 'bar\\"': None},
# ),
), ),
# <input type=file name="foo&quot;;bar\"> with Unicode filename! # <input type=file name="foo&quot;;bar\"> with Unicode filename!
( (
@ -187,27 +190,24 @@ def test_request_line(app):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"raw", "raw,expected_subtype",
( (
"show/first, show/second", ("show/first, show/second", "first"),
"show/*, show/first", ("show/*, show/first", "first"),
"*/*, show/first", ("*/*, show/first", "first"),
"*/*, show/*", ("*/*, show/*", "*"),
"other/*; q=0.1, show/*; q=0.2", ("other/*; q=0.1, show/*; q=0.2", "*"),
"show/first; q=0.5, show/second; q=0.5", ("show/first; q=0.5, show/second; q=0.5", "first"),
"show/first; foo=bar, show/second; foo=bar", ("show/first; foo=bar, show/second; foo=bar", "first"),
"show/second, show/first; foo=bar", ("show/second, show/first; foo=bar", "first"),
"show/second; q=0.5, show/first; foo=bar; q=0.5", ("show/second; q=0.5, show/first; foo=bar; q=0.5", "first"),
"show/second; q=0.5, show/first; q=1.0", ("show/second; q=0.5, show/first; q=1.0", "first"),
"show/first, show/second; q=1.0", ("show/first, show/second; q=1.0", "second"),
), ),
) )
def test_parse_accept_ordered_okay(raw): def test_parse_accept_ordered_okay(raw, expected_subtype):
ordered = headers.parse_accept(raw) ordered = headers.parse_accept(raw)
expected_subtype = ( assert ordered[0].type == "show"
"*" if all(q.subtype.is_wildcard for q in ordered) else "first"
)
assert ordered[0].type_ == "show"
assert ordered[0].subtype == expected_subtype assert ordered[0].subtype == expected_subtype
@ -217,6 +217,7 @@ def test_parse_accept_ordered_okay(raw):
"missing", "missing",
"missing/", "missing/",
"/missing", "/missing",
"/",
), ),
) )
def test_bad_accept(raw): def test_bad_accept(raw):
@ -225,128 +226,83 @@ def test_bad_accept(raw):
def test_empty_accept(): def test_empty_accept():
assert headers.parse_accept("") == [] a = headers.parse_accept("")
assert a == []
assert not a.match("*/*")
def test_wildcard_accept_set_ok(): def test_wildcard_accept_set_ok():
accept = headers.parse_accept("*/*")[0] accept = headers.parse_accept("*/*")[0]
assert accept.type_.is_wildcard assert accept.type == "*"
assert accept.subtype.is_wildcard assert accept.subtype == "*"
assert accept.has_wildcard
accept = headers.parse_accept("foo/*")[0]
assert accept.type == "foo"
assert accept.subtype == "*"
assert accept.has_wildcard
accept = headers.parse_accept("foo/bar")[0] accept = headers.parse_accept("foo/bar")[0]
assert not accept.type_.is_wildcard assert accept.type == "foo"
assert not accept.subtype.is_wildcard assert accept.subtype == "bar"
assert not accept.has_wildcard
def test_accept_parsed_against_str(): def test_accept_parsed_against_str():
accept = headers.Accept.parse("foo/bar") accept = headers.Matched.parse("foo/bar")
assert accept > "foo/bar; q=0.1" assert accept == "foo/bar; q=0.1"
def test_media_type_equality():
assert headers.MediaType("foo") == headers.MediaType("foo") == "foo"
assert headers.MediaType("foo") == headers.MediaType("*") == "*"
assert headers.MediaType("foo") != headers.MediaType("bar")
assert headers.MediaType("foo") != "bar"
def test_media_type_matching(): def test_media_type_matching():
assert headers.MediaType("foo").match(headers.MediaType("foo")) assert headers.MediaType("foo", "bar").match(
assert headers.MediaType("foo").match("foo") headers.MediaType("foo", "bar")
)
assert not headers.MediaType("foo").match(headers.MediaType("*")) assert headers.MediaType("foo", "bar").match("foo/bar")
assert not headers.MediaType("foo").match("*")
assert not headers.MediaType("foo").match(headers.MediaType("bar"))
assert not headers.MediaType("foo").match("bar")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"value,other,outcome,allow_type,allow_subtype", "value,other,outcome",
( (
# ALLOW BOTH # ALLOW BOTH
("foo/bar", "foo/bar", True, True, True), ("foo/bar", "foo/bar", True),
("foo/bar", headers.Accept.parse("foo/bar"), True, True, True), ("foo/bar", headers.Matched.parse("foo/bar"), True),
("foo/bar", "foo/*", True, True, True), ("foo/bar", "foo/*", True),
("foo/bar", headers.Accept.parse("foo/*"), True, True, True), ("foo/bar", headers.Matched.parse("foo/*"), True),
("foo/bar", "*/*", True, True, True), ("foo/bar", "*/*", True),
("foo/bar", headers.Accept.parse("*/*"), True, True, True), ("foo/bar", headers.Matched.parse("*/*"), True),
("foo/*", "foo/bar", True, True, True), ("foo/*", "foo/bar", True),
("foo/*", headers.Accept.parse("foo/bar"), True, True, True), ("foo/*", headers.Matched.parse("foo/bar"), True),
("foo/*", "foo/*", True, True, True), ("foo/*", "foo/*", True),
("foo/*", headers.Accept.parse("foo/*"), True, True, True), ("foo/*", headers.Matched.parse("foo/*"), True),
("foo/*", "*/*", True, True, True), ("foo/*", "*/*", True),
("foo/*", headers.Accept.parse("*/*"), True, True, True), ("foo/*", headers.Matched.parse("*/*"), True),
("*/*", "foo/bar", True, True, True), ("*/*", "foo/bar", True),
("*/*", headers.Accept.parse("foo/bar"), True, True, True), ("*/*", headers.Matched.parse("foo/bar"), True),
("*/*", "foo/*", True, True, True), ("*/*", "foo/*", True),
("*/*", headers.Accept.parse("foo/*"), True, True, True), ("*/*", headers.Matched.parse("foo/*"), True),
("*/*", "*/*", True, True, True), ("*/*", "*/*", True),
("*/*", headers.Accept.parse("*/*"), True, True, True), ("*/*", headers.Matched.parse("*/*"), True),
# ALLOW TYPE
("foo/bar", "foo/bar", True, True, False),
("foo/bar", headers.Accept.parse("foo/bar"), True, True, False),
("foo/bar", "foo/*", False, True, False),
("foo/bar", headers.Accept.parse("foo/*"), False, True, False),
("foo/bar", "*/*", False, True, False),
("foo/bar", headers.Accept.parse("*/*"), False, True, False),
("foo/*", "foo/bar", False, True, False),
("foo/*", headers.Accept.parse("foo/bar"), False, True, False),
("foo/*", "foo/*", False, True, False),
("foo/*", headers.Accept.parse("foo/*"), False, True, False),
("foo/*", "*/*", False, True, False),
("foo/*", headers.Accept.parse("*/*"), False, True, False),
("*/*", "foo/bar", False, True, False),
("*/*", headers.Accept.parse("foo/bar"), False, True, False),
("*/*", "foo/*", False, True, False),
("*/*", headers.Accept.parse("foo/*"), False, True, False),
("*/*", "*/*", False, True, False),
("*/*", headers.Accept.parse("*/*"), False, True, False),
# ALLOW SUBTYPE
("foo/bar", "foo/bar", True, False, True),
("foo/bar", headers.Accept.parse("foo/bar"), True, False, True),
("foo/bar", "foo/*", True, False, True),
("foo/bar", headers.Accept.parse("foo/*"), True, False, True),
("foo/bar", "*/*", False, False, True),
("foo/bar", headers.Accept.parse("*/*"), False, False, True),
("foo/*", "foo/bar", True, False, True),
("foo/*", headers.Accept.parse("foo/bar"), True, False, True),
("foo/*", "foo/*", True, False, True),
("foo/*", headers.Accept.parse("foo/*"), True, False, True),
("foo/*", "*/*", False, False, True),
("foo/*", headers.Accept.parse("*/*"), False, False, True),
("*/*", "foo/bar", False, False, True),
("*/*", headers.Accept.parse("foo/bar"), False, False, True),
("*/*", "foo/*", False, False, True),
("*/*", headers.Accept.parse("foo/*"), False, False, True),
("*/*", "*/*", False, False, True),
("*/*", headers.Accept.parse("*/*"), False, False, True),
), ),
) )
def test_accept_matching(value, other, outcome, allow_type, allow_subtype): def test_accept_matching(value, other, outcome):
assert ( assert bool(headers.Matched.parse(value).match(other)) is outcome
headers.Accept.parse(value).match(
other,
allow_type_wildcard=allow_type,
allow_subtype_wildcard=allow_subtype,
)
is outcome
)
@pytest.mark.parametrize("value", ("foo/bar", "foo/*", "*/*")) @pytest.mark.parametrize("value", ("foo/bar", "foo/*", "*/*"))
def test_value_in_accept(value): def test_value_in_accept(value):
acceptable = headers.parse_accept(value) acceptable = headers.parse_accept(value)
assert "foo/bar" in acceptable assert acceptable.match("foo/bar")
assert "foo/*" in acceptable assert acceptable.match("foo/*")
assert "*/*" in acceptable assert acceptable.match("*/*")
@pytest.mark.parametrize("value", ("foo/bar", "foo/*")) @pytest.mark.parametrize("value", ("foo/bar", "foo/*"))
def test_value_not_in_accept(value): def test_value_not_in_accept(value):
acceptable = headers.parse_accept(value) acceptable = headers.parse_accept(value)
assert "no/match" not in acceptable assert not acceptable.match("no/match")
assert "no/*" not in acceptable assert not acceptable.match("no/*")
assert "*/*" not in acceptable
assert "*/bar" not in acceptable
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -365,6 +321,117 @@ def test_value_not_in_accept(value):
), ),
), ),
) )
def test_browser_headers(header, expected): def test_browser_headers_general(header, expected):
request = Request(b"/", {"accept": header}, "1.1", "GET", None, None) request = Request(b"/", {"accept": header}, "1.1", "GET", None, None)
assert request.accept == expected assert [str(item) for item in request.accept] == expected
@pytest.mark.parametrize(
"header,expected",
(
(
"text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", # noqa: E501
[
("text/html", 1.0),
("application/xhtml+xml", 1.0),
("image/avif", 1.0),
("image/webp", 1.0),
("application/xml", 0.9),
("*/*", 0.8),
],
),
),
)
def test_browser_headers_specific(header, expected):
mimes = [e[0] for e in expected]
qs = [e[1] for e in expected]
request = Request(b"/", {"accept": header}, "1.1", "GET", None, None)
assert request.accept == mimes
for a, m, q in zip(request.accept, mimes, qs):
assert a == m
assert a.mime == m
assert a.q == q
@pytest.mark.parametrize(
"raw",
(
"text/html, application/xhtml+xml, application/xml;q=0.9, */*;q=0.8",
"application/xml;q=0.9, */*;q=0.8, text/html, application/xhtml+xml",
(
"foo/bar;q=0.9, */*;q=0.8, text/html=0.8, "
"text/plain, application/xhtml+xml"
),
),
)
def test_accept_ordering(raw):
"""Should sort by q but also be stable."""
accept = headers.parse_accept(raw)
assert accept[0].type == "text"
raw1 = ", ".join(str(a) for a in accept)
accept = headers.parse_accept(raw1)
raw2 = ", ".join(str(a) for a in accept)
assert raw1 == raw2
def test_not_accept_wildcard():
accept = headers.parse_accept("*/*, foo/*, */bar, foo/bar;q=0.1")
assert not accept.match(
"text/html", "foo/foo", "bar/bar", accept_wildcards=False
)
# Should ignore wildcards in accept but still matches them from mimes
m = accept.match("text/plain", "*/*", accept_wildcards=False)
assert m.mime == "*/*"
assert m.match("*/*")
assert m.header == "foo/bar"
assert not accept.match(
"text/html", "foo/foo", "bar/bar", accept_wildcards=False
)
def test_accept_misc():
header = (
"foo/bar;q=0.0, */plain;param=123, text/plain, text/*, foo/bar;q=0.5"
)
a = headers.parse_accept(header)
assert repr(a) == (
"[*/plain;param=123, text/plain, text/*, "
"foo/bar;q=0.5, foo/bar;q=0.0]"
) # noqa: E501
assert str(a) == (
"*/plain;param=123, text/plain, text/*, "
"foo/bar;q=0.5, foo/bar;q=0.0"
) # noqa: E501
# q=1 types don't match foo/bar but match the two others,
# text/* comes first and matches */plain because it
# comes first in the header
m = a.match("foo/bar", "text/*", "text/plain")
assert repr(m) == "<text/* matched */plain;param=123>"
assert m == "text/*"
assert m.mime == "text/*"
assert m.header.mime == "*/plain"
assert m.header.type == "*"
assert m.header.subtype == "plain"
assert m.header.q == 1.0
assert m.header.params == dict(param="123")
# Matches object against another Matched object (by mime and header)
assert m == a.match("text/*")
# Against unsupported type falls back to object id matching
assert m != 123
# Matches the highest q value
m = a.match("foo/bar")
assert repr(m) == "<foo/bar matched foo/bar;q=0.5>"
assert m == "foo/bar"
assert m == "foo/bar;q=0.5"
# Matching nothing special case
m = a.match()
assert m == ""
assert m.header is None
# No header means anything
a = headers.parse_accept(None)
assert a == ["*/*"]
assert a.match("foo/bar")
# Empty header means nothing
a = headers.parse_accept("")
assert a == []
assert not a.match("foo/bar")

View File

@ -156,7 +156,7 @@ def test_request_accept():
"Accept": "text/*, text/plain, text/plain;format=flowed, */*" "Accept": "text/*, text/plain, text/plain;format=flowed, */*"
}, },
) )
assert request.accept == [ assert [str(i) for i in request.accept] == [
"text/plain;format=flowed", "text/plain;format=flowed",
"text/plain", "text/plain",
"text/*", "text/*",
@ -171,11 +171,11 @@ def test_request_accept():
) )
}, },
) )
assert request.accept == [ assert [str(i) for i in request.accept] == [
"text/html", "text/html",
"text/x-c", "text/x-c",
"text/x-dvi; q=0.8", "text/x-dvi;q=0.8",
"text/plain; q=0.5", "text/plain;q=0.5",
] ]