From bffdb3b5c243f7fc2b534ea6bf3e91a40725a641 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= <98187+Tronic@users.noreply.github.com> Date: Mon, 20 Jan 2020 18:34:32 +0200 Subject: [PATCH] More robust response datatype handling (#1674) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * HTTP1 header formatting moved to headers.format_headers and rewritten. - New implementation is one line of code and twice faster than the old one. - Whole header block encoded to UTF-8 in one pass. - No longer supports custom encode method on header values. - Cookie objects now have __str__ in addition to encode, to work with this. * Linter * format_http1_response * Replace encode_body with faster implementation based on f-string. Benchmarks: def encode_body(data): try: # Try to encode it regularly return data.encode() except AttributeError: # Convert it to a str if you can't return str(data).encode() def encode_body2(data): return f"{data}".encode() def encode_body3(data): return str(data).encode() data_str, data_int = "foo", 123 %timeit encode_body(data_int) 928 ns ± 2.96 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) %timeit encode_body2(data_int) 280 ns ± 2.09 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) %timeit encode_body3(data_int) 387 ns ± 1.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) %timeit encode_body(data_str) 202 ns ± 1.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) %timeit encode_body2(data_str) 197 ns ± 0.507 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each) %timeit encode_body3(data_str) 313 ns ± 1.28 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) * Wtf linter * Content-type fixes. * Body encoding sanitation, first pass. - body/data type autodetection fixed. - do not repr(body).encode() bytes-ish values. - support __html__ and _repr_html_ in sanic.response.html(). * -to-str response autoconversion limited to sanic.response.text() only. * Workaround MyPy issue. * Add an empty line to make isort happy. * Add html test for __html__ and _repr_html_. * Remove StreamingHTTPResponse.get_headers helper function. * Add back HTTPResponse Keep-Alive removed by earlier merge or something. * Revert "Remove StreamingHTTPResponse.get_headers helper function." Tests depend on this otherwise useless function. This reverts commit 9651e6ae017b61bed6dd88af6631cdd6b01eb347. * Add deprecation warnings; instead of assert for wrong HTTP version, and for non-string response.text. * Add back missing import. * Avoid duplicate response header tweaking code. * Linter errors --- sanic/headers.py | 18 ++++++ sanic/response.py | 127 ++++++++++++++++++++--------------------- tests/test_requests.py | 37 +++++++++++- tests/test_response.py | 3 +- 4 files changed, 119 insertions(+), 66 deletions(-) diff --git a/sanic/headers.py b/sanic/headers.py index eef468f9..78140e83 100644 --- a/sanic/headers.py +++ b/sanic/headers.py @@ -3,6 +3,8 @@ import re from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from urllib.parse import unquote +from sanic.helpers import STATUS_CODES + HeaderIterable = Iterable[Tuple[str, Any]] # Values convertible to str Options = Dict[str, Union[int, str]] # key=value fields in various headers @@ -180,3 +182,19 @@ def format_http1(headers: HeaderIterable) -> bytes: - Values are converted into strings if necessary. """ return "".join(f"{name}: {val}\r\n" for name, val in headers).encode() + + +def format_http1_response( + status: int, headers: HeaderIterable, body=b"" +) -> bytes: + """Format a full HTTP/1.1 response. + + - If `body` is included, content-length must be specified in headers. + """ + headerbytes = format_http1(headers) + return b"HTTP/1.1 %d %b\r\n%b\r\n%b" % ( + status, + STATUS_CODES.get(status, b"UNKNOWN"), + headerbytes, + body, + ) diff --git a/sanic/response.py b/sanic/response.py index 60e6bb37..1857a4db 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -1,3 +1,5 @@ +import warnings + from functools import partial from mimetypes import guess_type from os import path @@ -5,8 +7,8 @@ from urllib.parse import quote_plus from sanic.compat import Header, open_async from sanic.cookies import CookieJar -from sanic.headers import format_http1 -from sanic.helpers import STATUS_CODES, has_message_body, remove_entity_headers +from sanic.headers import format_http1, format_http1_response +from sanic.helpers import has_message_body, remove_entity_headers try: @@ -21,12 +23,7 @@ except ImportError: class BaseHTTPResponse: def _encode_body(self, data): - try: - # Try to encode it regularly - return data.encode() - except AttributeError: - # Convert it to a str if you can't - return str(data).encode() + return data.encode() if hasattr(data, "encode") else data def _parse_headers(self): return format_http1(self.headers.items()) @@ -37,6 +34,37 @@ class BaseHTTPResponse: self._cookies = CookieJar(self.headers) return self._cookies + def get_headers( + self, + version="1.1", + keep_alive=False, + keep_alive_timeout=None, + body=b"", + ): + """.. deprecated:: 20.3: + This function is not public API and will be removed.""" + if version != "1.1": + warnings.warn( + "Only HTTP/1.1 is currently supported (got {version})", + DeprecationWarning, + ) + + # self.headers get priority over content_type + if self.content_type and "Content-Type" not in self.headers: + self.headers["Content-Type"] = self.content_type + + if keep_alive: + self.headers["Connection"] = "keep-alive" + if keep_alive_timeout is not None: + self.headers["Keep-Alive"] = keep_alive_timeout + else: + self.headers["Connection"] = "close" + + if self.status in (304, 412): + self.headers = remove_entity_headers(self.headers) + + return format_http1_response(self.status, self.headers.items(), body) + class StreamingHTTPResponse(BaseHTTPResponse): __slots__ = ( @@ -54,7 +82,7 @@ class StreamingHTTPResponse(BaseHTTPResponse): streaming_fn, status=200, headers=None, - content_type="text/plain", + content_type="text/plain; charset=utf-8", chunked=True, ): self.content_type = content_type @@ -67,10 +95,9 @@ class StreamingHTTPResponse(BaseHTTPResponse): async def write(self, data): """Writes a chunk of data to the streaming response. - :param data: bytes-ish data to be written. + :param data: str or bytes-ish data to be written. """ - if type(data) != bytes: - data = self._encode_body(data) + data = self._encode_body(data) if self.chunked: await self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data)) @@ -102,28 +129,11 @@ class StreamingHTTPResponse(BaseHTTPResponse): def get_headers( self, version="1.1", keep_alive=False, keep_alive_timeout=None ): - # This is all returned in a kind-of funky way - # We tried to make this as fast as possible in pure python - timeout_header = b"" - if keep_alive and keep_alive_timeout is not None: - timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout - if self.chunked and version == "1.1": self.headers["Transfer-Encoding"] = "chunked" self.headers.pop("Content-Length", None) - self.headers["Content-Type"] = self.headers.get( - "Content-Type", self.content_type - ) - headers = self._parse_headers() - status = STATUS_CODES.get(self.status, b"UNKNOWN RESPONSE") - return (b"HTTP/%b %d %b\r\n" b"%b" b"%b\r\n") % ( - version.encode(), - self.status, - status, - timeout_header, - headers, - ) + return super().get_headers(version, keep_alive, keep_alive_timeout) class HTTPResponse(BaseHTTPResponse): @@ -138,23 +148,12 @@ class HTTPResponse(BaseHTTPResponse): body_bytes=b"", ): self.content_type = content_type - - if body is not None: - self.body = self._encode_body(body) - else: - self.body = body_bytes - + self.body = body_bytes if body is None else self._encode_body(body) self.status = status self.headers = Header(headers or {}) self._cookies = None def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None): - # This is all returned in a kind-of funky way - # We tried to make this as fast as possible in pure python - timeout_header = b"" - if keep_alive and keep_alive_timeout is not None: - timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout - body = b"" if has_message_body(self.status): body = self.body @@ -162,26 +161,7 @@ class HTTPResponse(BaseHTTPResponse): "Content-Length", len(self.body) ) - # self.headers get priority over content_type - if self.content_type and "Content-Type" not in self.headers: - self.headers["Content-Type"] = self.content_type - - if self.status in (304, 412): - self.headers = remove_entity_headers(self.headers) - - headers = self._parse_headers() - status = STATUS_CODES.get(self.status, b"UNKNOWN RESPONSE") - return ( - b"HTTP/%b %d %b\r\n" b"Connection: %b\r\n" b"%b" b"%b\r\n" b"%b" - ) % ( - version.encode(), - self.status, - status, - b"keep-alive" if keep_alive else b"close", - timeout_header, - headers, - body, - ) + return self.get_headers(version, keep_alive, keep_alive_timeout, body) @property def cookies(self): @@ -206,7 +186,7 @@ def json( headers=None, content_type="application/json", dumps=json_dumps, - **kwargs + **kwargs, ): """ Returns response object with body in json format. @@ -235,6 +215,21 @@ def text( :param headers: Custom Headers. :param content_type: the content type (string) of the response """ + if not isinstance(body, str): + warnings.warn( + "Types other than str will be deprecated in future versions for" + f" response.text, got type {type(body).__name__})", + DeprecationWarning, + ) + # Type conversions are deprecated and quite b0rked but still supported for + # text() until applications get fixed. This try-except should be removed. + try: + # Avoid repr(body).encode() b0rkage for body that is already encoded. + # memoryview used only to test bytes-ishness. + with memoryview(body): + pass + except TypeError: + body = f"{body}" # no-op if body is already str return HTTPResponse( body, status=status, headers=headers, content_type=content_type ) @@ -263,10 +258,14 @@ def html(body, status=200, headers=None): """ Returns response object with body in html format. - :param body: Response data to be encoded. + :param body: str or bytes-ish, or an object with __html__ or _repr_html_. :param status: Response code. :param headers: Custom Headers. """ + if hasattr(body, "__html__"): + body = body.__html__() + elif hasattr(body, "_repr_html_"): + body = body._repr_html_() return HTTPResponse( body, status=status, diff --git a/tests/test_requests.py b/tests/test_requests.py index 93516e00..e5abb34f 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -11,7 +11,7 @@ import pytest from sanic import Blueprint, Sanic from sanic.exceptions import ServerError from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, Request, RequestParameters -from sanic.response import json, text +from sanic.response import html, json, text from sanic.testing import ASGI_HOST, HOST, PORT @@ -72,6 +72,41 @@ def test_text(app): assert response.text == "Hello" +def test_html(app): + class Foo: + def __html__(self): + return "

Foo

" + + def _repr_html_(self): + return "

Foo object repr

" + + class Bar: + def _repr_html_(self): + return "

Bar object repr

" + + @app.route("/") + async def handler(request): + return html("

Hello

") + + @app.route("/foo") + async def handler(request): + return html(Foo()) + + @app.route("/bar") + async def handler(request): + return html(Bar()) + + request, response = app.test_client.get("/") + assert response.content_type == "text/html; charset=utf-8" + assert response.text == "

Hello

" + + request, response = app.test_client.get("/foo") + assert response.text == "

Foo

" + + request, response = app.test_client.get("/bar") + assert response.text == "

Bar object repr

" + + @pytest.mark.asyncio async def test_text_asgi(app): @app.route("/") diff --git a/tests/test_response.py b/tests/test_response.py index 87bda1bf..ca508af7 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -20,6 +20,7 @@ from sanic.response import ( json, raw, stream, + text, ) from sanic.response import empty from sanic.server import HttpProtocol @@ -35,7 +36,7 @@ def test_response_body_not_a_string(app): @app.route("/hello") async def hello_route(request): - return HTTPResponse(body=random_num) + return text(random_num) request, response = app.test_client.get("/hello") assert response.text == str(random_num)