File Cache Control Headers Support (#2447)

Co-authored-by: Adam Hopkins <adam@amhopkins.com>
This commit is contained in:
Zhiwei 2022-06-16 08:24:39 -05:00 committed by GitHub
parent 2f90a85df1
commit a744041e38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 145 additions and 5 deletions

View File

@ -1,9 +1,12 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime
from email.utils import formatdate
from functools import partial from functools import partial
from mimetypes import guess_type from mimetypes import guess_type
from os import path from os import path
from pathlib import PurePath from pathlib import Path, PurePath
from time import time
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -23,7 +26,12 @@ from sanic.compat import Header, open_async
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE
from sanic.cookies import CookieJar from sanic.cookies import CookieJar
from sanic.exceptions import SanicException, ServerError from sanic.exceptions import SanicException, ServerError
from sanic.helpers import has_message_body, remove_entity_headers from sanic.helpers import (
Default,
_default,
has_message_body,
remove_entity_headers,
)
from sanic.http import Http from sanic.http import Http
from sanic.models.protocol_types import HTMLProtocol, Range from sanic.models.protocol_types import HTMLProtocol, Range
@ -309,6 +317,9 @@ async def file(
mime_type: Optional[str] = None, mime_type: Optional[str] = None,
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
filename: Optional[str] = None, filename: Optional[str] = None,
last_modified: Optional[Union[datetime, float, int, Default]] = _default,
max_age: Optional[Union[float, int]] = None,
no_store: Optional[bool] = None,
_range: Optional[Range] = None, _range: Optional[Range] = None,
) -> HTTPResponse: ) -> HTTPResponse:
"""Return a response object with file data. """Return a response object with file data.
@ -317,6 +328,9 @@ async def file(
:param mime_type: Specific mime_type. :param mime_type: Specific mime_type.
:param headers: Custom Headers. :param headers: Custom Headers.
:param filename: Override filename. :param filename: Override filename.
:param last_modified: The last modified date and time of the file.
:param max_age: Max age for cache control.
:param no_store: Any cache should not store this response.
:param _range: :param _range:
""" """
headers = headers or {} headers = headers or {}
@ -324,6 +338,33 @@ async def file(
headers.setdefault( headers.setdefault(
"Content-Disposition", f'attachment; filename="{filename}"' "Content-Disposition", f'attachment; filename="{filename}"'
) )
if isinstance(last_modified, datetime):
last_modified = last_modified.timestamp()
elif isinstance(last_modified, Default):
last_modified = Path(location).stat().st_mtime
if last_modified:
headers.setdefault(
"last-modified", formatdate(last_modified, usegmt=True)
)
if no_store:
cache_control = "no-store"
elif max_age:
cache_control = f"public, max-age={max_age}"
headers.setdefault(
"expires",
formatdate(
time() + max_age,
usegmt=True,
),
)
else:
cache_control = "no-cache"
headers.setdefault("cache-control", cache_control)
filename = filename or path.split(location)[-1] filename = filename or path.split(location)[-1]
async with await open_async(location, mode="rb") as f: async with await open_async(location, mode="rb") as f:

View File

@ -3,10 +3,13 @@ import inspect
import os import os
from collections import namedtuple from collections import namedtuple
from datetime import datetime
from email.utils import formatdate
from logging import ERROR, LogRecord from logging import ERROR, LogRecord
from mimetypes import guess_type from mimetypes import guess_type
from pathlib import Path
from random import choice from random import choice
from typing import Callable, List from typing import Callable, List, Union
from urllib.parse import unquote from urllib.parse import unquote
import pytest import pytest
@ -328,12 +331,27 @@ def static_file_directory():
return static_directory return static_directory
def get_file_content(static_file_directory, file_name): def path_str_to_path_obj(static_file_directory: Union[Path, str]):
if isinstance(static_file_directory, str):
static_file_directory = Path(static_file_directory)
return static_file_directory
def get_file_content(static_file_directory: Union[Path, str], file_name: str):
"""The content of the static file to check""" """The content of the static file to check"""
with open(os.path.join(static_file_directory, file_name), "rb") as file: static_file_directory = path_str_to_path_obj(static_file_directory)
with open(static_file_directory / file_name, "rb") as file:
return file.read() return file.read()
def get_file_last_modified_timestamp(
static_file_directory: Union[Path, str], file_name: str
):
"""The content of the static file to check"""
static_file_directory = path_str_to_path_obj(static_file_directory)
return (static_file_directory / file_name).stat().st_mtime
@pytest.mark.parametrize( @pytest.mark.parametrize(
"file_name", ["test.file", "decode me.txt", "python.png"] "file_name", ["test.file", "decode me.txt", "python.png"]
) )
@ -711,3 +729,84 @@ def send_response_after_eof_should_fail(
assert "foo, " in response.text assert "foo, " in response.text
assert message_in_records(caplog.records, error_msg1) assert message_in_records(caplog.records, error_msg1)
assert message_in_records(caplog.records, error_msg2) assert message_in_records(caplog.records, error_msg2)
@pytest.mark.parametrize(
"file_name", ["test.file", "decode me.txt", "python.png"]
)
def test_file_response_headers(
app: Sanic, file_name: str, static_file_directory: str
):
test_last_modified = datetime.now()
test_max_age = 10
test_expires = test_last_modified.timestamp() + test_max_age
@app.route("/files/cached/<filename>", methods=["GET"])
def file_route_cache(request, filename):
file_path = (Path(static_file_directory) / file_name).absolute()
return file(
file_path, max_age=test_max_age, last_modified=test_last_modified
)
@app.route(
"/files/cached_default_last_modified/<filename>", methods=["GET"]
)
def file_route_cache_default_last_modified(request, filename):
file_path = (Path(static_file_directory) / file_name).absolute()
return file(file_path, max_age=test_max_age)
@app.route("/files/no_cache/<filename>", methods=["GET"])
def file_route_no_cache(request, filename):
file_path = (Path(static_file_directory) / file_name).absolute()
return file(file_path)
@app.route("/files/no_store/<filename>", methods=["GET"])
def file_route_no_store(request, filename):
file_path = (Path(static_file_directory) / file_name).absolute()
return file(file_path, no_store=True)
_, response = app.test_client.get(f"/files/cached/{file_name}")
assert response.body == get_file_content(static_file_directory, file_name)
headers = response.headers
assert (
"cache-control" in headers
and f"max-age={test_max_age}" in headers.get("cache-control")
and f"public" in headers.get("cache-control")
)
assert (
"expires" in headers
and headers.get("expires")[:-6]
== formatdate(test_expires, usegmt=True)[:-6]
# [:-6] to allow at most 1 min difference
# It's minimal for cases like:
# Thu, 26 May 2022 05:36:49 GMT
# AND
# Thu, 26 May 2022 05:36:50 GMT
)
assert "last-modified" in headers and headers.get(
"last-modified"
) == formatdate(test_last_modified.timestamp(), usegmt=True)
_, response = app.test_client.get(
f"/files/cached_default_last_modified/{file_name}"
)
file_last_modified = get_file_last_modified_timestamp(
static_file_directory, file_name
)
headers = response.headers
assert "last-modified" in headers and headers.get(
"last-modified"
) == formatdate(file_last_modified, usegmt=True)
_, response = app.test_client.get(f"/files/no_cache/{file_name}")
headers = response.headers
assert "cache-control" in headers and f"no-cache" == headers.get(
"cache-control"
)
_, response = app.test_client.get(f"/files/no_store/{file_name}")
headers = response.headers
assert "cache-control" in headers and f"no-store" == headers.get(
"cache-control"
)