Merge branch 'headerformat' into bodybytes

This commit is contained in:
L. Kärkkäinen 2019-09-11 11:24:53 +03:00
commit 5e86cce761
3 changed files with 56 additions and 73 deletions

View File

@ -130,6 +130,10 @@ class Cookie(dict):
:return: Cookie encoded in a codec of choosing. :return: Cookie encoded in a codec of choosing.
:except: UnicodeEncodeError :except: UnicodeEncodeError
""" """
return str(self).encode(encoding)
def __str__(self):
"""Format as a Set-Cookie header value."""
output = ["%s=%s" % (self.key, _quote(self.value))] output = ["%s=%s" % (self.key, _quote(self.value))]
for key, value in self.items(): for key, value in self.items():
if key == "max-age": if key == "max-age":
@ -147,4 +151,4 @@ class Cookie(dict):
else: else:
output.append("%s=%s" % (self._keys[key], value)) output.append("%s=%s" % (self._keys[key], value))
return "; ".join(output).encode(encoding) return "; ".join(output)

View File

@ -1,9 +1,12 @@
import re import re
from typing import Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, Optional, Tuple
from urllib.parse import unquote from urllib.parse import unquote
from sanic.helpers import STATUS_CODES
HeaderIterable = Iterable[Tuple[str, Any]] # Values convertible to str
Options = Dict[str, str] # key=value fields in various headers Options = Dict[str, str] # key=value fields in various headers
OptionsIterable = Iterable[Tuple[str, str]] # May contain duplicate keys OptionsIterable = Iterable[Tuple[str, str]] # May contain duplicate keys
@ -165,3 +168,30 @@ def parse_host(host: str) -> Tuple[Optional[str], Optional[int]]:
return None, None return None, None
host, port = m.groups() host, port = m.groups()
return host.lower(), port and int(port) return host.lower(), port and int(port)
def format_http1(headers: HeaderIterable) -> bytes:
"""Convert a headers iterable into HTTP/1 header format.
- Outputs UTF-8 bytes where each header line ends with \\r\\n.
- Values are converted into strings if necessary.
"""
return "".join(f"{name}: {val}\r\n" for name, val in headers).encode()
def format_http1_response(
status: int, headers: HeaderIterable, body=b""
) -> bytes:
"""Format a full HTTP/1.1 response.
- If `body` is included, content-length must be specified in headers.
"""
headers = format_http1(headers)
if status == 200:
return b"HTTP/1.1 200 OK\r\n%b\r\n%b" % (headers, body)
return b"HTTP/1.1 %d %b\r\n%b\r\n%b" % (
status,
STATUS_CODES.get(status, b"UNKNOWN"),
headers,
body,
)

View File

@ -7,7 +7,8 @@ from aiofiles import open as open_async
from sanic.compat import Header from sanic.compat import Header
from sanic.cookies import CookieJar from sanic.cookies import CookieJar
from sanic.helpers import STATUS_CODES, has_message_body, remove_entity_headers from sanic.headers import format_http1, format_http1_response
from sanic.helpers import has_message_body, remove_entity_headers
try: try:
@ -22,28 +23,10 @@ except ImportError:
class BaseHTTPResponse: class BaseHTTPResponse:
def _encode_body(self, data): def _encode_body(self, data):
try: return f"{data}".encode()
# Try to encode it regularly
return data.encode()
except AttributeError:
# Convert it to a str if you can't
return str(data).encode()
def _parse_headers(self): def _parse_headers(self):
headers = b"" return format_http1(self.headers.items())
for name, value in self.headers.items():
try:
headers += b"%b: %b\r\n" % (
name.encode(),
value.encode("utf-8"),
)
except AttributeError:
headers += b"%b: %b\r\n" % (
str(name).encode(),
str(value).encode("utf-8"),
)
return headers
@property @property
def cookies(self): def cookies(self):
@ -68,7 +51,7 @@ class StreamingHTTPResponse(BaseHTTPResponse):
streaming_fn, streaming_fn,
status=200, status=200,
headers=None, headers=None,
content_type="text/plain", content_type="text/plain; charset=utf-8",
chunked=True, chunked=True,
): ):
self.content_type = content_type self.content_type = content_type
@ -116,33 +99,17 @@ class StreamingHTTPResponse(BaseHTTPResponse):
def get_headers( def get_headers(
self, version="1.1", keep_alive=False, keep_alive_timeout=None self, version="1.1", keep_alive=False, keep_alive_timeout=None
): ):
# This is all returned in a kind-of funky way if "Content-Type" not in self.headers:
# We tried to make this as fast as possible in pure python self.headers["Content-Type"] = self.content_type
timeout_header = b""
if keep_alive and keep_alive_timeout is not None: if keep_alive and keep_alive_timeout is not None:
timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout self.headers["Keep-Alive"] = keep_alive_timeout
if self.chunked and version == "1.1": if self.chunked and version == "1.1":
self.headers["Transfer-Encoding"] = "chunked" self.headers["Transfer-Encoding"] = "chunked"
self.headers.pop("Content-Length", None) self.headers.pop("Content-Length", None)
self.headers["Content-Type"] = self.headers.get(
"Content-Type", self.content_type
)
headers = self._parse_headers() return format_http1_response(self.status, self.headers.items())
if self.status == 200:
status = b"OK"
else:
status = STATUS_CODES.get(self.status)
return (b"HTTP/%b %d %b\r\n" b"%b" b"%b\r\n") % (
version.encode(),
self.status,
status,
timeout_header,
headers,
)
class HTTPResponse(BaseHTTPResponse): class HTTPResponse(BaseHTTPResponse):
@ -153,7 +120,7 @@ class HTTPResponse(BaseHTTPResponse):
body=None, body=None,
status=200, status=200,
headers=None, headers=None,
content_type="text/plain", content_type="text/plain; charset=utf-8",
body_bytes=b"", body_bytes=b"",
): ):
self.content_type = content_type self.content_type = content_type
@ -168,11 +135,8 @@ class HTTPResponse(BaseHTTPResponse):
self._cookies = None self._cookies = None
def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None): def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
# This is all returned in a kind-of funky way if "Content-Type" not in self.headers:
# We tried to make this as fast as possible in pure python self.headers["Content-Type"] = self.content_type
timeout_header = b""
if keep_alive and keep_alive_timeout is not None:
timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout
body = b"" body = b""
if has_message_body(self.status): if has_message_body(self.status):
@ -181,31 +145,16 @@ class HTTPResponse(BaseHTTPResponse):
"Content-Length", len(self.body) "Content-Length", len(self.body)
) )
self.headers["Content-Type"] = self.headers.get(
"Content-Type", self.content_type
)
if self.status in (304, 412): if self.status in (304, 412):
self.headers = remove_entity_headers(self.headers) self.headers = remove_entity_headers(self.headers)
headers = self._parse_headers() if keep_alive and keep_alive_timeout is not None:
self.headers["Connection"] = "keep-alive"
self.headers["Keep-Alive"] = keep_alive_timeout
elif not keep_alive:
self.headers["Connection"] = "close"
if self.status == 200: return format_http1_response(self.status, self.headers.items(), body)
status = b"OK"
else:
status = STATUS_CODES.get(self.status, b"UNKNOWN RESPONSE")
return (
b"HTTP/%b %d %b\r\n" b"Connection: %b\r\n" b"%b" b"%b\r\n" b"%b"
) % (
version.encode(),
self.status,
status,
b"keep-alive" if keep_alive else b"close",
timeout_header,
headers,
body,
)
@property @property
def cookies(self): def cookies(self):
@ -220,7 +169,7 @@ def json(
headers=None, headers=None,
content_type="application/json", content_type="application/json",
dumps=json_dumps, dumps=json_dumps,
**kwargs **kwargs,
): ):
""" """
Returns response object with body in json format. Returns response object with body in json format.