Validate File When Requested (#2526)

Co-authored-by: Adam Hopkins <adam@amhopkins.com>
This commit is contained in:
Zhiwei 2022-08-18 04:05:05 -05:00 committed by GitHub
parent 09089b1bd3
commit 753ee992a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 184 additions and 20 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime, timezone
from email.utils import formatdate from email.utils import formatdate, parsedate_to_datetime
from functools import partial from functools import partial
from mimetypes import guess_type from mimetypes import guess_type
from os import path from os import path
@ -33,6 +33,7 @@ from sanic.helpers import (
remove_entity_headers, remove_entity_headers,
) )
from sanic.http import Http from sanic.http import Http
from sanic.log import logger
from sanic.models.protocol_types import HTMLProtocol, Range from sanic.models.protocol_types import HTMLProtocol, Range
@ -319,9 +320,34 @@ def html(
) )
async def validate_file(
request_headers: Header, last_modified: Union[datetime, float, int]
):
try:
if_modified_since = request_headers.getone("If-Modified-Since")
except KeyError:
return
try:
if_modified_since = parsedate_to_datetime(if_modified_since)
except (TypeError, ValueError):
logger.warning(
"Ignorning invalid If-Modified-Since header received: " "'%s'",
if_modified_since,
)
return
if not isinstance(last_modified, datetime):
last_modified = datetime.fromtimestamp(
float(last_modified), tz=timezone.utc
).replace(microsecond=0)
if last_modified <= if_modified_since:
return HTTPResponse(status=304)
async def file( async def file(
location: Union[str, PurePath], location: Union[str, PurePath],
status: int = 200, status: int = 200,
request_headers: Optional[Header] = None,
validate_when_requested: bool = True,
mime_type: Optional[str] = None, mime_type: Optional[str] = None,
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
filename: Optional[str] = None, filename: Optional[str] = None,
@ -331,7 +357,12 @@ async def file(
_range: Optional[Range] = None, _range: Optional[Range] = None,
) -> HTTPResponse: ) -> HTTPResponse:
"""Return a response object with file data. """Return a response object with file data.
:param status: HTTP response code. Won't enforce the passed in
status if only a part of the content will be sent (206)
or file is being validated (304).
:param request_headers: The request headers.
:param validate_when_requested: If True, will validate the
file when requested.
:param location: Location of file on system. :param location: Location of file on system.
:param mime_type: Specific mime_type. :param mime_type: Specific mime_type.
:param headers: Custom Headers. :param headers: Custom Headers.
@ -341,11 +372,6 @@ async def file(
:param no_store: Any cache should not store this response. :param no_store: Any cache should not store this response.
:param _range: :param _range:
""" """
headers = headers or {}
if filename:
headers.setdefault(
"Content-Disposition", f'attachment; filename="{filename}"'
)
if isinstance(last_modified, datetime): if isinstance(last_modified, datetime):
last_modified = last_modified.replace(microsecond=0).timestamp() last_modified = last_modified.replace(microsecond=0).timestamp()
@ -353,9 +379,24 @@ async def file(
stat = await stat_async(location) stat = await stat_async(location)
last_modified = stat.st_mtime last_modified = stat.st_mtime
if (
validate_when_requested
and request_headers is not None
and last_modified
):
response = await validate_file(request_headers, last_modified)
if response:
return response
headers = headers or {}
if last_modified: if last_modified:
headers.setdefault( headers.setdefault(
"last-modified", formatdate(last_modified, usegmt=True) "Last-Modified", formatdate(last_modified, usegmt=True)
)
if filename:
headers.setdefault(
"Content-Disposition", f'attachment; filename="{filename}"'
) )
if no_store: if no_store:

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
import inspect import inspect
import os import os
import time
from collections import namedtuple from collections import namedtuple
from datetime import datetime from datetime import datetime
@ -730,8 +731,10 @@ def test_file_response_headers(
test_expires = test_last_modified.timestamp() + test_max_age test_expires = test_last_modified.timestamp() + test_max_age
@app.route("/files/cached/<filename>", methods=["GET"]) @app.route("/files/cached/<filename>", methods=["GET"])
def file_route_cache(request, filename): def file_route_cache(request: Request, filename: str):
file_path = (Path(static_file_directory) / file_name).absolute() file_path = (
Path(static_file_directory) / unquote(filename)
).absolute()
return file( return file(
file_path, max_age=test_max_age, last_modified=test_last_modified file_path, max_age=test_max_age, last_modified=test_last_modified
) )
@ -739,18 +742,26 @@ def test_file_response_headers(
@app.route( @app.route(
"/files/cached_default_last_modified/<filename>", methods=["GET"] "/files/cached_default_last_modified/<filename>", methods=["GET"]
) )
def file_route_cache_default_last_modified(request, filename): def file_route_cache_default_last_modified(
file_path = (Path(static_file_directory) / file_name).absolute() request: Request, filename: str
):
file_path = (
Path(static_file_directory) / unquote(filename)
).absolute()
return file(file_path, max_age=test_max_age) return file(file_path, max_age=test_max_age)
@app.route("/files/no_cache/<filename>", methods=["GET"]) @app.route("/files/no_cache/<filename>", methods=["GET"])
def file_route_no_cache(request, filename): def file_route_no_cache(request: Request, filename: str):
file_path = (Path(static_file_directory) / file_name).absolute() file_path = (
Path(static_file_directory) / unquote(filename)
).absolute()
return file(file_path) return file(file_path)
@app.route("/files/no_store/<filename>", methods=["GET"]) @app.route("/files/no_store/<filename>", methods=["GET"])
def file_route_no_store(request, filename): def file_route_no_store(request: Request, filename: str):
file_path = (Path(static_file_directory) / file_name).absolute() file_path = (
Path(static_file_directory) / unquote(filename)
).absolute()
return file(file_path, no_store=True) return file(file_path, no_store=True)
_, response = app.test_client.get(f"/files/cached/{file_name}") _, response = app.test_client.get(f"/files/cached/{file_name}")
@ -767,11 +778,11 @@ def test_file_response_headers(
== formatdate(test_expires, usegmt=True)[:-6] == formatdate(test_expires, usegmt=True)[:-6]
# [:-6] to allow at most 1 min difference # [:-6] to allow at most 1 min difference
# It's minimal for cases like: # It's minimal for cases like:
# Thu, 26 May 2022 05:36:49 GMT # Thu, 26 May 2022 05:36:59 GMT
# AND # AND
# Thu, 26 May 2022 05:36:50 GMT # Thu, 26 May 2022 05:37:00 GMT
) )
assert response.status == 200
assert "last-modified" in headers and headers.get( assert "last-modified" in headers and headers.get(
"last-modified" "last-modified"
) == formatdate(test_last_modified.timestamp(), usegmt=True) ) == formatdate(test_last_modified.timestamp(), usegmt=True)
@ -786,15 +797,127 @@ def test_file_response_headers(
assert "last-modified" in headers and headers.get( assert "last-modified" in headers and headers.get(
"last-modified" "last-modified"
) == formatdate(file_last_modified, usegmt=True) ) == formatdate(file_last_modified, usegmt=True)
assert response.status == 200
_, response = app.test_client.get(f"/files/no_cache/{file_name}") _, response = app.test_client.get(f"/files/no_cache/{file_name}")
headers = response.headers headers = response.headers
assert "cache-control" in headers and f"no-cache" == headers.get( assert "cache-control" in headers and f"no-cache" == headers.get(
"cache-control" "cache-control"
) )
assert response.status == 200
_, response = app.test_client.get(f"/files/no_store/{file_name}") _, response = app.test_client.get(f"/files/no_store/{file_name}")
headers = response.headers headers = response.headers
assert "cache-control" in headers and f"no-store" == headers.get( assert "cache-control" in headers and f"no-store" == headers.get(
"cache-control" "cache-control"
) )
assert response.status == 200
def test_file_validate(app: Sanic, static_file_directory: str):
file_name = "test_validate.txt"
static_file_directory = Path(static_file_directory)
file_path = static_file_directory / file_name
file_path = file_path.absolute()
test_max_age = 10
with open(file_path, "w+") as f:
f.write("foo\n")
@app.route("/validate", methods=["GET"])
def file_route_cache(request: Request):
return file(
file_path,
request_headers=request.headers,
max_age=test_max_age,
validate_when_requested=True,
)
_, response = app.test_client.get("/validate")
assert response.status == 200
assert response.body == b"foo\n"
last_modified = response.headers["Last-Modified"]
time.sleep(1)
with open(file_path, "a") as f:
f.write("bar\n")
_, response = app.test_client.get(
"/validate", headers={"If-Modified-Since": last_modified}
)
assert response.status == 200
assert response.body == b"foo\nbar\n"
last_modified = response.headers["Last-Modified"]
_, response = app.test_client.get(
"/validate", headers={"if-modified-since": last_modified}
)
assert response.status == 304
assert response.body == b""
file_path.unlink()
@pytest.mark.parametrize(
"file_name", ["test.file", "decode me.txt", "python.png"]
)
def test_file_validating_invalid_header(
app: Sanic, file_name: str, static_file_directory: str
):
@app.route("/files/<filename>", methods=["GET"])
def file_route(request: Request, filename: str):
handler_file_path = (
Path(static_file_directory) / unquote(filename)
).absolute()
return file(
handler_file_path,
request_headers=request.headers,
validate_when_requested=True,
)
_, response = app.test_client.get(f"/files/{file_name}")
assert response.status == 200
assert response.body == get_file_content(static_file_directory, file_name)
_, response = app.test_client.get(
f"/files/{file_name}", headers={"if-modified-since": "invalid-value"}
)
assert response.status == 200
assert response.body == get_file_content(static_file_directory, file_name)
_, response = app.test_client.get(
f"/files/{file_name}", headers={"if-modified-since": ""}
)
assert response.status == 200
assert response.body == get_file_content(static_file_directory, file_name)
@pytest.mark.parametrize(
"file_name", ["test.file", "decode me.txt", "python.png"]
)
def test_file_validating_304_response(
app: Sanic, file_name: str, static_file_directory: str
):
@app.route("/files/<filename>", methods=["GET"])
def file_route(request: Request, filename: str):
handler_file_path = (
Path(static_file_directory) / unquote(filename)
).absolute()
return file(
handler_file_path,
request_headers=request.headers,
validate_when_requested=True,
)
_, response = app.test_client.get(f"/files/{file_name}")
assert response.status == 200
assert response.body == get_file_content(static_file_directory, file_name)
_, response = app.test_client.get(
f"/files/{file_name}",
headers={"if-modified-since": response.headers["Last-Modified"]},
)
assert response.status == 304
assert response.body == b""