diff --git a/sanic/exceptions.py b/sanic/exceptions.py index ad2205dd..61077cea 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -126,8 +126,11 @@ class HeaderNotFound(InvalidUsage): **Status**: 400 Bad Request """ - status_code = 400 - quiet = True + +class InvalidHeader(InvalidUsage): + """ + **Status**: 400 Bad Request + """ class ContentRangeError(SanicException): diff --git a/sanic/headers.py b/sanic/headers.py index 66427442..cc05f8e0 100644 --- a/sanic/headers.py +++ b/sanic/headers.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import re from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from urllib.parse import unquote +from sanic.exceptions import InvalidHeader from sanic.helpers import STATUS_CODES @@ -30,6 +33,154 @@ _host_re = re.compile( # For more information, consult ../tests/test_requests.py +def parse_arg_as_accept(f): + def func(self, other, *args, **kwargs): + if not isinstance(other, Accept): + other = Accept.parse(other) + return f(self, other, *args, **kwargs) + + return func + + +class MediaType(str): + def __new__(cls, value: str): + return str.__new__(cls, value) + + def __init__(self, value: str) -> None: + self.value = value + self.is_wildcard = self.check_if_wildcard(value) + + def __eq__(self, other): + if self.is_wildcard: + return True + + if self.match(other): + return True + + other_is_wildcard = ( + other.is_wildcard + if isinstance(other, MediaType) + else self.check_if_wildcard(other) + ) + + return other_is_wildcard + + def match(self, other): + other_value = other.value if isinstance(other, MediaType) else other + return self.value == other_value + + @staticmethod + def check_if_wildcard(value): + return value == "*" + + +class Accept(str): + def __new__(cls, value: str, *args, **kwargs): + return str.__new__(cls, value) + + def __init__( + self, + value: str, + type_: MediaType, + subtype: MediaType, + *, + q: str = "1.0", + **kwargs: str, + ): + qvalue = float(q) + if qvalue > 1 or qvalue < 0: + raise InvalidHeader( + f"Accept header qvalue must be between 0 and 1, not: {qvalue}" + ) + self.value = value + self.type_ = type_ + self.subtype = subtype + self.qvalue = qvalue + self.params = kwargs + + def _compare(self, other, method): + try: + return method(self.qvalue, other.qvalue) + except (AttributeError, TypeError): + return NotImplemented + + @parse_arg_as_accept + def __lt__(self, other: Union[str, Accept]): + return self._compare(other, lambda s, o: s < o) + + @parse_arg_as_accept + def __le__(self, other: Union[str, Accept]): + return self._compare(other, lambda s, o: s <= o) + + @parse_arg_as_accept + def __eq__(self, other: Union[str, Accept]): # type: ignore + return self._compare(other, lambda s, o: s == o) + + @parse_arg_as_accept + def __ge__(self, other: Union[str, Accept]): + return self._compare(other, lambda s, o: s >= o) + + @parse_arg_as_accept + def __gt__(self, other: Union[str, Accept]): + return self._compare(other, lambda s, o: s > o) + + @parse_arg_as_accept + def __ne__(self, other: Union[str, Accept]): # type: ignore + return self._compare(other, lambda s, o: s != o) + + @parse_arg_as_accept + def match( + self, + other, + *, + allow_type_wildcard: bool = True, + allow_subtype_wildcard: bool = True, + ) -> bool: + type_match = ( + self.type_ == other.type_ + if allow_type_wildcard + else ( + self.type_.match(other.type_) + and not self.type_.is_wildcard + and not other.type_.is_wildcard + ) + ) + subtype_match = ( + self.subtype == other.subtype + if allow_subtype_wildcard + else ( + self.subtype.match(other.subtype) + and not self.subtype.is_wildcard + and not other.subtype.is_wildcard + ) + ) + + return type_match and subtype_match + + @classmethod + def parse(cls, raw: str) -> Accept: + invalid = False + mtype = raw.strip() + + try: + media, *raw_params = mtype.split(";") + type_, subtype = media.split("/") + except ValueError: + invalid = True + + if invalid or not type_ or not subtype: + raise InvalidHeader(f"Header contains invalid Accept value: {raw}") + + params = dict( + [ + (key.strip(), value.strip()) + for key, value in (param.split("=", 1) for param in raw_params) + ] + ) + + return cls(mtype, MediaType(type_), MediaType(subtype), **params) + + def parse_content_header(value: str) -> Tuple[str, Options]: """Parse content-type and content-disposition header values. @@ -194,3 +345,29 @@ def format_http1_response(status: int, headers: HeaderBytesIterable) -> bytes: ret += b"%b: %b\r\n" % h ret += b"\r\n" return ret + + +def _sort_accept_value(accept: Accept): + return ( + accept.qvalue, + len(accept.params), + accept.subtype != "*", + accept.type_ != "*", + ) + + +def parse_accept(accept: str) -> List[Accept]: + """Parse an Accept header and order the acceptable media types in + accorsing to RFC 7231, s. 5.3.2 + https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 + """ + media_types = accept.split(",") + accept_list: List[Accept] = [] + + for mtype in media_types: + if not mtype: + continue + + accept_list.append(Accept.parse(mtype)) + + return sorted(accept_list, key=_sort_accept_value, reverse=True) diff --git a/sanic/request.py b/sanic/request.py index e37de36a..7f94de2d 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -34,7 +34,9 @@ from sanic.compat import CancelledErrors, Header from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE from sanic.exceptions import InvalidUsage from sanic.headers import ( + Accept, Options, + parse_accept, parse_content_header, parse_forwarded, parse_host, @@ -94,6 +96,7 @@ class Request: "head", "headers", "method", + "parsed_accept", "parsed_args", "parsed_not_grouped_args", "parsed_files", @@ -136,6 +139,7 @@ class Request: self.conn_info: Optional[ConnInfo] = None self.ctx = SimpleNamespace() self.parsed_forwarded: Optional[Options] = None + self.parsed_accept: Optional[List[Accept]] = None self.parsed_json = None self.parsed_form = None self.parsed_files = None @@ -296,6 +300,13 @@ class Request: return self.parsed_json + @property + def accept(self) -> List[Accept]: + if self.parsed_accept is None: + accept_header = self.headers.getone("accept", "") + self.parsed_accept = parse_accept(accept_header) + return self.parsed_accept + @property def token(self): """Attempt to return the auth header token. diff --git a/setup.py b/setup.py index 9341eae8..d65703ac 100644 --- a/setup.py +++ b/setup.py @@ -110,7 +110,7 @@ tests_require = [ "mypy>=0.901", "docutils", "pygments", - "uvicorn", + "uvicorn<0.15.0", types_ujson, ] diff --git a/tests/test_headers.py b/tests/test_headers.py index 546a9ef7..928847e0 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -3,7 +3,7 @@ from unittest.mock import Mock import pytest from sanic import headers, text -from sanic.exceptions import PayloadTooLarge +from sanic.exceptions import InvalidHeader, PayloadTooLarge from sanic.http import Http @@ -182,3 +182,159 @@ def test_request_line(app): ) assert request.request_line == b"GET / HTTP/1.1" + + +@pytest.mark.parametrize( + "raw", + ( + "show/first, show/second", + "show/*, show/first", + "*/*, show/first", + "*/*, show/*", + "other/*; q=0.1, show/*; q=0.2", + "show/first; q=0.5, show/second; q=0.5", + "show/first; foo=bar, show/second; foo=bar", + "show/second, show/first; foo=bar", + "show/second; q=0.5, show/first; foo=bar; q=0.5", + "show/second; q=0.5, show/first; q=1.0", + "show/first, show/second; q=1.0", + ), +) +def test_parse_accept_ordered_okay(raw): + ordered = headers.parse_accept(raw) + expected_subtype = ( + "*" if all(q.subtype.is_wildcard for q in ordered) else "first" + ) + assert ordered[0].type_ == "show" + assert ordered[0].subtype == expected_subtype + + +@pytest.mark.parametrize( + "raw", + ( + "missing", + "missing/", + "/missing", + ), +) +def test_bad_accept(raw): + with pytest.raises(InvalidHeader): + headers.parse_accept(raw) + + +def test_empty_accept(): + assert headers.parse_accept("") == [] + + +def test_wildcard_accept_set_ok(): + accept = headers.parse_accept("*/*")[0] + assert accept.type_.is_wildcard + assert accept.subtype.is_wildcard + + accept = headers.parse_accept("foo/bar")[0] + assert not accept.type_.is_wildcard + assert not accept.subtype.is_wildcard + + +def test_accept_parsed_against_str(): + accept = headers.Accept.parse("foo/bar") + assert accept > "foo/bar; q=0.1" + + +def test_media_type_equality(): + assert headers.MediaType("foo") == headers.MediaType("foo") == "foo" + assert headers.MediaType("foo") == headers.MediaType("*") == "*" + assert headers.MediaType("foo") != headers.MediaType("bar") + assert headers.MediaType("foo") != "bar" + + +def test_media_type_matching(): + assert headers.MediaType("foo").match(headers.MediaType("foo")) + assert headers.MediaType("foo").match("foo") + + assert not headers.MediaType("foo").match(headers.MediaType("*")) + assert not headers.MediaType("foo").match("*") + + assert not headers.MediaType("foo").match(headers.MediaType("bar")) + assert not headers.MediaType("foo").match("bar") + + +@pytest.mark.parametrize( + "value,other,outcome,allow_type,allow_subtype", + ( + # ALLOW BOTH + ("foo/bar", "foo/bar", True, True, True), + ("foo/bar", headers.Accept.parse("foo/bar"), True, True, True), + ("foo/bar", "foo/*", True, True, True), + ("foo/bar", headers.Accept.parse("foo/*"), True, True, True), + ("foo/bar", "*/*", True, True, True), + ("foo/bar", headers.Accept.parse("*/*"), True, True, True), + ("foo/*", "foo/bar", True, True, True), + ("foo/*", headers.Accept.parse("foo/bar"), True, True, True), + ("foo/*", "foo/*", True, True, True), + ("foo/*", headers.Accept.parse("foo/*"), True, True, True), + ("foo/*", "*/*", True, True, True), + ("foo/*", headers.Accept.parse("*/*"), True, True, True), + ("*/*", "foo/bar", True, True, True), + ("*/*", headers.Accept.parse("foo/bar"), True, True, True), + ("*/*", "foo/*", True, True, True), + ("*/*", headers.Accept.parse("foo/*"), True, True, True), + ("*/*", "*/*", True, True, True), + ("*/*", headers.Accept.parse("*/*"), True, True, True), + # ALLOW TYPE + ("foo/bar", "foo/bar", True, True, False), + ("foo/bar", headers.Accept.parse("foo/bar"), True, True, False), + ("foo/bar", "foo/*", False, True, False), + ("foo/bar", headers.Accept.parse("foo/*"), False, True, False), + ("foo/bar", "*/*", False, True, False), + ("foo/bar", headers.Accept.parse("*/*"), False, True, False), + ("foo/*", "foo/bar", False, True, False), + ("foo/*", headers.Accept.parse("foo/bar"), False, True, False), + ("foo/*", "foo/*", False, True, False), + ("foo/*", headers.Accept.parse("foo/*"), False, True, False), + ("foo/*", "*/*", False, True, False), + ("foo/*", headers.Accept.parse("*/*"), False, True, False), + ("*/*", "foo/bar", False, True, False), + ("*/*", headers.Accept.parse("foo/bar"), False, True, False), + ("*/*", "foo/*", False, True, False), + ("*/*", headers.Accept.parse("foo/*"), False, True, False), + ("*/*", "*/*", False, True, False), + ("*/*", headers.Accept.parse("*/*"), False, True, False), + # ALLOW SUBTYPE + ("foo/bar", "foo/bar", True, False, True), + ("foo/bar", headers.Accept.parse("foo/bar"), True, False, True), + ("foo/bar", "foo/*", True, False, True), + ("foo/bar", headers.Accept.parse("foo/*"), True, False, True), + ("foo/bar", "*/*", False, False, True), + ("foo/bar", headers.Accept.parse("*/*"), False, False, True), + ("foo/*", "foo/bar", True, False, True), + ("foo/*", headers.Accept.parse("foo/bar"), True, False, True), + ("foo/*", "foo/*", True, False, True), + ("foo/*", headers.Accept.parse("foo/*"), True, False, True), + ("foo/*", "*/*", False, False, True), + ("foo/*", headers.Accept.parse("*/*"), False, False, True), + ("*/*", "foo/bar", False, False, True), + ("*/*", headers.Accept.parse("foo/bar"), False, False, True), + ("*/*", "foo/*", False, False, True), + ("*/*", headers.Accept.parse("foo/*"), False, False, True), + ("*/*", "*/*", False, False, True), + ("*/*", headers.Accept.parse("*/*"), False, False, True), + ), +) +def test_accept_matching(value, other, outcome, allow_type, allow_subtype): + assert ( + headers.Accept.parse(value).match( + other, + allow_type_wildcard=allow_type, + allow_subtype_wildcard=allow_subtype, + ) + is outcome + ) + + +@pytest.mark.parametrize("value", ("foo/bar", "foo/*", "*/*")) +def test_value_in_accept(value): + acceptable = headers.parse_accept(value) + assert "foo/bar" in acceptable + assert "foo/*" in acceptable + assert "*/*" in acceptable diff --git a/tests/test_request.py b/tests/test_request.py index e4b21f66..ca2c1e4a 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -140,3 +140,39 @@ def test_ipv6_address_is_not_wrapped(app): assert resp.json["client"] == "[::1]" assert resp.json["client_ip"] == "::1" assert request.ip == "::1" + + +def test_request_accept(): + app = Sanic("req-generator") + + @app.get("/") + async def get(request): + return response.empty() + + request, _ = app.test_client.get( + "/", + headers={ + "Accept": "text/*, text/plain, text/plain;format=flowed, */*" + }, + ) + assert request.accept == [ + "text/plain;format=flowed", + "text/plain", + "text/*", + "*/*", + ] + + request, _ = app.test_client.get( + "/", + headers={ + "Accept": ( + "text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c" + ) + }, + ) + assert request.accept == [ + "text/html", + "text/x-c", + "text/x-dvi; q=0.8", + "text/plain; q=0.5", + ]