From 43ba381e7b6a81f37268f7bcab1816b01087b4e0 Mon Sep 17 00:00:00 2001 From: Zhiwei Date: Tue, 20 Sep 2022 16:45:03 -0500 Subject: [PATCH] Refactor `_static_request_handler` (#2533) Co-authored-by: Adam Hopkins --- sanic/mixins/routes.py | 61 +++++++++++++++++++++++------------------- tests/test_response.py | 33 +++++++++++++++++++---- 2 files changed, 61 insertions(+), 33 deletions(-) diff --git a/sanic/mixins/routes.py b/sanic/mixins/routes.py index 4f619714..2fb1e837 100644 --- a/sanic/mixins/routes.py +++ b/sanic/mixins/routes.py @@ -1,12 +1,12 @@ from ast import NodeVisitor, Return, parse from contextlib import suppress +from email.utils import formatdate from functools import partial, wraps from inspect import getsource, signature from mimetypes import guess_type from os import path from pathlib import Path, PurePath from textwrap import dedent -from time import gmtime, strftime from typing import ( Any, Callable, @@ -31,7 +31,7 @@ from sanic.handlers import ContentRangeHandler from sanic.log import error_logger from sanic.models.futures import FutureRoute, FutureStatic from sanic.models.handler_types import RouteHandler -from sanic.response import HTTPResponse, file, file_stream +from sanic.response import HTTPResponse, file, file_stream, validate_file from sanic.types import HashableDict @@ -790,24 +790,9 @@ class RouteMixin(metaclass=SanicMeta): return name - async def _static_request_handler( - self, - file_or_directory, - use_modified_since, - use_content_range, - stream_large_files, - request, - content_type=None, - __file_uri__=None, - ): - # Merge served directory and requested file if provided + async def _get_file_path(self, file_or_directory, __file_uri__, not_found): file_path_raw = Path(unquote(file_or_directory)) root_path = file_path = file_path_raw.resolve() - not_found = FileNotFound( - "File not found", - path=file_or_directory, - relative_url=__file_uri__, - ) if __file_uri__: # Strip all / that in the beginning of the URL to help prevent @@ -834,6 +819,29 @@ class RouteMixin(metaclass=SanicMeta): f"relative_url={__file_uri__}" ) raise not_found + return file_path + + async def _static_request_handler( + self, + file_or_directory, + use_modified_since, + use_content_range, + stream_large_files, + request, + content_type=None, + __file_uri__=None, + ): + not_found = FileNotFound( + "File not found", + path=file_or_directory, + relative_url=__file_uri__, + ) + + # Merge served directory and requested file if provided + file_path = await self._get_file_path( + file_or_directory, __file_uri__, not_found + ) + try: headers = {} # Check if the client has been sent this file before @@ -841,15 +849,13 @@ class RouteMixin(metaclass=SanicMeta): stats = None if use_modified_since: stats = await stat_async(file_path) - modified_since = strftime( - "%a, %d %b %Y %H:%M:%S GMT", gmtime(stats.st_mtime) + modified_since = stats.st_mtime + response = await validate_file(request.headers, modified_since) + if response: + return response + headers["Last-Modified"] = formatdate( + modified_since, usegmt=True ) - if ( - request.headers.getone("if-modified-since", None) - == modified_since - ): - return HTTPResponse(status=304) - headers["Last-Modified"] = modified_since _range = None if use_content_range: _range = None @@ -864,8 +870,7 @@ class RouteMixin(metaclass=SanicMeta): pass else: del headers["Content-Length"] - for key, value in _range.headers.items(): - headers[key] = value + headers.update(_range.headers) if "content-type" not in headers: content_type = ( diff --git a/tests/test_response.py b/tests/test_response.py index cd61a621..9254fca3 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -4,8 +4,8 @@ import os import time from collections import namedtuple -from datetime import datetime -from email.utils import formatdate +from datetime import datetime, timedelta +from email.utils import formatdate, parsedate_to_datetime from logging import ERROR, LogRecord from mimetypes import guess_type from pathlib import Path @@ -665,13 +665,11 @@ def test_multiple_responses( with caplog.at_level(ERROR): _, response = app.test_client.get("/4") - print(response.json) assert response.status == 200 assert "foo" not in response.text assert "one" in response.headers assert response.headers["one"] == "one" - print(response.headers) assert message_in_records(caplog.records, error_msg2) with caplog.at_level(ERROR): @@ -841,10 +839,10 @@ def test_file_validate(app: Sanic, static_file_directory: str): time.sleep(1) with open(file_path, "a") as f: f.write("bar\n") - _, response = app.test_client.get( "/validate", headers={"If-Modified-Since": last_modified} ) + assert response.status == 200 assert response.body == b"foo\nbar\n" @@ -921,3 +919,28 @@ def test_file_validating_304_response( ) assert response.status == 304 assert response.body == b"" + + +@pytest.mark.parametrize( + "file_name", ["test.file", "decode me.txt", "python.png"] +) +def test_file_validating_304_response( + app: Sanic, file_name: str, static_file_directory: str +): + app.static("static", Path(static_file_directory) / file_name) + + _, response = app.test_client.get("/static") + assert response.status == 200 + assert response.body == get_file_content(static_file_directory, file_name) + last_modified = parsedate_to_datetime(response.headers["Last-Modified"]) + last_modified += timedelta(seconds=1) + _, response = app.test_client.get( + "/static", + headers={ + "if-modified-since": formatdate( + last_modified.timestamp(), usegmt=True + ) + }, + ) + assert response.status == 304 + assert response.body == b""