From ad97cac31393477380ee421281d6356e604353d4 Mon Sep 17 00:00:00 2001 From: ENT8R Date: Thu, 8 Apr 2021 12:30:12 +0200 Subject: [PATCH] Explicit usage of CIMultiDict getters (#2104) --- sanic/app.py | 4 +++- sanic/errorpages.py | 2 +- sanic/handlers.py | 2 +- sanic/headers.py | 4 ++-- sanic/http.py | 6 +++--- sanic/mixins/routes.py | 5 ++++- sanic/request.py | 20 +++++++++++--------- tests/test_request.py | 2 ++ 8 files changed, 27 insertions(+), 18 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index d80bd242..3b3a45cf 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -673,7 +673,9 @@ class Sanic(BaseSanic): try: # Fetch handler from router route, handler, kwargs = self.router.get( - request.path, request.method, request.headers.get("host") + request.path, + request.method, + request.headers.getone("host", None), ) request._match_info = kwargs diff --git a/sanic/errorpages.py b/sanic/errorpages.py index ceb8c92f..5fc10de1 100644 --- a/sanic/errorpages.py +++ b/sanic/errorpages.py @@ -366,7 +366,7 @@ def exception_response( except InvalidUsage: renderer = HTMLRenderer - content_type, *_ = request.headers.get( + content_type, *_ = request.headers.getone( "content-type", "" ).split(";") renderer = RENDERERS_BY_CONTENT_TYPE.get( diff --git a/sanic/handlers.py b/sanic/handlers.py index 58754749..2f15c143 100644 --- a/sanic/handlers.py +++ b/sanic/handlers.py @@ -165,7 +165,7 @@ class ContentRangeHandler: def __init__(self, request, stats): self.total = stats.st_size - _range = request.headers.get("Range") + _range = request.headers.getone("range", None) if _range is None: raise HeaderNotFound("Range Header Not Found") unit, _, value = tuple(map(str.strip, _range.partition("="))) diff --git a/sanic/headers.py b/sanic/headers.py index c41cfcac..66427442 100644 --- a/sanic/headers.py +++ b/sanic/headers.py @@ -102,7 +102,7 @@ def parse_xforwarded(headers, config) -> Optional[Options]: """Parse traditional proxy headers.""" real_ip_header = config.REAL_IP_HEADER proxies_count = config.PROXIES_COUNT - addr = real_ip_header and headers.get(real_ip_header) + addr = real_ip_header and headers.getone(real_ip_header, None) if not addr and proxies_count: assert proxies_count > 0 try: @@ -131,7 +131,7 @@ def parse_xforwarded(headers, config) -> Optional[Options]: ("port", "x-forwarded-port"), ("path", "x-forwarded-path"), ): - yield key, headers.get(header) + yield key, headers.getone(header, None) return fwd_normalize(options()) diff --git a/sanic/http.py b/sanic/http.py index 7604dede..3dcfb22c 100644 --- a/sanic/http.py +++ b/sanic/http.py @@ -219,7 +219,7 @@ class Http: headers_instance = Header(headers) self.upgrade_websocket = ( - headers_instance.get("upgrade", "").lower() == "websocket" + headers_instance.getone("upgrade", "").lower() == "websocket" ) # Prepare a Request object @@ -237,7 +237,7 @@ class Http: self.request_bytes_left = self.request_bytes = 0 if request_body: headers = request.headers - expect = headers.get("expect") + expect = headers.getone("expect", None) if expect is not None: if expect.lower() == "100-continue": @@ -245,7 +245,7 @@ class Http: else: raise HeaderExpectationFailed(f"Unknown Expect: {expect}") - if headers.get("transfer-encoding") == "chunked": + if headers.getone("transfer-encoding", None) == "chunked": self.request_body = "chunked" pos -= 2 # One CRLF stays in buffer else: diff --git a/sanic/mixins/routes.py b/sanic/mixins/routes.py index a790ae16..be494a2a 100644 --- a/sanic/mixins/routes.py +++ b/sanic/mixins/routes.py @@ -665,7 +665,10 @@ class RouteMixin: modified_since = strftime( "%a, %d %b %Y %H:%M:%S GMT", gmtime(stats.st_mtime) ) - if request.headers.get("If-Modified-Since") == modified_since: + if ( + request.headers.getone("if-modified-since", None) + == modified_since + ): return HTTPResponse(status=304) headers["Last-Modified"] = modified_since _range = None diff --git a/sanic/request.py b/sanic/request.py index 91e0c808..6ba809ca 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -125,7 +125,7 @@ class Request: self._name: Optional[str] = None self.app = app - self.headers = headers + self.headers = Header(headers) self.version = version self.method = method self.transport = transport @@ -262,7 +262,7 @@ class Request: app = Sanic("MyApp", request_class=IntRequest) """ if not self._id: - self._id = self.headers.get( + self._id = self.headers.getone( self.app.config.REQUEST_ID_HEADER, self.__class__.generate_id(self), # type: ignore ) @@ -303,7 +303,7 @@ class Request: :return: token related to request """ prefixes = ("Bearer", "Token") - auth_header = self.headers.get("Authorization") + auth_header = self.headers.getone("authorization", None) if auth_header is not None: for prefix in prefixes: @@ -317,8 +317,8 @@ class Request: if self.parsed_form is None: self.parsed_form = RequestParameters() self.parsed_files = RequestParameters() - content_type = self.headers.get( - "Content-Type", DEFAULT_HTTP_CONTENT_TYPE + content_type = self.headers.getone( + "content-type", DEFAULT_HTTP_CONTENT_TYPE ) content_type, parameters = parse_content_header(content_type) try: @@ -465,7 +465,7 @@ class Request: """ if self._cookies is None: - cookie = self.headers.get("Cookie") + cookie = self.headers.getone("cookie", None) if cookie is not None: cookies: SimpleCookie = SimpleCookie() cookies.load(cookie) @@ -482,7 +482,7 @@ class Request: :return: Content-Type header form the request :rtype: str """ - return self.headers.get("Content-Type", DEFAULT_HTTP_CONTENT_TYPE) + return self.headers.getone("content-type", DEFAULT_HTTP_CONTENT_TYPE) @property def match_info(self): @@ -581,7 +581,7 @@ class Request: if ( self.app.websocket_enabled - and self.headers.get("upgrade", "").lower() == "websocket" + and self.headers.getone("upgrade", "").lower() == "websocket" ): scheme = "ws" else: @@ -608,7 +608,9 @@ class Request: server_name = self.app.config.get("SERVER_NAME") if server_name: return server_name.split("//", 1)[-1].split("/", 1)[0] - return str(self.forwarded.get("host") or self.headers.get("host", "")) + return str( + self.forwarded.get("host") or self.headers.getone("host", "") + ) @property def server_name(self) -> str: diff --git a/tests/test_request.py b/tests/test_request.py index 5b79b1e3..0cbf0994 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -20,6 +20,7 @@ def test_request_id_generates_from_request(monkeypatch): monkeypatch.setattr(Request, "generate_id", Mock()) Request.generate_id.return_value = 1 request = Request(b"/", {}, None, "GET", None, Mock()) + request.app.config.REQUEST_ID_HEADER = "foo" for _ in range(10): request.id @@ -28,6 +29,7 @@ def test_request_id_generates_from_request(monkeypatch): def test_request_id_defaults_uuid(): request = Request(b"/", {}, None, "GET", None, Mock()) + request.app.config.REQUEST_ID_HEADER = "foo" assert isinstance(request.id, UUID)