Refactor _static_request_handler (#2533)

Co-authored-by: Adam Hopkins <adam@amhopkins.com>
This commit is contained in:
Zhiwei 2022-09-20 16:45:03 -05:00 committed by GitHub
parent 16503319e5
commit 43ba381e7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 33 deletions

View File

@ -1,12 +1,12 @@
from ast import NodeVisitor, Return, parse from ast import NodeVisitor, Return, parse
from contextlib import suppress from contextlib import suppress
from email.utils import formatdate
from functools import partial, wraps from functools import partial, wraps
from inspect import getsource, signature from inspect import getsource, signature
from mimetypes import guess_type from mimetypes import guess_type
from os import path from os import path
from pathlib import Path, PurePath from pathlib import Path, PurePath
from textwrap import dedent from textwrap import dedent
from time import gmtime, strftime
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -31,7 +31,7 @@ from sanic.handlers import ContentRangeHandler
from sanic.log import error_logger from sanic.log import error_logger
from sanic.models.futures import FutureRoute, FutureStatic from sanic.models.futures import FutureRoute, FutureStatic
from sanic.models.handler_types import RouteHandler from sanic.models.handler_types import RouteHandler
from sanic.response import HTTPResponse, file, file_stream from sanic.response import HTTPResponse, file, file_stream, validate_file
from sanic.types import HashableDict from sanic.types import HashableDict
@ -790,24 +790,9 @@ class RouteMixin(metaclass=SanicMeta):
return name return name
async def _static_request_handler( async def _get_file_path(self, file_or_directory, __file_uri__, not_found):
self,
file_or_directory,
use_modified_since,
use_content_range,
stream_large_files,
request,
content_type=None,
__file_uri__=None,
):
# Merge served directory and requested file if provided
file_path_raw = Path(unquote(file_or_directory)) file_path_raw = Path(unquote(file_or_directory))
root_path = file_path = file_path_raw.resolve() root_path = file_path = file_path_raw.resolve()
not_found = FileNotFound(
"File not found",
path=file_or_directory,
relative_url=__file_uri__,
)
if __file_uri__: if __file_uri__:
# Strip all / that in the beginning of the URL to help prevent # Strip all / that in the beginning of the URL to help prevent
@ -834,6 +819,29 @@ class RouteMixin(metaclass=SanicMeta):
f"relative_url={__file_uri__}" f"relative_url={__file_uri__}"
) )
raise not_found raise not_found
return file_path
async def _static_request_handler(
self,
file_or_directory,
use_modified_since,
use_content_range,
stream_large_files,
request,
content_type=None,
__file_uri__=None,
):
not_found = FileNotFound(
"File not found",
path=file_or_directory,
relative_url=__file_uri__,
)
# Merge served directory and requested file if provided
file_path = await self._get_file_path(
file_or_directory, __file_uri__, not_found
)
try: try:
headers = {} headers = {}
# Check if the client has been sent this file before # Check if the client has been sent this file before
@ -841,15 +849,13 @@ class RouteMixin(metaclass=SanicMeta):
stats = None stats = None
if use_modified_since: if use_modified_since:
stats = await stat_async(file_path) stats = await stat_async(file_path)
modified_since = strftime( modified_since = stats.st_mtime
"%a, %d %b %Y %H:%M:%S GMT", gmtime(stats.st_mtime) response = await validate_file(request.headers, modified_since)
if response:
return response
headers["Last-Modified"] = formatdate(
modified_since, usegmt=True
) )
if (
request.headers.getone("if-modified-since", None)
== modified_since
):
return HTTPResponse(status=304)
headers["Last-Modified"] = modified_since
_range = None _range = None
if use_content_range: if use_content_range:
_range = None _range = None
@ -864,8 +870,7 @@ class RouteMixin(metaclass=SanicMeta):
pass pass
else: else:
del headers["Content-Length"] del headers["Content-Length"]
for key, value in _range.headers.items(): headers.update(_range.headers)
headers[key] = value
if "content-type" not in headers: if "content-type" not in headers:
content_type = ( content_type = (

View File

@ -4,8 +4,8 @@ import os
import time import time
from collections import namedtuple from collections import namedtuple
from datetime import datetime from datetime import datetime, timedelta
from email.utils import formatdate from email.utils import formatdate, parsedate_to_datetime
from logging import ERROR, LogRecord from logging import ERROR, LogRecord
from mimetypes import guess_type from mimetypes import guess_type
from pathlib import Path from pathlib import Path
@ -665,13 +665,11 @@ def test_multiple_responses(
with caplog.at_level(ERROR): with caplog.at_level(ERROR):
_, response = app.test_client.get("/4") _, response = app.test_client.get("/4")
print(response.json)
assert response.status == 200 assert response.status == 200
assert "foo" not in response.text assert "foo" not in response.text
assert "one" in response.headers assert "one" in response.headers
assert response.headers["one"] == "one" assert response.headers["one"] == "one"
print(response.headers)
assert message_in_records(caplog.records, error_msg2) assert message_in_records(caplog.records, error_msg2)
with caplog.at_level(ERROR): with caplog.at_level(ERROR):
@ -841,10 +839,10 @@ def test_file_validate(app: Sanic, static_file_directory: str):
time.sleep(1) time.sleep(1)
with open(file_path, "a") as f: with open(file_path, "a") as f:
f.write("bar\n") f.write("bar\n")
_, response = app.test_client.get( _, response = app.test_client.get(
"/validate", headers={"If-Modified-Since": last_modified} "/validate", headers={"If-Modified-Since": last_modified}
) )
assert response.status == 200 assert response.status == 200
assert response.body == b"foo\nbar\n" assert response.body == b"foo\nbar\n"
@ -921,3 +919,28 @@ def test_file_validating_304_response(
) )
assert response.status == 304 assert response.status == 304
assert response.body == b"" assert response.body == b""
@pytest.mark.parametrize(
"file_name", ["test.file", "decode me.txt", "python.png"]
)
def test_file_validating_304_response(
app: Sanic, file_name: str, static_file_directory: str
):
app.static("static", Path(static_file_directory) / file_name)
_, response = app.test_client.get("/static")
assert response.status == 200
assert response.body == get_file_content(static_file_directory, file_name)
last_modified = parsedate_to_datetime(response.headers["Last-Modified"])
last_modified += timedelta(seconds=1)
_, response = app.test_client.get(
"/static",
headers={
"if-modified-since": formatdate(
last_modified.timestamp(), usegmt=True
)
},
)
assert response.status == 304
assert response.body == b""