diff --git a/sanic/errorpages.py b/sanic/errorpages.py index c035dce1..5fe72ae8 100644 --- a/sanic/errorpages.py +++ b/sanic/errorpages.py @@ -403,7 +403,8 @@ RENDERERS_BY_CONTENT_TYPE = { CONTENT_TYPE_BY_RENDERERS = { v: k for k, v in RENDERERS_BY_CONTENT_TYPE.items() } - +# Handler source code is checked for which response types it returns +# If it returns (exactly) one of these, it will be used as render_format RESPONSE_MAPPING = { "empty": "html", "json": "json", @@ -436,98 +437,60 @@ def exception_response( """ Render a response for the default FALLBACK exception handler. """ - content_type = None - if not renderer: - # Make sure we have something set - renderer = base - render_format = fallback - - if request: - # If there is a request, try and get the format - # from the route - if request.route: - try: - if request.route.extra.error_format: - render_format = request.route.extra.error_format - except AttributeError: - ... - - content_type = request.headers.getone("content-type", "").split( - ";" - )[0] - - acceptable = request.accept - - # If the format is auto still, make a guess - if render_format == "auto": - # First, if there is an Accept header, check if text/html - # is the first option - # According to MDN Web Docs, all major browsers use text/html - # as the primary value in Accept (with the exception of IE 8, - # and, well, if you are supporting IE 8, then you have bigger - # problems to concern yourself with than what default exception - # renderer is used) - # Source: - # https://developer.mozilla.org/en-US/docs/Web/HTTP/Content_negotiation/List_of_default_Accept_values - - if acceptable and acceptable[0].match( - "text/html", - allow_type_wildcard=False, - allow_subtype_wildcard=False, - ): - renderer = HTMLRenderer - - # Second, if there is an Accept header, check if - # application/json is an option, or if the content-type - # is application/json - elif ( - acceptable - and acceptable.match( - "application/json", - allow_type_wildcard=False, - allow_subtype_wildcard=False, - ) - or content_type == "application/json" - ): - renderer = JSONRenderer - - # Third, if there is no Accept header, assume we want text. - # The likely use case here is a raw socket. - elif not acceptable: - renderer = TextRenderer - else: - # Fourth, look to see if there was a JSON body - # When in this situation, the request is probably coming - # from curl, an API client like Postman or Insomnia, or a - # package like requests or httpx - try: - # Give them the benefit of the doubt if they did: - # $ curl localhost:8000 -d '{"foo": "bar"}' - # And provide them with JSONRenderer - renderer = JSONRenderer if request.json else base - except BadRequest: - renderer = base - else: - renderer = RENDERERS_BY_CONFIG.get(render_format, renderer) - - # Lastly, if there is an Accept header, make sure - # our choice is okay - if acceptable: - type_ = CONTENT_TYPE_BY_RENDERERS.get(renderer) # type: ignore - if type_ and type_ not in acceptable: - # If the renderer selected is not in the Accept header - # look through what is in the Accept header, and select - # the first option that matches. Otherwise, just drop back - # to the original default - for accept in acceptable: - mtype = f"{accept.type_}/{accept.subtype}" - maybe = RENDERERS_BY_CONTENT_TYPE.get(mtype) - if maybe: - renderer = maybe - break - else: - renderer = base + renderer = _guess_renderer(request, fallback, base) renderer = t.cast(t.Type[BaseRenderer], renderer) return renderer(request, exception, debug).render() + + +def _guess_renderer(request: Request, fallback: str, base: t.Type[BaseRenderer]) -> t.Type[BaseRenderer]: + # base/fallback is app.config.FALLBACK_ERROR_FORMAT + render_format = fallback + if not request: + return base + + # Try the format from the route via RESPONSE_MAPPING + if request.route: + try: + if request.route.extra.error_format: + render_format = request.route.extra.error_format + except AttributeError: + pass + + # Do we need to look at the request itself? + if render_format == "auto": + # Use the Accept header to choose one + mediatype, accept_q = request.accept.choose(*RENDERERS_BY_CONTENT_TYPE) + if accept_q: + return RENDERERS_BY_CONTENT_TYPE[mediatype] + + # Otherwise, try JSON content type or request body + if not accept_q and "*/*" in request.accept and _check_json_content(request): + return JSONRenderer + + return base + + # Use the format from the route if it doesn't contradict the Accept header + renderer = RENDERERS_BY_CONFIG.get(render_format, base) + type_ = CONTENT_TYPE_BY_RENDERERS[renderer] # type: ignore + acceptable = not request.accept or request.accept.match(type_) + return renderer if acceptable else base + + +def _check_json_content(request: Request) -> bool: + content_type = request.headers.getone("content-type", "").split(";")[0] + if content_type == "application/json": + return True + # Look to see if there was a JSON body + # When in this situation, the request is probably coming + # from curl, an API client like Postman or Insomnia, or a + # package like requests or httpx + try: + # Give them the benefit of the doubt if they did: + # $ curl localhost:8000 -d '{"foo": "bar"}' + # And provide them with JSONRenderer + if request.json: return True + except BadRequest: + pass + return False diff --git a/sanic/headers.py b/sanic/headers.py index ef805642..b648762e 100644 --- a/sanic/headers.py +++ b/sanic/headers.py @@ -35,149 +35,72 @@ _host_re = re.compile( def parse_arg_as_accept(f): def func(self, other, *args, **kwargs): - if not isinstance(other, Accept) and other: - other = Accept.parse(other) + if not isinstance(other, MediaType) and other: + other = MediaType._parse(other) return f(self, other, *args, **kwargs) return func -class MediaType(str): - def __new__(cls, value: str): - return str.__new__(cls, value) - - def __init__(self, value: str) -> None: - self.value = value - self.is_wildcard = self.check_if_wildcard(value) - - def __eq__(self, other): - if self.is_wildcard: - return True - - if self.match(other): - return True - - other_is_wildcard = ( - other.is_wildcard - if isinstance(other, MediaType) - else self.check_if_wildcard(other) - ) - - return other_is_wildcard - - def match(self, other): - other_value = other.value if isinstance(other, MediaType) else other - return self.value == other_value - - @staticmethod - def check_if_wildcard(value): - return value == "*" - - -class Accept(str): - def __new__(cls, value: str, *args, **kwargs): - return str.__new__(cls, value) - +class MediaType: + """A media type, as used in the Accept header.""" def __init__( self, - value: str, - type_: MediaType, - subtype: MediaType, - *, - q: str = "1.0", - **kwargs: str, + type_: str, + subtype: str, + **params: str, ): - qvalue = float(q) - if qvalue > 1 or qvalue < 0: - raise InvalidHeader( - f"Accept header qvalue must be between 0 and 1, not: {qvalue}" - ) - self.value = value self.type_ = type_ self.subtype = subtype - self.qvalue = qvalue - self.params = kwargs + self.q = float(params.get("q", "1.0")) + self.params = params + self.str = f"{type_}/{subtype}" - def _compare(self, other, method): - try: - return method(self.qvalue, other.qvalue) - except (AttributeError, TypeError): - return NotImplemented + def __repr__(self): + return self.str + "".join(f";{k}={v}" for k, v in self.params.items()) - @parse_arg_as_accept - def __lt__(self, other: AcceptLike): - return self._compare(other, lambda s, o: s < o) + def __eq__(self, media_type: str): + """Check if the type and subtype match exactly.""" + return self.str == media_type - @parse_arg_as_accept - def __le__(self, other: AcceptLike): - return self._compare(other, lambda s, o: s <= o) - - @parse_arg_as_accept - def __eq__(self, other: AcceptLike): # type: ignore - return self._compare(other, lambda s, o: s == o) - - @parse_arg_as_accept - def __ge__(self, other: AcceptLike): - return self._compare(other, lambda s, o: s >= o) - - @parse_arg_as_accept - def __gt__(self, other: AcceptLike): - return self._compare(other, lambda s, o: s > o) - - @parse_arg_as_accept - def __ne__(self, other: AcceptLike): # type: ignore - return self._compare(other, lambda s, o: s != o) - - @parse_arg_as_accept def match( self, - other, - *, - allow_type_wildcard: bool = True, - allow_subtype_wildcard: bool = True, - ) -> bool: - type_match = ( - self.type_ == other.type_ - if allow_type_wildcard - else ( - self.type_.match(other.type_) - and not self.type_.is_wildcard - and not other.type_.is_wildcard - ) - ) - subtype_match = ( - self.subtype == other.subtype - if allow_subtype_wildcard - else ( - self.subtype.match(other.subtype) - and not self.subtype.is_wildcard - and not other.subtype.is_wildcard - ) - ) + media_type: str, + ) -> Optional[MediaType]: + """Check if this media type matches the given media type. - return type_match and subtype_match + Wildcards are supported both ways on both type and subtype. + + Note: Use the `==` operator instead to check for literal matches + without expanding wildcards. + + @param media_type: A type/subtype string to match. + @return `self` if the media types are compatible, else `None` + """ + mt = MediaType._parse(media_type) + return self if ( + # Subtype match + (self.subtype in (mt.subtype, "*") or mt.subtype == "*") + # Type match + and (self.type_ in (mt.type_, "*") or mt.type_ == "*") + ) else None @property def has_wildcard(self) -> bool: - return self.type_.is_wildcard or self.subtype.is_wildcard + """Return True if this media type has a wildcard in it.""" + return "*" in (self.subtype, self.type_) @property def is_wildcard(self) -> bool: - return self.type_.is_wildcard and self.subtype.is_wildcard + """Return True if this is the wildcard `*/*`""" + return self.type_ == "*" and self.subtype == "*" @classmethod - def parse(cls, raw: AcceptLike) -> Accept: - invalid = False + def _parse(cls, raw: AcceptLike) -> MediaType: mtype = raw.strip() - try: - media, *raw_params = mtype.split(";") - type_, subtype = media.split("/") - except ValueError: - invalid = True - - if invalid or not type_ or not subtype: - raise InvalidHeader(f"Header contains invalid Accept value: {raw}") + media, *raw_params = mtype.split(";") + type_, subtype = media.split("/", 1) params = dict( [ @@ -186,51 +109,81 @@ class Accept(str): ] ) - return cls(mtype, MediaType(type_), MediaType(subtype), **params) + return cls(type_.lstrip(), subtype.rstrip(), **params) -AcceptLike = Union[str, Accept] +class AcceptList(list): + """A list of media types, as used in the Accept header. + + The Accept header entries are listed in order of preference, starting + with the most preferred. This class is a list of `MediaType` objects, + that encapsulate also the q value or any other parameters. + + Three separate methods are provided for searching the list, for + different use cases. The first two match wildcards with anything, + while `in` and other operators handle wildcards as literal values. + + - `choose` for choosing one of its arguments to use in response. + - 'match' for the best MediaType of the accept header, or None. + - operator 'in' for checking explicit matches (wildcards as is). + """ + + def match(self, *media_types: List[str]) -> Optional[MediaType]: + """Find a media type accepted by the client. + + This method can be used to find which of the media types requested by + the client is most preferred while matching any of the arguments. + + Wildcards are supported. Most clients include */* as the last item in + their Accept header, so this method will always return a match unless + a custom header is used, but it may return a more specific match if + the client has requested any suitable types explicitly. + + @param media_types: Any type/subtype strings to find. + @return A matching `MediaType` or `None` if nothing matches. + """ + for accepted in self: + if any(accepted.match(mt) for mt in media_types): + return accepted -class AcceptContainer(list): - def __contains__(self, o: object) -> bool: - return any(item.match(o) for item in self) + def choose(self, *media_types: List[str], omit_wildcard=True) -> str: + """Choose a most suitable media type based on the Accept header. + + This is the recommended way to choose a response format based on the + Accept header. The q values and the order of the Accept header are + respected, and if due to wildcards multiple arguments match the same + accept header entry, the first one matching is returned. - def match( - self, - o: AcceptLike, - *, - allow_type_wildcard: bool = True, - allow_subtype_wildcard: bool = True, - ) -> bool: - return any( - item.match( - o, - allow_type_wildcard=allow_type_wildcard, - allow_subtype_wildcard=allow_subtype_wildcard, - ) - for item in self - ) + Should none of the arguments be acceptable, the first argument is + returned with the q value of 0.0 (i.e. the lowest possible). - def find_first( - self, - others: List[AcceptLike], - default: AcceptLike = "*/*", - ) -> Accept: - filtered = [ - accept - for accept in self - if any(accept.match(other) for other in others) - ] - if filtered: - return filtered[0] - return Accept.parse(default) + @param media_types: Any type/subtype strings to find. + @param omit_wildcard: Ignore full wildcard */* in the Accept header. + @return A tuple of one of the arguments and the q value of the match. + """ + # Find the preferred MediaType if any match + for accepted in self: + if omit_wildcard and accepted.is_wildcard: + continue + for mt in media_types: + if accepted.match(mt): + return mt, accepted.q + # Fall back to the first argument + return media_types[0], 0.0 + + +def parse_accept(accept: str) -> AcceptList: + """Parse an Accept header and order the acceptable media types in + accorsing to RFC 7231, s. 5.3.2 + https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 + """ + try: + accept_list = [MediaType._parse(mtype) for mtype in accept.split(",")] + return AcceptList(sorted(accept_list, key=lambda mtype: -mtype.q)) + except ValueError: + raise InvalidHeader(f"Invalid header value in Accept: {accept}") - @parse_arg_as_accept - def match_explicit(self, other: AcceptLike) -> bool: - return self.match( - other, allow_type_wildcard=False, allow_subtype_wildcard=False - ) def parse_content_header(value: str) -> Tuple[str, Options]: @@ -399,34 +352,6 @@ def format_http1_response(status: int, headers: HeaderBytesIterable) -> bytes: return ret -def _sort_accept_value(accept: Accept): - return ( - accept.qvalue, - len(accept.params), - accept.subtype != "*", - accept.type_ != "*", - ) - - -def parse_accept(accept: str) -> AcceptContainer: - """Parse an Accept header and order the acceptable media types in - accorsing to RFC 7231, s. 5.3.2 - https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 - """ - media_types = accept.split(",") - accept_list: List[Accept] = [] - - for mtype in media_types: - if not mtype: - continue - - accept_list.append(Accept.parse(mtype)) - - return AcceptContainer( - sorted(accept_list, key=_sort_accept_value, reverse=True) - ) - - def parse_credentials( header: Optional[str], prefixes: Union[List, Tuple, Set] = None, diff --git a/sanic/request.py b/sanic/request.py index 592869e5..c8639c41 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -47,7 +47,7 @@ from sanic.constants import ( ) from sanic.exceptions import BadRequest, BadURL, ServerError from sanic.headers import ( - AcceptContainer, + AcceptList, Options, parse_accept, parse_content_header, @@ -168,7 +168,7 @@ class Request: self.conn_info: Optional[ConnInfo] = None self.ctx = SimpleNamespace() self.parsed_forwarded: Optional[Options] = None - self.parsed_accept: Optional[AcceptContainer] = None + self.parsed_accept: Optional[AcceptList] = None self.parsed_credentials: Optional[Credentials] = None self.parsed_json = None self.parsed_form: Optional[RequestParameters] = None @@ -500,7 +500,7 @@ class Request: return self.parsed_json @property - def accept(self) -> AcceptContainer: + def accept(self) -> AcceptList: """ :return: The ``Accept`` header parsed :rtype: AcceptContainer diff --git a/tests/test_headers.py b/tests/test_headers.py index 7b2235ed..48dec620 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -239,7 +239,7 @@ def test_wildcard_accept_set_ok(): def test_accept_parsed_against_str(): - accept = headers.Accept.parse("foo/bar") + accept = headers.MediaType._parse("foo/bar") assert accept > "foo/bar; q=0.1" @@ -266,66 +266,66 @@ def test_media_type_matching(): ( # ALLOW BOTH ("foo/bar", "foo/bar", True, True, True), - ("foo/bar", headers.Accept.parse("foo/bar"), True, True, True), + ("foo/bar", headers.MediaType._parse("foo/bar"), True, True, True), ("foo/bar", "foo/*", True, True, True), - ("foo/bar", headers.Accept.parse("foo/*"), True, True, True), + ("foo/bar", headers.MediaType._parse("foo/*"), True, True, True), ("foo/bar", "*/*", True, True, True), - ("foo/bar", headers.Accept.parse("*/*"), True, True, True), + ("foo/bar", headers.MediaType._parse("*/*"), True, True, True), ("foo/*", "foo/bar", True, True, True), - ("foo/*", headers.Accept.parse("foo/bar"), True, True, True), + ("foo/*", headers.MediaType._parse("foo/bar"), True, True, True), ("foo/*", "foo/*", True, True, True), - ("foo/*", headers.Accept.parse("foo/*"), True, True, True), + ("foo/*", headers.MediaType._parse("foo/*"), True, True, True), ("foo/*", "*/*", True, True, True), - ("foo/*", headers.Accept.parse("*/*"), True, True, True), + ("foo/*", headers.MediaType._parse("*/*"), True, True, True), ("*/*", "foo/bar", True, True, True), - ("*/*", headers.Accept.parse("foo/bar"), True, True, True), + ("*/*", headers.MediaType._parse("foo/bar"), True, True, True), ("*/*", "foo/*", True, True, True), - ("*/*", headers.Accept.parse("foo/*"), True, True, True), + ("*/*", headers.MediaType._parse("foo/*"), True, True, True), ("*/*", "*/*", True, True, True), - ("*/*", headers.Accept.parse("*/*"), True, True, True), + ("*/*", headers.MediaType._parse("*/*"), True, True, True), # ALLOW TYPE ("foo/bar", "foo/bar", True, True, False), - ("foo/bar", headers.Accept.parse("foo/bar"), True, True, False), + ("foo/bar", headers.MediaType._parse("foo/bar"), True, True, False), ("foo/bar", "foo/*", False, True, False), - ("foo/bar", headers.Accept.parse("foo/*"), False, True, False), + ("foo/bar", headers.MediaType._parse("foo/*"), False, True, False), ("foo/bar", "*/*", False, True, False), - ("foo/bar", headers.Accept.parse("*/*"), False, True, False), + ("foo/bar", headers.MediaType._parse("*/*"), False, True, False), ("foo/*", "foo/bar", False, True, False), - ("foo/*", headers.Accept.parse("foo/bar"), False, True, False), + ("foo/*", headers.MediaType._parse("foo/bar"), False, True, False), ("foo/*", "foo/*", False, True, False), - ("foo/*", headers.Accept.parse("foo/*"), False, True, False), + ("foo/*", headers.MediaType._parse("foo/*"), False, True, False), ("foo/*", "*/*", False, True, False), - ("foo/*", headers.Accept.parse("*/*"), False, True, False), + ("foo/*", headers.MediaType._parse("*/*"), False, True, False), ("*/*", "foo/bar", False, True, False), - ("*/*", headers.Accept.parse("foo/bar"), False, True, False), + ("*/*", headers.MediaType._parse("foo/bar"), False, True, False), ("*/*", "foo/*", False, True, False), - ("*/*", headers.Accept.parse("foo/*"), False, True, False), + ("*/*", headers.MediaType._parse("foo/*"), False, True, False), ("*/*", "*/*", False, True, False), - ("*/*", headers.Accept.parse("*/*"), False, True, False), + ("*/*", headers.MediaType._parse("*/*"), False, True, False), # ALLOW SUBTYPE ("foo/bar", "foo/bar", True, False, True), - ("foo/bar", headers.Accept.parse("foo/bar"), True, False, True), + ("foo/bar", headers.MediaType._parse("foo/bar"), True, False, True), ("foo/bar", "foo/*", True, False, True), - ("foo/bar", headers.Accept.parse("foo/*"), True, False, True), + ("foo/bar", headers.MediaType._parse("foo/*"), True, False, True), ("foo/bar", "*/*", False, False, True), - ("foo/bar", headers.Accept.parse("*/*"), False, False, True), + ("foo/bar", headers.MediaType._parse("*/*"), False, False, True), ("foo/*", "foo/bar", True, False, True), - ("foo/*", headers.Accept.parse("foo/bar"), True, False, True), + ("foo/*", headers.MediaType._parse("foo/bar"), True, False, True), ("foo/*", "foo/*", True, False, True), - ("foo/*", headers.Accept.parse("foo/*"), True, False, True), + ("foo/*", headers.MediaType._parse("foo/*"), True, False, True), ("foo/*", "*/*", False, False, True), - ("foo/*", headers.Accept.parse("*/*"), False, False, True), + ("foo/*", headers.MediaType._parse("*/*"), False, False, True), ("*/*", "foo/bar", False, False, True), - ("*/*", headers.Accept.parse("foo/bar"), False, False, True), + ("*/*", headers.MediaType._parse("foo/bar"), False, False, True), ("*/*", "foo/*", False, False, True), - ("*/*", headers.Accept.parse("foo/*"), False, False, True), + ("*/*", headers.MediaType._parse("foo/*"), False, False, True), ("*/*", "*/*", False, False, True), - ("*/*", headers.Accept.parse("*/*"), False, False, True), + ("*/*", headers.MediaType._parse("*/*"), False, False, True), ), ) def test_accept_matching(value, other, outcome, allow_type, allow_subtype): assert ( - headers.Accept.parse(value).match( + headers.MediaType._parse(value).match( other, allow_type_wildcard=allow_type, allow_subtype_wildcard=allow_subtype,