From 5351cda979540534fb75073dfa73f5fffe6d8ce9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Sun, 8 Mar 2020 17:33:56 +0200 Subject: [PATCH] Unify response header processing of ASGI and asyncio modes. --- sanic/asgi.py | 2 +- sanic/headers.py | 20 ++++---------------- sanic/http.py | 9 ++------- sanic/response.py | 30 ++++++++++++++---------------- 4 files changed, 21 insertions(+), 40 deletions(-) diff --git a/sanic/asgi.py b/sanic/asgi.py index 8f379358..59d91990 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -281,7 +281,7 @@ class ASGIApp: await self.transport.send({ "type": "http.response.start", "status": response.status, - "headers": response.full_headers, + "headers": response.processed_headers, }) response_body = getattr(response, "body", None) if response_body: diff --git a/sanic/headers.py b/sanic/headers.py index 78140e83..74a797ee 100644 --- a/sanic/headers.py +++ b/sanic/headers.py @@ -7,6 +7,7 @@ from sanic.helpers import STATUS_CODES HeaderIterable = Iterable[Tuple[str, Any]] # Values convertible to str +HeaderBytesIterable = Iterable[Tuple[bytes, bytes]] Options = Dict[str, Union[int, str]] # key=value fields in various headers OptionsIterable = Iterable[Tuple[str, str]] # May contain duplicate keys @@ -175,26 +176,13 @@ def parse_host(host: str) -> Tuple[Optional[str], Optional[int]]: return host.lower(), int(port) if port is not None else None -def format_http1(headers: HeaderIterable) -> bytes: - """Convert a headers iterable into HTTP/1 header format. - - - Outputs UTF-8 bytes where each header line ends with \\r\\n. - - Values are converted into strings if necessary. - """ - return "".join(f"{name}: {val}\r\n" for name, val in headers).encode() - - def format_http1_response( - status: int, headers: HeaderIterable, body=b"" + status: int, headers: HeaderBytesIterable, body=b"" ) -> bytes: - """Format a full HTTP/1.1 response. - - - If `body` is included, content-length must be specified in headers. - """ - headerbytes = format_http1(headers) + """Format a full HTTP/1.1 response.""" return b"HTTP/1.1 %d %b\r\n%b\r\n%b" % ( status, STATUS_CODES.get(status, b"UNKNOWN"), - headerbytes, + b"".join(b"%b: %b\r\n" % h for h in headers), body, ) diff --git a/sanic/http.py b/sanic/http.py index 97cdb0a8..29821b07 100644 --- a/sanic/http.py +++ b/sanic/http.py @@ -10,7 +10,7 @@ from sanic.exceptions import ( ServiceUnavailable, ) from sanic.headers import format_http1_response -from sanic.helpers import has_message_body, remove_entity_headers +from sanic.helpers import has_message_body from sanic.log import access_logger, logger @@ -182,14 +182,9 @@ class Http: data, end_stream = res.body, True size = len(data) headers = res.headers - if res.content_type and "content-type" not in headers: - headers["content-type"] = res.content_type status = res.status if not isinstance(status, int) or status < 200: raise RuntimeError(f"Invalid response status {status!r}") - # Not Modified, Precondition Failed - if status in (304, 412): - headers = remove_entity_headers(headers) if not has_message_body(status): # Header-only response status self.response_func = None @@ -227,7 +222,7 @@ class Http: data = b"" self.response_func = self.head_response_ignored headers["connection"] = "keep-alive" if self.keep_alive else "close" - ret = format_http1_response(status, headers.items(), data) + ret = format_http1_response(status, res.processed_headers, data) # Send a 100-continue if expected and not Expectation Failed if self.expecting_continue: self.expecting_continue = False diff --git a/sanic/response.py b/sanic/response.py index da6929e6..2cc74e1e 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -7,8 +7,7 @@ from urllib.parse import quote_plus from sanic.compat import Header, open_async from sanic.cookies import CookieJar -from sanic.headers import format_http1 -from sanic.helpers import has_message_body +from sanic.helpers import has_message_body, remove_entity_headers try: @@ -25,9 +24,6 @@ class BaseHTTPResponse: def _encode_body(self, data): return data.encode() if hasattr(data, "encode") else data - def _parse_headers(self): - return format_http1(self.full_headers) - @property def cookies(self): if self._cookies is None: @@ -35,14 +31,22 @@ class BaseHTTPResponse: return self._cookies @property - def full_headers(self): - """Obtain an encoded tuple of headers for a response to be sent.""" + def processed_headers(self): + """Obtain a list of header tuples encoded in bytes for sending. + + Add and remove headers based on status and content_type. + """ headers = [] cookies = {} - if self.content_type and not "content-type" in self.headers: - headers += (b"content-type", self.content_type.encode()), + status = self.status + # TODO: Make a blacklist set of header names and then filter with that + if status in (304, 412): # Not Modified, Precondition Failed + self.headers = remove_entity_headers(self.headers) + if has_message_body(status): + if self.content_type and not "content-type" in self.headers: + headers += (b"content-type", self.content_type.encode()), for name, value in self.headers.items(): - name = f"{name}" + name = f"{name}".lower() if name.lower() == "set-cookie": cookies[value.key] = value else: @@ -126,12 +130,6 @@ class HTTPResponse(BaseHTTPResponse): self.headers = Header(headers or {}) self._cookies = None - @property - def cookies(self): - if self._cookies is None: - self._cookies = CookieJar(self.headers) - return self._cookies - def empty(status=204, headers=None): """