diff --git a/sanic/static.py b/sanic/static.py index 0a75d9dc..f0943a7d 100644 --- a/sanic/static.py +++ b/sanic/static.py @@ -1,8 +1,10 @@ from functools import partial, wraps from mimetypes import guess_type from os import path +from pathlib import PurePath from re import sub from time import gmtime, strftime +from typing import Union from urllib.parse import unquote from sanic.compat import stat_async @@ -110,13 +112,13 @@ async def _static_request_handler( def register( app, - uri, - file_or_directory, + uri: str, + file_or_directory: Union[str, bytes, PurePath], pattern, use_modified_since, use_content_range, stream_large_files, - name="static", + name: str = "static", host=None, strict_slashes=None, content_type=None, @@ -130,7 +132,9 @@ def register( :param app: Sanic :param file_or_directory: File or directory path to serve from + :type file_or_directory: Union[str,bytes,Path] :param uri: URL to serve from + :type uri: str :param pattern: regular expression used to match files in the URL :param use_modified_since: If true, send file modified time, and return not modified if the browser's matches the @@ -142,10 +146,19 @@ def register( If this is an integer, this represents the threshold size to switch to file_stream() :param name: user defined name used for url_for + :type name: str :param content_type: user defined content type for header :return: registered static routes :rtype: List[sanic.router.Route] """ + + if isinstance(file_or_directory, bytes): + file_or_directory = file_or_directory.decode("utf-8") + elif isinstance(file_or_directory, PurePath): + file_or_directory = str(file_or_directory) + elif not isinstance(file_or_directory, str): + raise ValueError("Invalid file path string.") + # If we're not trying to match a file directly, # serve from the folder if not path.isfile(file_or_directory): diff --git a/tests/test_static.py b/tests/test_static.py index 91635a45..23ba05d7 100644 --- a/tests/test_static.py +++ b/tests/test_static.py @@ -1,6 +1,6 @@ import inspect import os - +from pathlib import Path from time import gmtime, strftime import pytest @@ -76,6 +76,41 @@ def test_static_file(app, static_file_directory, file_name): assert response.body == get_file_content(static_file_directory, file_name) +@pytest.mark.parametrize( + "file_name", + ["test.file", "decode me.txt", "python.png", "symlink", "hard_link"], +) +def test_static_file_pathlib(app, static_file_directory, file_name): + file_path = Path(get_file_path(static_file_directory, file_name)) + app.static("/testing.file", file_path) + request, response = app.test_client.get("/testing.file") + assert response.status == 200 + assert response.body == get_file_content(static_file_directory, file_name) + + +@pytest.mark.parametrize( + "file_name", + [b"test.file", b"decode me.txt", b"python.png"], +) +def test_static_file_bytes(app, static_file_directory, file_name): + bsep = os.path.sep.encode('utf-8') + file_path = static_file_directory.encode('utf-8') + bsep + file_name + app.static("/testing.file", file_path) + request, response = app.test_client.get("/testing.file") + assert response.status == 200 + + +@pytest.mark.parametrize( + "file_name", + [dict(), list(), object()], +) +def test_static_file_invalid_path(app, static_file_directory, file_name): + with pytest.raises(ValueError): + app.static("/testing.file", file_name) + request, response = app.test_client.get("/testing.file") + assert response.status == 404 + + @pytest.mark.parametrize("file_name", ["test.html"]) def test_static_file_content_type(app, static_file_directory, file_name): app.static(