From d238995f1be11510aad81c4e877485ae446afb5c Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Tue, 21 Feb 2023 08:22:51 +0200 Subject: [PATCH] Refresh Request.accept functionality (#2687) --- sanic/errorpages.py | 14 +- sanic/headers.py | 366 +++++++++++++++++++++++------------------- sanic/request.py | 8 +- tests/test_asgi.py | 2 + tests/test_headers.py | 293 ++++++++++++++++++++------------- tests/test_request.py | 8 +- 6 files changed, 396 insertions(+), 295 deletions(-) diff --git a/sanic/errorpages.py b/sanic/errorpages.py index 6354210a..c20896db 100644 --- a/sanic/errorpages.py +++ b/sanic/errorpages.py @@ -470,10 +470,8 @@ def exception_response( # Source: # https://developer.mozilla.org/en-US/docs/Web/HTTP/Content_negotiation/List_of_default_Accept_values - if acceptable and acceptable[0].match( - "text/html", - allow_type_wildcard=False, - allow_subtype_wildcard=False, + if acceptable and acceptable.match( + "text/html", accept_wildcards=False ): renderer = HTMLRenderer @@ -483,9 +481,7 @@ def exception_response( elif ( acceptable and acceptable.match( - "application/json", - allow_type_wildcard=False, - allow_subtype_wildcard=False, + "application/json", accept_wildcards=False ) or content_type == "application/json" ): @@ -514,13 +510,13 @@ def exception_response( # our choice is okay if acceptable: type_ = CONTENT_TYPE_BY_RENDERERS.get(renderer) # type: ignore - if type_ and type_ not in acceptable: + if type_ and not acceptable.match(type_): # If the renderer selected is not in the Accept header # look through what is in the Accept header, and select # the first option that matches. Otherwise, just drop back # to the original default for accept in acceptable: - mtype = f"{accept.type_}/{accept.subtype}" + mtype = f"{accept.type}/{accept.subtype}" maybe = RENDERERS_BY_CONTENT_TYPE.get(mtype) if maybe: renderer = maybe diff --git a/sanic/headers.py b/sanic/headers.py index 5e4ded03..b192fa24 100644 --- a/sanic/headers.py +++ b/sanic/headers.py @@ -33,143 +33,96 @@ _host_re = re.compile( # For more information, consult ../tests/test_requests.py -def parse_arg_as_accept(f): - def func(self, other, *args, **kwargs): - if not isinstance(other, Accept) and other: - other = Accept.parse(other) - return f(self, other, *args, **kwargs) - - return func - - -class MediaType(str): - def __new__(cls, value: str): - return str.__new__(cls, value) - - def __init__(self, value: str) -> None: - self.value = value - self.is_wildcard = self.check_if_wildcard(value) - - def __eq__(self, other): - if self.is_wildcard: - return True - - if self.match(other): - return True - - other_is_wildcard = ( - other.is_wildcard - if isinstance(other, MediaType) - else self.check_if_wildcard(other) - ) - - return other_is_wildcard - - def match(self, other): - other_value = other.value if isinstance(other, MediaType) else other - return self.value == other_value - - @staticmethod - def check_if_wildcard(value): - return value == "*" - - -class Accept(str): - def __new__(cls, value: str, *args, **kwargs): - return str.__new__(cls, value) +class MediaType: + """A media type, as used in the Accept header.""" def __init__( self, - value: str, - type_: MediaType, - subtype: MediaType, - *, - q: str = "1.0", - **kwargs: str, + type_: str, + subtype: str, + **params: str, ): - qvalue = float(q) - if qvalue > 1 or qvalue < 0: - raise InvalidHeader( - f"Accept header qvalue must be between 0 and 1, not: {qvalue}" - ) - self.value = value - self.type_ = type_ + self.type = type_ self.subtype = subtype - self.qvalue = qvalue - self.params = kwargs + self.q = float(params.get("q", "1.0")) + self.params = params + self.mime = f"{type_}/{subtype}" + self.key = ( + -1 * self.q, + -1 * len(self.params), + self.subtype == "*", + self.type == "*", + ) - def _compare(self, other, method): - try: - return method(self.qvalue, other.qvalue) - except (AttributeError, TypeError): - return NotImplemented + def __repr__(self): + return self.mime + "".join(f";{k}={v}" for k, v in self.params.items()) - @parse_arg_as_accept - def __lt__(self, other: Union[str, Accept]): - return self._compare(other, lambda s, o: s < o) + def __eq__(self, other): + """Check for mime (str or MediaType) identical type/subtype. + Parameters such as q are not considered.""" + if isinstance(other, str): + # Give a friendly reminder if str contains parameters + if ";" in other: + raise ValueError("Use match() to compare with parameters") + return self.mime == other + if isinstance(other, MediaType): + # Ignore parameters silently with MediaType objects + return self.mime == other.mime + return NotImplemented - @parse_arg_as_accept - def __le__(self, other: Union[str, Accept]): - return self._compare(other, lambda s, o: s <= o) - - @parse_arg_as_accept - def __eq__(self, other: Union[str, Accept]): # type: ignore - return self._compare(other, lambda s, o: s == o) - - @parse_arg_as_accept - def __ge__(self, other: Union[str, Accept]): - return self._compare(other, lambda s, o: s >= o) - - @parse_arg_as_accept - def __gt__(self, other: Union[str, Accept]): - return self._compare(other, lambda s, o: s > o) - - @parse_arg_as_accept - def __ne__(self, other: Union[str, Accept]): # type: ignore - return self._compare(other, lambda s, o: s != o) - - @parse_arg_as_accept def match( self, - other, - *, - allow_type_wildcard: bool = True, - allow_subtype_wildcard: bool = True, - ) -> bool: - type_match = ( - self.type_ == other.type_ - if allow_type_wildcard - else ( - self.type_.match(other.type_) - and not self.type_.is_wildcard - and not other.type_.is_wildcard - ) + mime_with_params: Union[str, MediaType], + ) -> Optional[MediaType]: + """Check if this media type matches the given mime type/subtype. + Wildcards are supported both ways on both type and subtype. + If mime contains a semicolon, optionally followed by parameters, + the parameters of the two media types must match exactly. + Note: Use the `==` operator instead to check for literal matches + without expanding wildcards. + @param media_type: A type/subtype string to match. + @return `self` if the media types are compatible, else `None` + """ + mt = ( + MediaType._parse(mime_with_params) + if isinstance(mime_with_params, str) + else mime_with_params ) - subtype_match = ( - self.subtype == other.subtype - if allow_subtype_wildcard - else ( - self.subtype.match(other.subtype) - and not self.subtype.is_wildcard - and not other.subtype.is_wildcard + return ( + self + if ( + mt + # All parameters given in the other media type must match + and all(self.params.get(k) == v for k, v in mt.params.items()) + # Subtype match + and ( + self.subtype == mt.subtype + or self.subtype == "*" + or mt.subtype == "*" + ) + # Type match + and ( + self.type == mt.type or self.type == "*" or mt.type == "*" + ) ) + else None ) - return type_match and subtype_match + @property + def has_wildcard(self) -> bool: + """Return True if this media type has a wildcard in it.""" + return any(part == "*" for part in (self.subtype, self.type)) @classmethod - def parse(cls, raw: str) -> Accept: - invalid = False - mtype = raw.strip() + def _parse(cls, mime_with_params: str) -> Optional[MediaType]: + mtype = mime_with_params.strip() + if "/" not in mime_with_params: + return None - try: - media, *raw_params = mtype.split(";") - type_, subtype = media.split("/") - except ValueError: - invalid = True - - if invalid or not type_ or not subtype: - raise InvalidHeader(f"Header contains invalid Accept value: {raw}") + mime, *raw_params = mtype.split(";") + type_, subtype = mime.split("/", 1) + if not type_ or not subtype: + raise ValueError(f"Invalid media type: {mtype}") params = dict( [ @@ -178,29 +131,140 @@ class Accept(str): ] ) - return cls(mtype, MediaType(type_), MediaType(subtype), **params) + return cls(type_.lstrip(), subtype.rstrip(), **params) -class AcceptContainer(list): - def __contains__(self, o: object) -> bool: - return any(item.match(o) for item in self) +class Matched: + """A matching result of a MIME string against a header.""" - def match( - self, - o: object, - *, - allow_type_wildcard: bool = True, - allow_subtype_wildcard: bool = True, - ) -> bool: - return any( - item.match( - o, - allow_type_wildcard=allow_type_wildcard, - allow_subtype_wildcard=allow_subtype_wildcard, + def __init__(self, mime: str, header: Optional[MediaType]): + self.mime = mime + self.header = header + + def __repr__(self): + return f"<{self} matched {self.header}>" if self else "" + + def __str__(self): + return self.mime + + def __bool__(self): + return self.header is not None + + def __eq__(self, other: Any) -> bool: + try: + comp, other_accept = self._compare(other) + except TypeError: + return False + + return bool( + comp + and ( + (self.header and other_accept.header) + or (not self.header and not other_accept.header) ) - for item in self ) + def _compare(self, other) -> Tuple[bool, Matched]: + if isinstance(other, str): + # return self.mime == other, Accept.parse(other) + parsed = Matched.parse(other) + if self.mime == other: + return True, parsed + other = parsed + + if isinstance(other, Matched): + return self.header == other.header, other + + raise TypeError( + "Comparison not supported between unequal " + f"mime types of '{self.mime}' and '{other}'" + ) + + def match(self, other: Union[str, Matched]) -> Optional[Matched]: + accept = Matched.parse(other) if isinstance(other, str) else other + if not self.header or not accept.header: + return None + if self.header.match(accept.header): + return accept + return None + + @classmethod + def parse(cls, raw: str) -> Matched: + media_type = MediaType._parse(raw) + return cls(raw, media_type) + + +class AcceptList(list): + """A list of media types, as used in the Accept header. + + The Accept header entries are listed in order of preference, starting + with the most preferred. This class is a list of `MediaType` objects, + that encapsulate also the q value or any other parameters. + + Two separate methods are provided for searching the list: + - 'match' for finding the most preferred match (wildcards supported) + - operator 'in' for checking explicit matches (wildcards as literals) + """ + + def match(self, *mimes: str, accept_wildcards=True) -> Matched: + """Find a media type accepted by the client. + + This method can be used to find which of the media types requested by + the client is most preferred against the ones given as arguments. + + The ordering of preference is set by: + 1. The order set by RFC 7231, s. 5.3.2, giving a higher priority + to q values and more specific type definitions, + 2. The order of the arguments (first is most preferred), and + 3. The first matching entry on the Accept header. + + Wildcards are matched both ways. A match is usually found, as the + Accept headers typically include `*/*`, in particular if the header + is missing, is not manually set, or if the client is a browser. + + Note: the returned object behaves as a string of the mime argument + that matched, and is empty/falsy if no match was found. The matched + header entry `MediaType` or `None` is available as the `m` attribute. + + @param mimes: Any MIME types to search for in order of preference. + @param accept_wildcards: Match Accept entries with wildcards in them. + @return A match object with the mime string and the MediaType object. + """ + a = sorted( + (-acc.q, i, j, mime, acc) + for j, acc in enumerate(self) + if accept_wildcards or not acc.has_wildcard + for i, mime in enumerate(mimes) + if acc.match(mime) + ) + return Matched(*(a[0][-2:] if a else ("", None))) + + def __str__(self): + """Format as Accept header value (parsed, not original).""" + return ", ".join(str(m) for m in self) + + +def parse_accept(accept: Optional[str]) -> AcceptList: + """Parse an Accept header and order the acceptable media types in + according to RFC 7231, s. 5.3.2 + https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 + """ + if not accept: + if accept == "": + return AcceptList() # Empty header, accept nothing + accept = "*/*" # No header means that all types are accepted + try: + a = [ + mt + for mt in [MediaType._parse(mtype) for mtype in accept.split(",")] + if mt + ] + if not a: + raise ValueError + return AcceptList(sorted(a, key=lambda x: x.key)) + except ValueError: + raise InvalidHeader(f"Invalid header value in Accept: {accept}") + def parse_content_header(value: str) -> Tuple[str, Options]: """Parse content-type and content-disposition header values. @@ -368,34 +432,6 @@ def format_http1_response(status: int, headers: HeaderBytesIterable) -> bytes: return ret -def _sort_accept_value(accept: Accept): - return ( - accept.qvalue, - len(accept.params), - accept.subtype != "*", - accept.type_ != "*", - ) - - -def parse_accept(accept: str) -> AcceptContainer: - """Parse an Accept header and order the acceptable media types in - accorsing to RFC 7231, s. 5.3.2 - https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 - """ - media_types = accept.split(",") - accept_list: List[Accept] = [] - - for mtype in media_types: - if not mtype: - continue - - accept_list.append(Accept.parse(mtype)) - - return AcceptContainer( - sorted(accept_list, key=_sort_accept_value, reverse=True) - ) - - def parse_credentials( header: Optional[str], prefixes: Union[List, Tuple, Set] = None, diff --git a/sanic/request.py b/sanic/request.py index bd7b5663..62ed6cf2 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -47,7 +47,7 @@ from sanic.constants import ( ) from sanic.exceptions import BadRequest, BadURL, ServerError from sanic.headers import ( - AcceptContainer, + AcceptList, Options, parse_accept, parse_content_header, @@ -167,7 +167,7 @@ class Request: self.conn_info: Optional[ConnInfo] = None self.ctx = SimpleNamespace() self.parsed_forwarded: Optional[Options] = None - self.parsed_accept: Optional[AcceptContainer] = None + self.parsed_accept: Optional[AcceptList] = None self.parsed_credentials: Optional[Credentials] = None self.parsed_json = None self.parsed_form: Optional[RequestParameters] = None @@ -499,10 +499,10 @@ class Request: return self.parsed_json @property - def accept(self) -> AcceptContainer: + def accept(self) -> AcceptList: """ :return: The ``Accept`` header parsed - :rtype: AcceptContainer + :rtype: AcceptList """ if self.parsed_accept is None: accept_header = self.headers.getone("accept", "") diff --git a/tests/test_asgi.py b/tests/test_asgi.py index fe9ff306..1b57436d 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -351,6 +351,7 @@ async def test_websocket_text_receive(send, receive, message_stack): assert text == msg["text"] + @pytest.mark.asyncio async def test_websocket_bytes_receive(send, receive, message_stack): msg = {"bytes": b"hello", "type": "websocket.receive"} @@ -361,6 +362,7 @@ async def test_websocket_bytes_receive(send, receive, message_stack): assert data == msg["bytes"] + @pytest.mark.asyncio async def test_websocket_accept_with_no_subprotocols( send, receive, message_stack diff --git a/tests/test_headers.py b/tests/test_headers.py index 7b2235ed..7f3a56dc 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -50,7 +50,10 @@ def raised_ceiling(): # cgi.parse_header: # ('form-data', {'name': 'files', 'filename': 'fo"o;bar\\'}) # werkzeug.parse_options_header: - # ('form-data', {'name': 'files', 'filename': '"fo\\"o', 'bar\\"': None}) + # ( + # "form-data", + # {"name": "files", "filename": '"fo\\"o', 'bar\\"': None}, + # ), ), # with Unicode filename! ( @@ -187,27 +190,24 @@ def test_request_line(app): @pytest.mark.parametrize( - "raw", + "raw,expected_subtype", ( - "show/first, show/second", - "show/*, show/first", - "*/*, show/first", - "*/*, show/*", - "other/*; q=0.1, show/*; q=0.2", - "show/first; q=0.5, show/second; q=0.5", - "show/first; foo=bar, show/second; foo=bar", - "show/second, show/first; foo=bar", - "show/second; q=0.5, show/first; foo=bar; q=0.5", - "show/second; q=0.5, show/first; q=1.0", - "show/first, show/second; q=1.0", + ("show/first, show/second", "first"), + ("show/*, show/first", "first"), + ("*/*, show/first", "first"), + ("*/*, show/*", "*"), + ("other/*; q=0.1, show/*; q=0.2", "*"), + ("show/first; q=0.5, show/second; q=0.5", "first"), + ("show/first; foo=bar, show/second; foo=bar", "first"), + ("show/second, show/first; foo=bar", "first"), + ("show/second; q=0.5, show/first; foo=bar; q=0.5", "first"), + ("show/second; q=0.5, show/first; q=1.0", "first"), + ("show/first, show/second; q=1.0", "second"), ), ) -def test_parse_accept_ordered_okay(raw): +def test_parse_accept_ordered_okay(raw, expected_subtype): ordered = headers.parse_accept(raw) - expected_subtype = ( - "*" if all(q.subtype.is_wildcard for q in ordered) else "first" - ) - assert ordered[0].type_ == "show" + assert ordered[0].type == "show" assert ordered[0].subtype == expected_subtype @@ -217,6 +217,7 @@ def test_parse_accept_ordered_okay(raw): "missing", "missing/", "/missing", + "/", ), ) def test_bad_accept(raw): @@ -225,128 +226,83 @@ def test_bad_accept(raw): def test_empty_accept(): - assert headers.parse_accept("") == [] + a = headers.parse_accept("") + assert a == [] + assert not a.match("*/*") def test_wildcard_accept_set_ok(): accept = headers.parse_accept("*/*")[0] - assert accept.type_.is_wildcard - assert accept.subtype.is_wildcard + assert accept.type == "*" + assert accept.subtype == "*" + assert accept.has_wildcard + + accept = headers.parse_accept("foo/*")[0] + assert accept.type == "foo" + assert accept.subtype == "*" + assert accept.has_wildcard accept = headers.parse_accept("foo/bar")[0] - assert not accept.type_.is_wildcard - assert not accept.subtype.is_wildcard + assert accept.type == "foo" + assert accept.subtype == "bar" + assert not accept.has_wildcard def test_accept_parsed_against_str(): - accept = headers.Accept.parse("foo/bar") - assert accept > "foo/bar; q=0.1" - - -def test_media_type_equality(): - assert headers.MediaType("foo") == headers.MediaType("foo") == "foo" - assert headers.MediaType("foo") == headers.MediaType("*") == "*" - assert headers.MediaType("foo") != headers.MediaType("bar") - assert headers.MediaType("foo") != "bar" + accept = headers.Matched.parse("foo/bar") + assert accept == "foo/bar; q=0.1" def test_media_type_matching(): - assert headers.MediaType("foo").match(headers.MediaType("foo")) - assert headers.MediaType("foo").match("foo") - - assert not headers.MediaType("foo").match(headers.MediaType("*")) - assert not headers.MediaType("foo").match("*") - - assert not headers.MediaType("foo").match(headers.MediaType("bar")) - assert not headers.MediaType("foo").match("bar") + assert headers.MediaType("foo", "bar").match( + headers.MediaType("foo", "bar") + ) + assert headers.MediaType("foo", "bar").match("foo/bar") @pytest.mark.parametrize( - "value,other,outcome,allow_type,allow_subtype", + "value,other,outcome", ( # ALLOW BOTH - ("foo/bar", "foo/bar", True, True, True), - ("foo/bar", headers.Accept.parse("foo/bar"), True, True, True), - ("foo/bar", "foo/*", True, True, True), - ("foo/bar", headers.Accept.parse("foo/*"), True, True, True), - ("foo/bar", "*/*", True, True, True), - ("foo/bar", headers.Accept.parse("*/*"), True, True, True), - ("foo/*", "foo/bar", True, True, True), - ("foo/*", headers.Accept.parse("foo/bar"), True, True, True), - ("foo/*", "foo/*", True, True, True), - ("foo/*", headers.Accept.parse("foo/*"), True, True, True), - ("foo/*", "*/*", True, True, True), - ("foo/*", headers.Accept.parse("*/*"), True, True, True), - ("*/*", "foo/bar", True, True, True), - ("*/*", headers.Accept.parse("foo/bar"), True, True, True), - ("*/*", "foo/*", True, True, True), - ("*/*", headers.Accept.parse("foo/*"), True, True, True), - ("*/*", "*/*", True, True, True), - ("*/*", headers.Accept.parse("*/*"), True, True, True), - # ALLOW TYPE - ("foo/bar", "foo/bar", True, True, False), - ("foo/bar", headers.Accept.parse("foo/bar"), True, True, False), - ("foo/bar", "foo/*", False, True, False), - ("foo/bar", headers.Accept.parse("foo/*"), False, True, False), - ("foo/bar", "*/*", False, True, False), - ("foo/bar", headers.Accept.parse("*/*"), False, True, False), - ("foo/*", "foo/bar", False, True, False), - ("foo/*", headers.Accept.parse("foo/bar"), False, True, False), - ("foo/*", "foo/*", False, True, False), - ("foo/*", headers.Accept.parse("foo/*"), False, True, False), - ("foo/*", "*/*", False, True, False), - ("foo/*", headers.Accept.parse("*/*"), False, True, False), - ("*/*", "foo/bar", False, True, False), - ("*/*", headers.Accept.parse("foo/bar"), False, True, False), - ("*/*", "foo/*", False, True, False), - ("*/*", headers.Accept.parse("foo/*"), False, True, False), - ("*/*", "*/*", False, True, False), - ("*/*", headers.Accept.parse("*/*"), False, True, False), - # ALLOW SUBTYPE - ("foo/bar", "foo/bar", True, False, True), - ("foo/bar", headers.Accept.parse("foo/bar"), True, False, True), - ("foo/bar", "foo/*", True, False, True), - ("foo/bar", headers.Accept.parse("foo/*"), True, False, True), - ("foo/bar", "*/*", False, False, True), - ("foo/bar", headers.Accept.parse("*/*"), False, False, True), - ("foo/*", "foo/bar", True, False, True), - ("foo/*", headers.Accept.parse("foo/bar"), True, False, True), - ("foo/*", "foo/*", True, False, True), - ("foo/*", headers.Accept.parse("foo/*"), True, False, True), - ("foo/*", "*/*", False, False, True), - ("foo/*", headers.Accept.parse("*/*"), False, False, True), - ("*/*", "foo/bar", False, False, True), - ("*/*", headers.Accept.parse("foo/bar"), False, False, True), - ("*/*", "foo/*", False, False, True), - ("*/*", headers.Accept.parse("foo/*"), False, False, True), - ("*/*", "*/*", False, False, True), - ("*/*", headers.Accept.parse("*/*"), False, False, True), + ("foo/bar", "foo/bar", True), + ("foo/bar", headers.Matched.parse("foo/bar"), True), + ("foo/bar", "foo/*", True), + ("foo/bar", headers.Matched.parse("foo/*"), True), + ("foo/bar", "*/*", True), + ("foo/bar", headers.Matched.parse("*/*"), True), + ("foo/*", "foo/bar", True), + ("foo/*", headers.Matched.parse("foo/bar"), True), + ("foo/*", "foo/*", True), + ("foo/*", headers.Matched.parse("foo/*"), True), + ("foo/*", "*/*", True), + ("foo/*", headers.Matched.parse("*/*"), True), + ("*/*", "foo/bar", True), + ("*/*", headers.Matched.parse("foo/bar"), True), + ("*/*", "foo/*", True), + ("*/*", headers.Matched.parse("foo/*"), True), + ("*/*", "*/*", True), + ("*/*", headers.Matched.parse("*/*"), True), ), ) -def test_accept_matching(value, other, outcome, allow_type, allow_subtype): - assert ( - headers.Accept.parse(value).match( - other, - allow_type_wildcard=allow_type, - allow_subtype_wildcard=allow_subtype, - ) - is outcome - ) +def test_accept_matching(value, other, outcome): + assert bool(headers.Matched.parse(value).match(other)) is outcome @pytest.mark.parametrize("value", ("foo/bar", "foo/*", "*/*")) def test_value_in_accept(value): acceptable = headers.parse_accept(value) - assert "foo/bar" in acceptable - assert "foo/*" in acceptable - assert "*/*" in acceptable + assert acceptable.match("foo/bar") + assert acceptable.match("foo/*") + assert acceptable.match("*/*") @pytest.mark.parametrize("value", ("foo/bar", "foo/*")) def test_value_not_in_accept(value): acceptable = headers.parse_accept(value) - assert "no/match" not in acceptable - assert "no/*" not in acceptable + assert not acceptable.match("no/match") + assert not acceptable.match("no/*") + assert "*/*" not in acceptable + assert "*/bar" not in acceptable @pytest.mark.parametrize( @@ -365,6 +321,117 @@ def test_value_not_in_accept(value): ), ), ) -def test_browser_headers(header, expected): +def test_browser_headers_general(header, expected): request = Request(b"/", {"accept": header}, "1.1", "GET", None, None) - assert request.accept == expected + assert [str(item) for item in request.accept] == expected + + +@pytest.mark.parametrize( + "header,expected", + ( + ( + "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", # noqa: E501 + [ + ("text/html", 1.0), + ("application/xhtml+xml", 1.0), + ("image/avif", 1.0), + ("image/webp", 1.0), + ("application/xml", 0.9), + ("*/*", 0.8), + ], + ), + ), +) +def test_browser_headers_specific(header, expected): + mimes = [e[0] for e in expected] + qs = [e[1] for e in expected] + request = Request(b"/", {"accept": header}, "1.1", "GET", None, None) + assert request.accept == mimes + for a, m, q in zip(request.accept, mimes, qs): + assert a == m + assert a.mime == m + assert a.q == q + + +@pytest.mark.parametrize( + "raw", + ( + "text/html, application/xhtml+xml, application/xml;q=0.9, */*;q=0.8", + "application/xml;q=0.9, */*;q=0.8, text/html, application/xhtml+xml", + ( + "foo/bar;q=0.9, */*;q=0.8, text/html=0.8, " + "text/plain, application/xhtml+xml" + ), + ), +) +def test_accept_ordering(raw): + """Should sort by q but also be stable.""" + accept = headers.parse_accept(raw) + assert accept[0].type == "text" + raw1 = ", ".join(str(a) for a in accept) + accept = headers.parse_accept(raw1) + raw2 = ", ".join(str(a) for a in accept) + assert raw1 == raw2 + + +def test_not_accept_wildcard(): + accept = headers.parse_accept("*/*, foo/*, */bar, foo/bar;q=0.1") + assert not accept.match( + "text/html", "foo/foo", "bar/bar", accept_wildcards=False + ) + # Should ignore wildcards in accept but still matches them from mimes + m = accept.match("text/plain", "*/*", accept_wildcards=False) + assert m.mime == "*/*" + assert m.match("*/*") + assert m.header == "foo/bar" + assert not accept.match( + "text/html", "foo/foo", "bar/bar", accept_wildcards=False + ) + + +def test_accept_misc(): + header = ( + "foo/bar;q=0.0, */plain;param=123, text/plain, text/*, foo/bar;q=0.5" + ) + a = headers.parse_accept(header) + assert repr(a) == ( + "[*/plain;param=123, text/plain, text/*, " + "foo/bar;q=0.5, foo/bar;q=0.0]" + ) # noqa: E501 + assert str(a) == ( + "*/plain;param=123, text/plain, text/*, " + "foo/bar;q=0.5, foo/bar;q=0.0" + ) # noqa: E501 + # q=1 types don't match foo/bar but match the two others, + # text/* comes first and matches */plain because it + # comes first in the header + m = a.match("foo/bar", "text/*", "text/plain") + assert repr(m) == "" + assert m == "text/*" + assert m.mime == "text/*" + assert m.header.mime == "*/plain" + assert m.header.type == "*" + assert m.header.subtype == "plain" + assert m.header.q == 1.0 + assert m.header.params == dict(param="123") + # Matches object against another Matched object (by mime and header) + assert m == a.match("text/*") + # Against unsupported type falls back to object id matching + assert m != 123 + # Matches the highest q value + m = a.match("foo/bar") + assert repr(m) == "" + assert m == "foo/bar" + assert m == "foo/bar;q=0.5" + # Matching nothing special case + m = a.match() + assert m == "" + assert m.header is None + # No header means anything + a = headers.parse_accept(None) + assert a == ["*/*"] + assert a.match("foo/bar") + # Empty header means nothing + a = headers.parse_accept("") + assert a == [] + assert not a.match("foo/bar") diff --git a/tests/test_request.py b/tests/test_request.py index 3a4132ae..a4757b52 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -156,7 +156,7 @@ def test_request_accept(): "Accept": "text/*, text/plain, text/plain;format=flowed, */*" }, ) - assert request.accept == [ + assert [str(i) for i in request.accept] == [ "text/plain;format=flowed", "text/plain", "text/*", @@ -171,11 +171,11 @@ def test_request_accept(): ) }, ) - assert request.accept == [ + assert [str(i) for i in request.accept] == [ "text/html", "text/x-c", - "text/x-dvi; q=0.8", - "text/plain; q=0.5", + "text/x-dvi;q=0.8", + "text/plain;q=0.5", ]