Add some test coverage

This commit is contained in:
Adam Hopkins 2021-02-15 21:50:20 +02:00
parent 55a5ab4be1
commit 7f63ad5484
12 changed files with 378 additions and 201 deletions

View File

@ -49,7 +49,6 @@ from sanic.server import (
serve,
serve_multiple,
)
from sanic.static import register as static_register
from sanic.websocket import ConnectionClosed, WebSocketProtocol
@ -242,7 +241,7 @@ class Sanic(BaseSanic):
return self.router.add(**params)
def _apply_static(self, static: FutureStatic) -> Route:
return static_register(self, static)
return self._register_static(static)
def _apply_middleware(
self,

View File

@ -17,7 +17,7 @@ class Header(CIMultiDict):
use_trio = argv[0].endswith("hypercorn") and "trio" in argv
if use_trio:
if use_trio: # pragma: no cover
import trio # type: ignore
def stat_async(path):

View File

@ -8,7 +8,7 @@ class ExceptionMixin:
self._future_exceptions: Set[FutureException] = set()
def _apply_exception_handler(self, handler: FutureException):
raise NotImplementedError
raise NotImplementedError # noqa
def exception(self, *exceptions, apply=True):
"""

View File

@ -20,7 +20,7 @@ class ListenerMixin:
self._future_listeners: List[FutureListener] = list()
def _apply_listener(self, listener: FutureListener):
raise NotImplementedError
raise NotImplementedError # noqa
def listener(self, listener_or_event, event_or_none=None, apply=True):
"""Create a listener from a decorated function.

View File

@ -9,7 +9,7 @@ class MiddlewareMixin:
self._future_middleware: List[FutureMiddleware] = list()
def _apply_middleware(self, middleware: FutureMiddleware):
raise NotImplementedError
raise NotImplementedError # noqa
def middleware(
self, middleware_or_request, attach_to="request", apply=True
@ -42,8 +42,14 @@ class MiddlewareMixin:
register_middleware, attach_to=middleware_or_request
)
def on_request(self, middleware):
def on_request(self, middleware=None):
if callable(middleware):
return self.middleware(middleware, "request")
else:
return partial(self.middleware, attach_to="request")
def on_response(self, middleware):
def on_response(self, middleware=None):
if callable(middleware):
return self.middleware(middleware, "response")
else:
return partial(self.middleware, attach_to="response")

View File

@ -1,11 +1,27 @@
from functools import partial, wraps
from inspect import signature
from mimetypes import guess_type
from os import path
from pathlib import PurePath
from re import sub
from time import gmtime, strftime
from typing import Set, Union
from urllib.parse import unquote
from sanic_routing.route import Route # type: ignore
from sanic.compat import stat_async
from sanic.constants import HTTP_METHODS
from sanic.exceptions import (
ContentRangeError,
FileNotFound,
HeaderNotFound,
InvalidUsage,
)
from sanic.handlers import ContentRangeHandler
from sanic.log import error_logger
from sanic.models.futures import FutureRoute, FutureStatic
from sanic.response import HTTPResponse, file, file_stream
from sanic.views import CompositionView
@ -17,10 +33,10 @@ class RouteMixin:
self.strict_slashes = False
def _apply_route(self, route: FutureRoute) -> Route:
raise NotImplementedError
raise NotImplementedError # noqa
def _apply_static(self, static: FutureStatic) -> Route:
raise NotImplementedError
raise NotImplementedError # noqa
def route(
self,
@ -555,10 +571,191 @@ class RouteMixin:
else:
break
if not name:
raise Exception("...")
if not name: # noq
raise ValueError("Could not generate a name for handler")
if not name.startswith(f"{self.name}."):
name = f"{self.name}.{name}"
return name
async def _static_request_handler(
self,
file_or_directory,
use_modified_since,
use_content_range,
stream_large_files,
request,
content_type=None,
file_uri=None,
):
# Using this to determine if the URL is trying to break out of the path
# served. os.path.realpath seems to be very slow
if file_uri and "../" in file_uri:
raise InvalidUsage("Invalid URL")
# Merge served directory and requested file if provided
# Strip all / that in the beginning of the URL to help prevent python
# from herping a derp and treating the uri as an absolute path
root_path = file_path = file_or_directory
if file_uri:
file_path = path.join(
file_or_directory, sub("^[/]*", "", file_uri)
)
# URL decode the path sent by the browser otherwise we won't be able to
# match filenames which got encoded (filenames with spaces etc)
file_path = path.abspath(unquote(file_path))
if not file_path.startswith(path.abspath(unquote(root_path))):
error_logger.exception(
f"File not found: path={file_or_directory}, "
f"relative_url={file_uri}"
)
raise FileNotFound(
"File not found", path=file_or_directory, relative_url=file_uri
)
try:
headers = {}
# Check if the client has been sent this file before
# and it has not been modified since
stats = None
if use_modified_since:
stats = await stat_async(file_path)
modified_since = strftime(
"%a, %d %b %Y %H:%M:%S GMT", gmtime(stats.st_mtime)
)
if request.headers.get("If-Modified-Since") == modified_since:
return HTTPResponse(status=304)
headers["Last-Modified"] = modified_since
_range = None
if use_content_range:
_range = None
if not stats:
stats = await stat_async(file_path)
headers["Accept-Ranges"] = "bytes"
headers["Content-Length"] = str(stats.st_size)
if request.method != "HEAD":
try:
_range = ContentRangeHandler(request, stats)
except HeaderNotFound:
pass
else:
del headers["Content-Length"]
for key, value in _range.headers.items():
headers[key] = value
if "content-type" not in headers:
content_type = (
content_type
or guess_type(file_path)[0]
or "application/octet-stream"
)
if "charset=" not in content_type and (
content_type.startswith("text/")
or content_type == "application/javascript"
):
content_type += "; charset=utf-8"
headers["Content-Type"] = content_type
if request.method == "HEAD":
return HTTPResponse(headers=headers)
else:
if stream_large_files:
if type(stream_large_files) == int:
threshold = stream_large_files
else:
threshold = 1024 * 1024
if not stats:
stats = await stat_async(file_path)
if stats.st_size >= threshold:
return await file_stream(
file_path, headers=headers, _range=_range
)
return await file(file_path, headers=headers, _range=_range)
except ContentRangeError:
raise
except Exception:
error_logger.exception(
f"File not found: path={file_or_directory}, "
f"relative_url={file_uri}"
)
raise FileNotFound(
"File not found", path=file_or_directory, relative_url=file_uri
)
def _register_static(
self,
static: FutureStatic,
):
# TODO: Though sanic is not a file server, I feel like we should
# at least make a good effort here. Modified-since is nice, but
# we could also look into etags, expires, and caching
"""
Register a static directory handler with Sanic by adding a route to the
router and registering a handler.
:param app: Sanic
:param file_or_directory: File or directory path to serve from
:type file_or_directory: Union[str,bytes,Path]
:param uri: URL to serve from
:type uri: str
:param pattern: regular expression used to match files in the URL
:param use_modified_since: If true, send file modified time, and return
not modified if the browser's matches the
server's
:param use_content_range: If true, process header for range requests
and sends the file part that is requested
:param stream_large_files: If true, use the file_stream() handler
rather than the file() handler to send the file
If this is an integer, this represents the
threshold size to switch to file_stream()
:param name: user defined name used for url_for
:type name: str
:param content_type: user defined content type for header
:return: registered static routes
:rtype: List[sanic.router.Route]
"""
if isinstance(static.file_or_directory, bytes):
file_or_directory = static.file_or_directory.decode("utf-8")
elif isinstance(static.file_or_directory, PurePath):
file_or_directory = str(static.file_or_directory)
elif not isinstance(static.file_or_directory, str):
raise ValueError("Invalid file path string.")
else:
file_or_directory = static.file_or_directory
uri = static.uri
name = static.name
# If we're not trying to match a file directly,
# serve from the folder
if not path.isfile(file_or_directory):
uri += "/<file_uri>"
# special prefix for static files
# if not static.name.startswith("_static_"):
# name = f"_static_{static.name}"
_handler = wraps(self._static_request_handler)(
partial(
self._static_request_handler,
file_or_directory,
static.use_modified_since,
static.use_content_range,
static.stream_large_files,
content_type=static.content_type,
)
)
route, _ = self.route(
uri=uri,
methods=["GET", "HEAD"],
name=name,
host=static.host,
strict_slashes=static.strict_slashes,
static=True,
)(_handler)
return route

View File

@ -1,186 +0,0 @@
from functools import partial, wraps
from mimetypes import guess_type
from os import path
from pathlib import PurePath
from re import sub
from time import gmtime, strftime
from urllib.parse import unquote
from sanic.compat import stat_async
from sanic.exceptions import (
ContentRangeError,
FileNotFound,
HeaderNotFound,
InvalidUsage,
)
from sanic.handlers import ContentRangeHandler
from sanic.log import error_logger
from sanic.models.futures import FutureStatic
from sanic.response import HTTPResponse, file, file_stream
async def _static_request_handler(
file_or_directory,
use_modified_since,
use_content_range,
stream_large_files,
request,
content_type=None,
file_uri=None,
):
# Using this to determine if the URL is trying to break out of the path
# served. os.path.realpath seems to be very slow
if file_uri and "../" in file_uri:
raise InvalidUsage("Invalid URL")
# Merge served directory and requested file if provided
# Strip all / that in the beginning of the URL to help prevent python
# from herping a derp and treating the uri as an absolute path
root_path = file_path = file_or_directory
if file_uri:
file_path = path.join(file_or_directory, sub("^[/]*", "", file_uri))
# URL decode the path sent by the browser otherwise we won't be able to
# match filenames which got encoded (filenames with spaces etc)
file_path = path.abspath(unquote(file_path))
if not file_path.startswith(path.abspath(unquote(root_path))):
error_logger.exception(
f"File not found: path={file_or_directory}, "
f"relative_url={file_uri}"
)
raise FileNotFound(
"File not found", path=file_or_directory, relative_url=file_uri
)
try:
headers = {}
# Check if the client has been sent this file before
# and it has not been modified since
stats = None
if use_modified_since:
stats = await stat_async(file_path)
modified_since = strftime(
"%a, %d %b %Y %H:%M:%S GMT", gmtime(stats.st_mtime)
)
if request.headers.get("If-Modified-Since") == modified_since:
return HTTPResponse(status=304)
headers["Last-Modified"] = modified_since
_range = None
if use_content_range:
_range = None
if not stats:
stats = await stat_async(file_path)
headers["Accept-Ranges"] = "bytes"
headers["Content-Length"] = str(stats.st_size)
if request.method != "HEAD":
try:
_range = ContentRangeHandler(request, stats)
except HeaderNotFound:
pass
else:
del headers["Content-Length"]
for key, value in _range.headers.items():
headers[key] = value
headers["Content-Type"] = (
content_type or guess_type(file_path)[0] or "text/plain"
)
if request.method == "HEAD":
return HTTPResponse(headers=headers)
else:
if stream_large_files:
if type(stream_large_files) == int:
threshold = stream_large_files
else:
threshold = 1024 * 1024
if not stats:
stats = await stat_async(file_path)
if stats.st_size >= threshold:
return await file_stream(
file_path, headers=headers, _range=_range
)
return await file(file_path, headers=headers, _range=_range)
except ContentRangeError:
raise
except Exception:
error_logger.exception(
f"File not found: path={file_or_directory}, "
f"relative_url={file_uri}"
)
raise FileNotFound(
"File not found", path=file_or_directory, relative_url=file_uri
)
def register(
app,
static: FutureStatic,
):
# TODO: Though sanic is not a file server, I feel like we should at least
# make a good effort here. Modified-since is nice, but we could
# also look into etags, expires, and caching
"""
Register a static directory handler with Sanic by adding a route to the
router and registering a handler.
:param app: Sanic
:param file_or_directory: File or directory path to serve from
:type file_or_directory: Union[str,bytes,Path]
:param uri: URL to serve from
:type uri: str
:param pattern: regular expression used to match files in the URL
:param use_modified_since: If true, send file modified time, and return
not modified if the browser's matches the
server's
:param use_content_range: If true, process header for range requests
and sends the file part that is requested
:param stream_large_files: If true, use the file_stream() handler rather
than the file() handler to send the file
If this is an integer, this represents the
threshold size to switch to file_stream()
:param name: user defined name used for url_for
:type name: str
:param content_type: user defined content type for header
:return: registered static routes
:rtype: List[sanic.router.Route]
"""
if isinstance(static.file_or_directory, bytes):
file_or_directory = static.file_or_directory.decode("utf-8")
elif isinstance(static.file_or_directory, PurePath):
file_or_directory = str(static.file_or_directory)
elif not isinstance(static.file_or_directory, str):
raise ValueError("Invalid file path string.")
else:
file_or_directory = static.file_or_directory
uri = static.uri
name = static.name
# If we're not trying to match a file directly,
# serve from the folder
if not path.isfile(file_or_directory):
uri += "/<file_uri>"
# special prefix for static files
# if not static.name.startswith("_static_"):
# name = f"_static_{static.name}"
_handler = wraps(_static_request_handler)(
partial(
_static_request_handler,
file_or_directory,
static.use_modified_since,
static.use_content_range,
static.stream_large_files,
content_type=static.content_type,
)
)
route, _ = app.route(
uri=uri,
methods=["GET", "HEAD"],
name=name,
host=static.host,
strict_slashes=static.strict_slashes,
static=True,
)(_handler)
return route

View File

@ -110,6 +110,11 @@ def test_bp_group(app: Sanic):
global MIDDLEWARE_INVOKE_COUNTER
MIDDLEWARE_INVOKE_COUNTER["request"] += 1
@blueprint_group_1.middleware
def blueprint_group_1_middleware_not_called(request):
global MIDDLEWARE_INVOKE_COUNTER
MIDDLEWARE_INVOKE_COUNTER["request"] += 1
@blueprint_3.route("/")
def blueprint_3_default_route(request):
return text("BP3_OK")
@ -142,7 +147,7 @@ def test_bp_group(app: Sanic):
assert response.text == "BP3_OK"
assert MIDDLEWARE_INVOKE_COUNTER["response"] == 3
assert MIDDLEWARE_INVOKE_COUNTER["request"] == 2
assert MIDDLEWARE_INVOKE_COUNTER["request"] == 4
def test_bp_group_list_operations(app: Sanic):
@ -179,3 +184,19 @@ def test_bp_group_list_operations(app: Sanic):
assert len(blueprint_group_1) == 2
assert blueprint_group_1.url_prefix == "/bp"
def test_bp_group_as_list():
blueprint_1 = Blueprint("blueprint_1", url_prefix="/bp1")
blueprint_2 = Blueprint("blueprint_2", url_prefix="/bp2")
blueprint_group_1 = Blueprint.group([blueprint_1, blueprint_2])
assert len(blueprint_group_1) == 2
def test_bp_group_as_nested_group():
blueprint_1 = Blueprint("blueprint_1", url_prefix="/bp1")
blueprint_2 = Blueprint("blueprint_2", url_prefix="/bp2")
blueprint_group_1 = Blueprint.group(
Blueprint.group(blueprint_1, blueprint_2)
)
assert len(blueprint_group_1) == 2

View File

@ -30,6 +30,23 @@ def test_middleware_request(app):
assert type(results[0]) is Request
def test_middleware_request_as_convenience(app):
results = []
@app.on_request
async def handler1(request):
results.append(request)
@app.route("/")
async def handler2(request):
return text("OK")
request, response = app.test_client.get("/")
assert response.text == "OK"
assert type(results[0]) is Request
def test_middleware_response(app):
results = []
@ -54,6 +71,54 @@ def test_middleware_response(app):
assert isinstance(results[2], HTTPResponse)
def test_middleware_response_as_convenience(app):
results = []
@app.on_request
async def process_request(request):
results.append(request)
@app.on_response
async def process_response(request, response):
results.append(request)
results.append(response)
@app.route("/")
async def handler(request):
return text("OK")
request, response = app.test_client.get("/")
assert response.text == "OK"
assert type(results[0]) is Request
assert type(results[1]) is Request
assert isinstance(results[2], HTTPResponse)
def test_middleware_response_as_convenience_called(app):
results = []
@app.on_request()
async def process_request(request):
results.append(request)
@app.on_response()
async def process_response(request, response):
results.append(request)
results.append(response)
@app.route("/")
async def handler(request):
return text("OK")
request, response = app.test_client.get("/")
assert response.text == "OK"
assert type(results[0]) is Request
assert type(results[1]) is Request
assert isinstance(results[2], HTTPResponse)
def test_middleware_response_exception(app):
result = {"status_code": "middleware not run"}

View File

@ -633,6 +633,19 @@ def test_websocket_route(app, url):
assert ev.is_set()
def test_websocket_route_invalid_handler(app):
with pytest.raises(ValueError) as e:
@app.websocket("/")
async def handler():
...
assert e.match(
r"Required parameter `request` and/or `ws` missing in the "
r"handler\(\) route\?"
)
@pytest.mark.asyncio
@pytest.mark.parametrize("url", ["/ws", "ws"])
async def test_websocket_route_asgi(app, url):

View File

@ -8,6 +8,8 @@ import pytest
from sanic_testing.testing import HOST, PORT
from sanic.exceptions import InvalidUsage
AVAILABLE_LISTENERS = [
"before_server_start",
@ -80,6 +82,18 @@ def test_all_listeners(app):
assert app.name + listener_name == output.pop()
@skipif_no_alarm
def test_all_listeners_as_convenience(app):
output = []
for listener_name in AVAILABLE_LISTENERS:
listener = create_listener(listener_name, output)
method = getattr(app, listener_name)
method(listener)
start_stop_app(app)
for listener_name in AVAILABLE_LISTENERS:
assert app.name + listener_name == output.pop()
@pytest.mark.asyncio
async def test_trigger_before_events_create_server(app):
class MySanicDb:
@ -95,6 +109,20 @@ async def test_trigger_before_events_create_server(app):
assert isinstance(app.db, MySanicDb)
@pytest.mark.asyncio
async def test_trigger_before_events_create_server_missing_event(app):
class MySanicDb:
pass
with pytest.raises(InvalidUsage):
@app.listener
async def init_db(app, loop):
app.db = MySanicDb()
assert not hasattr(app, "db")
def test_create_server_trigger_events(app):
"""Test if create_server can trigger server events"""

View File

@ -127,6 +127,40 @@ def test_static_file_content_type(app, static_file_directory, file_name):
assert response.headers["Content-Type"] == "text/html; charset=utf-8"
@pytest.mark.parametrize(
"file_name,expected",
[
("test.html", "text/html; charset=utf-8"),
("decode me.txt", "text/plain; charset=utf-8"),
("test.file", "application/octet-stream"),
],
)
def test_static_file_content_type_guessed(
app, static_file_directory, file_name, expected
):
app.static(
"/testing.file",
get_file_path(static_file_directory, file_name),
)
request, response = app.test_client.get("/testing.file")
assert response.status == 200
assert response.body == get_file_content(static_file_directory, file_name)
assert response.headers["Content-Type"] == expected
def test_static_file_content_type_with_charset(app, static_file_directory):
app.static(
"/testing.file",
get_file_path(static_file_directory, "decode me.txt"),
content_type="text/plain;charset=ISO-8859-1",
)
request, response = app.test_client.get("/testing.file")
assert response.status == 200
assert response.headers["Content-Type"] == "text/plain;charset=ISO-8859-1"
@pytest.mark.parametrize(
"file_name", ["test.file", "decode me.txt", "symlink", "hard_link"]
)