Explicit usage of CIMultiDict getters (#2104)

This commit is contained in:
ENT8R 2021-04-08 12:30:12 +02:00 committed by GitHub
parent 1a352ddf55
commit ad97cac313
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 27 additions and 18 deletions

View File

@ -673,7 +673,9 @@ class Sanic(BaseSanic):
try: try:
# Fetch handler from router # Fetch handler from router
route, handler, kwargs = self.router.get( route, handler, kwargs = self.router.get(
request.path, request.method, request.headers.get("host") request.path,
request.method,
request.headers.getone("host", None),
) )
request._match_info = kwargs request._match_info = kwargs

View File

@ -366,7 +366,7 @@ def exception_response(
except InvalidUsage: except InvalidUsage:
renderer = HTMLRenderer renderer = HTMLRenderer
content_type, *_ = request.headers.get( content_type, *_ = request.headers.getone(
"content-type", "" "content-type", ""
).split(";") ).split(";")
renderer = RENDERERS_BY_CONTENT_TYPE.get( renderer = RENDERERS_BY_CONTENT_TYPE.get(

View File

@ -165,7 +165,7 @@ class ContentRangeHandler:
def __init__(self, request, stats): def __init__(self, request, stats):
self.total = stats.st_size self.total = stats.st_size
_range = request.headers.get("Range") _range = request.headers.getone("range", None)
if _range is None: if _range is None:
raise HeaderNotFound("Range Header Not Found") raise HeaderNotFound("Range Header Not Found")
unit, _, value = tuple(map(str.strip, _range.partition("="))) unit, _, value = tuple(map(str.strip, _range.partition("=")))

View File

@ -102,7 +102,7 @@ def parse_xforwarded(headers, config) -> Optional[Options]:
"""Parse traditional proxy headers.""" """Parse traditional proxy headers."""
real_ip_header = config.REAL_IP_HEADER real_ip_header = config.REAL_IP_HEADER
proxies_count = config.PROXIES_COUNT proxies_count = config.PROXIES_COUNT
addr = real_ip_header and headers.get(real_ip_header) addr = real_ip_header and headers.getone(real_ip_header, None)
if not addr and proxies_count: if not addr and proxies_count:
assert proxies_count > 0 assert proxies_count > 0
try: try:
@ -131,7 +131,7 @@ def parse_xforwarded(headers, config) -> Optional[Options]:
("port", "x-forwarded-port"), ("port", "x-forwarded-port"),
("path", "x-forwarded-path"), ("path", "x-forwarded-path"),
): ):
yield key, headers.get(header) yield key, headers.getone(header, None)
return fwd_normalize(options()) return fwd_normalize(options())

View File

@ -219,7 +219,7 @@ class Http:
headers_instance = Header(headers) headers_instance = Header(headers)
self.upgrade_websocket = ( self.upgrade_websocket = (
headers_instance.get("upgrade", "").lower() == "websocket" headers_instance.getone("upgrade", "").lower() == "websocket"
) )
# Prepare a Request object # Prepare a Request object
@ -237,7 +237,7 @@ class Http:
self.request_bytes_left = self.request_bytes = 0 self.request_bytes_left = self.request_bytes = 0
if request_body: if request_body:
headers = request.headers headers = request.headers
expect = headers.get("expect") expect = headers.getone("expect", None)
if expect is not None: if expect is not None:
if expect.lower() == "100-continue": if expect.lower() == "100-continue":
@ -245,7 +245,7 @@ class Http:
else: else:
raise HeaderExpectationFailed(f"Unknown Expect: {expect}") raise HeaderExpectationFailed(f"Unknown Expect: {expect}")
if headers.get("transfer-encoding") == "chunked": if headers.getone("transfer-encoding", None) == "chunked":
self.request_body = "chunked" self.request_body = "chunked"
pos -= 2 # One CRLF stays in buffer pos -= 2 # One CRLF stays in buffer
else: else:

View File

@ -665,7 +665,10 @@ class RouteMixin:
modified_since = strftime( modified_since = strftime(
"%a, %d %b %Y %H:%M:%S GMT", gmtime(stats.st_mtime) "%a, %d %b %Y %H:%M:%S GMT", gmtime(stats.st_mtime)
) )
if request.headers.get("If-Modified-Since") == modified_since: if (
request.headers.getone("if-modified-since", None)
== modified_since
):
return HTTPResponse(status=304) return HTTPResponse(status=304)
headers["Last-Modified"] = modified_since headers["Last-Modified"] = modified_since
_range = None _range = None

View File

@ -125,7 +125,7 @@ class Request:
self._name: Optional[str] = None self._name: Optional[str] = None
self.app = app self.app = app
self.headers = headers self.headers = Header(headers)
self.version = version self.version = version
self.method = method self.method = method
self.transport = transport self.transport = transport
@ -262,7 +262,7 @@ class Request:
app = Sanic("MyApp", request_class=IntRequest) app = Sanic("MyApp", request_class=IntRequest)
""" """
if not self._id: if not self._id:
self._id = self.headers.get( self._id = self.headers.getone(
self.app.config.REQUEST_ID_HEADER, self.app.config.REQUEST_ID_HEADER,
self.__class__.generate_id(self), # type: ignore self.__class__.generate_id(self), # type: ignore
) )
@ -303,7 +303,7 @@ class Request:
:return: token related to request :return: token related to request
""" """
prefixes = ("Bearer", "Token") prefixes = ("Bearer", "Token")
auth_header = self.headers.get("Authorization") auth_header = self.headers.getone("authorization", None)
if auth_header is not None: if auth_header is not None:
for prefix in prefixes: for prefix in prefixes:
@ -317,8 +317,8 @@ class Request:
if self.parsed_form is None: if self.parsed_form is None:
self.parsed_form = RequestParameters() self.parsed_form = RequestParameters()
self.parsed_files = RequestParameters() self.parsed_files = RequestParameters()
content_type = self.headers.get( content_type = self.headers.getone(
"Content-Type", DEFAULT_HTTP_CONTENT_TYPE "content-type", DEFAULT_HTTP_CONTENT_TYPE
) )
content_type, parameters = parse_content_header(content_type) content_type, parameters = parse_content_header(content_type)
try: try:
@ -465,7 +465,7 @@ class Request:
""" """
if self._cookies is None: if self._cookies is None:
cookie = self.headers.get("Cookie") cookie = self.headers.getone("cookie", None)
if cookie is not None: if cookie is not None:
cookies: SimpleCookie = SimpleCookie() cookies: SimpleCookie = SimpleCookie()
cookies.load(cookie) cookies.load(cookie)
@ -482,7 +482,7 @@ class Request:
:return: Content-Type header form the request :return: Content-Type header form the request
:rtype: str :rtype: str
""" """
return self.headers.get("Content-Type", DEFAULT_HTTP_CONTENT_TYPE) return self.headers.getone("content-type", DEFAULT_HTTP_CONTENT_TYPE)
@property @property
def match_info(self): def match_info(self):
@ -581,7 +581,7 @@ class Request:
if ( if (
self.app.websocket_enabled self.app.websocket_enabled
and self.headers.get("upgrade", "").lower() == "websocket" and self.headers.getone("upgrade", "").lower() == "websocket"
): ):
scheme = "ws" scheme = "ws"
else: else:
@ -608,7 +608,9 @@ class Request:
server_name = self.app.config.get("SERVER_NAME") server_name = self.app.config.get("SERVER_NAME")
if server_name: if server_name:
return server_name.split("//", 1)[-1].split("/", 1)[0] return server_name.split("//", 1)[-1].split("/", 1)[0]
return str(self.forwarded.get("host") or self.headers.get("host", "")) return str(
self.forwarded.get("host") or self.headers.getone("host", "")
)
@property @property
def server_name(self) -> str: def server_name(self) -> str:

View File

@ -20,6 +20,7 @@ def test_request_id_generates_from_request(monkeypatch):
monkeypatch.setattr(Request, "generate_id", Mock()) monkeypatch.setattr(Request, "generate_id", Mock())
Request.generate_id.return_value = 1 Request.generate_id.return_value = 1
request = Request(b"/", {}, None, "GET", None, Mock()) request = Request(b"/", {}, None, "GET", None, Mock())
request.app.config.REQUEST_ID_HEADER = "foo"
for _ in range(10): for _ in range(10):
request.id request.id
@ -28,6 +29,7 @@ def test_request_id_generates_from_request(monkeypatch):
def test_request_id_defaults_uuid(): def test_request_id_defaults_uuid():
request = Request(b"/", {}, None, "GET", None, Mock()) request = Request(b"/", {}, None, "GET", None, Mock())
request.app.config.REQUEST_ID_HEADER = "foo"
assert isinstance(request.id, UUID) assert isinstance(request.id, UUID)