Add some test coverage

This commit is contained in:
Adam Hopkins 2021-02-15 21:50:20 +02:00
parent 55a5ab4be1
commit 7f63ad5484
12 changed files with 378 additions and 201 deletions

View File

@ -49,7 +49,6 @@ from sanic.server import (
serve, serve,
serve_multiple, serve_multiple,
) )
from sanic.static import register as static_register
from sanic.websocket import ConnectionClosed, WebSocketProtocol from sanic.websocket import ConnectionClosed, WebSocketProtocol
@ -242,7 +241,7 @@ class Sanic(BaseSanic):
return self.router.add(**params) return self.router.add(**params)
def _apply_static(self, static: FutureStatic) -> Route: def _apply_static(self, static: FutureStatic) -> Route:
return static_register(self, static) return self._register_static(static)
def _apply_middleware( def _apply_middleware(
self, self,

View File

@ -17,7 +17,7 @@ class Header(CIMultiDict):
use_trio = argv[0].endswith("hypercorn") and "trio" in argv use_trio = argv[0].endswith("hypercorn") and "trio" in argv
if use_trio: if use_trio: # pragma: no cover
import trio # type: ignore import trio # type: ignore
def stat_async(path): def stat_async(path):

View File

@ -8,7 +8,7 @@ class ExceptionMixin:
self._future_exceptions: Set[FutureException] = set() self._future_exceptions: Set[FutureException] = set()
def _apply_exception_handler(self, handler: FutureException): def _apply_exception_handler(self, handler: FutureException):
raise NotImplementedError raise NotImplementedError # noqa
def exception(self, *exceptions, apply=True): def exception(self, *exceptions, apply=True):
""" """

View File

@ -20,7 +20,7 @@ class ListenerMixin:
self._future_listeners: List[FutureListener] = list() self._future_listeners: List[FutureListener] = list()
def _apply_listener(self, listener: FutureListener): def _apply_listener(self, listener: FutureListener):
raise NotImplementedError raise NotImplementedError # noqa
def listener(self, listener_or_event, event_or_none=None, apply=True): def listener(self, listener_or_event, event_or_none=None, apply=True):
"""Create a listener from a decorated function. """Create a listener from a decorated function.

View File

@ -9,7 +9,7 @@ class MiddlewareMixin:
self._future_middleware: List[FutureMiddleware] = list() self._future_middleware: List[FutureMiddleware] = list()
def _apply_middleware(self, middleware: FutureMiddleware): def _apply_middleware(self, middleware: FutureMiddleware):
raise NotImplementedError raise NotImplementedError # noqa
def middleware( def middleware(
self, middleware_or_request, attach_to="request", apply=True self, middleware_or_request, attach_to="request", apply=True
@ -42,8 +42,14 @@ class MiddlewareMixin:
register_middleware, attach_to=middleware_or_request register_middleware, attach_to=middleware_or_request
) )
def on_request(self, middleware): def on_request(self, middleware=None):
return self.middleware(middleware, "request") if callable(middleware):
return self.middleware(middleware, "request")
else:
return partial(self.middleware, attach_to="request")
def on_response(self, middleware): def on_response(self, middleware=None):
return self.middleware(middleware, "response") if callable(middleware):
return self.middleware(middleware, "response")
else:
return partial(self.middleware, attach_to="response")

View File

@ -1,11 +1,27 @@
from functools import partial, wraps
from inspect import signature from inspect import signature
from mimetypes import guess_type
from os import path
from pathlib import PurePath from pathlib import PurePath
from re import sub
from time import gmtime, strftime
from typing import Set, Union from typing import Set, Union
from urllib.parse import unquote
from sanic_routing.route import Route # type: ignore from sanic_routing.route import Route # type: ignore
from sanic.compat import stat_async
from sanic.constants import HTTP_METHODS from sanic.constants import HTTP_METHODS
from sanic.exceptions import (
ContentRangeError,
FileNotFound,
HeaderNotFound,
InvalidUsage,
)
from sanic.handlers import ContentRangeHandler
from sanic.log import error_logger
from sanic.models.futures import FutureRoute, FutureStatic from sanic.models.futures import FutureRoute, FutureStatic
from sanic.response import HTTPResponse, file, file_stream
from sanic.views import CompositionView from sanic.views import CompositionView
@ -17,10 +33,10 @@ class RouteMixin:
self.strict_slashes = False self.strict_slashes = False
def _apply_route(self, route: FutureRoute) -> Route: def _apply_route(self, route: FutureRoute) -> Route:
raise NotImplementedError raise NotImplementedError # noqa
def _apply_static(self, static: FutureStatic) -> Route: def _apply_static(self, static: FutureStatic) -> Route:
raise NotImplementedError raise NotImplementedError # noqa
def route( def route(
self, self,
@ -555,10 +571,191 @@ class RouteMixin:
else: else:
break break
if not name: if not name: # noq
raise Exception("...") raise ValueError("Could not generate a name for handler")
if not name.startswith(f"{self.name}."): if not name.startswith(f"{self.name}."):
name = f"{self.name}.{name}" name = f"{self.name}.{name}"
return name return name
async def _static_request_handler(
self,
file_or_directory,
use_modified_since,
use_content_range,
stream_large_files,
request,
content_type=None,
file_uri=None,
):
# Using this to determine if the URL is trying to break out of the path
# served. os.path.realpath seems to be very slow
if file_uri and "../" in file_uri:
raise InvalidUsage("Invalid URL")
# Merge served directory and requested file if provided
# Strip all / that in the beginning of the URL to help prevent python
# from herping a derp and treating the uri as an absolute path
root_path = file_path = file_or_directory
if file_uri:
file_path = path.join(
file_or_directory, sub("^[/]*", "", file_uri)
)
# URL decode the path sent by the browser otherwise we won't be able to
# match filenames which got encoded (filenames with spaces etc)
file_path = path.abspath(unquote(file_path))
if not file_path.startswith(path.abspath(unquote(root_path))):
error_logger.exception(
f"File not found: path={file_or_directory}, "
f"relative_url={file_uri}"
)
raise FileNotFound(
"File not found", path=file_or_directory, relative_url=file_uri
)
try:
headers = {}
# Check if the client has been sent this file before
# and it has not been modified since
stats = None
if use_modified_since:
stats = await stat_async(file_path)
modified_since = strftime(
"%a, %d %b %Y %H:%M:%S GMT", gmtime(stats.st_mtime)
)
if request.headers.get("If-Modified-Since") == modified_since:
return HTTPResponse(status=304)
headers["Last-Modified"] = modified_since
_range = None
if use_content_range:
_range = None
if not stats:
stats = await stat_async(file_path)
headers["Accept-Ranges"] = "bytes"
headers["Content-Length"] = str(stats.st_size)
if request.method != "HEAD":
try:
_range = ContentRangeHandler(request, stats)
except HeaderNotFound:
pass
else:
del headers["Content-Length"]
for key, value in _range.headers.items():
headers[key] = value
if "content-type" not in headers:
content_type = (
content_type
or guess_type(file_path)[0]
or "application/octet-stream"
)
if "charset=" not in content_type and (
content_type.startswith("text/")
or content_type == "application/javascript"
):
content_type += "; charset=utf-8"
headers["Content-Type"] = content_type
if request.method == "HEAD":
return HTTPResponse(headers=headers)
else:
if stream_large_files:
if type(stream_large_files) == int:
threshold = stream_large_files
else:
threshold = 1024 * 1024
if not stats:
stats = await stat_async(file_path)
if stats.st_size >= threshold:
return await file_stream(
file_path, headers=headers, _range=_range
)
return await file(file_path, headers=headers, _range=_range)
except ContentRangeError:
raise
except Exception:
error_logger.exception(
f"File not found: path={file_or_directory}, "
f"relative_url={file_uri}"
)
raise FileNotFound(
"File not found", path=file_or_directory, relative_url=file_uri
)
def _register_static(
self,
static: FutureStatic,
):
# TODO: Though sanic is not a file server, I feel like we should
# at least make a good effort here. Modified-since is nice, but
# we could also look into etags, expires, and caching
"""
Register a static directory handler with Sanic by adding a route to the
router and registering a handler.
:param app: Sanic
:param file_or_directory: File or directory path to serve from
:type file_or_directory: Union[str,bytes,Path]
:param uri: URL to serve from
:type uri: str
:param pattern: regular expression used to match files in the URL
:param use_modified_since: If true, send file modified time, and return
not modified if the browser's matches the
server's
:param use_content_range: If true, process header for range requests
and sends the file part that is requested
:param stream_large_files: If true, use the file_stream() handler
rather than the file() handler to send the file
If this is an integer, this represents the
threshold size to switch to file_stream()
:param name: user defined name used for url_for
:type name: str
:param content_type: user defined content type for header
:return: registered static routes
:rtype: List[sanic.router.Route]
"""
if isinstance(static.file_or_directory, bytes):
file_or_directory = static.file_or_directory.decode("utf-8")
elif isinstance(static.file_or_directory, PurePath):
file_or_directory = str(static.file_or_directory)
elif not isinstance(static.file_or_directory, str):
raise ValueError("Invalid file path string.")
else:
file_or_directory = static.file_or_directory
uri = static.uri
name = static.name
# If we're not trying to match a file directly,
# serve from the folder
if not path.isfile(file_or_directory):
uri += "/<file_uri>"
# special prefix for static files
# if not static.name.startswith("_static_"):
# name = f"_static_{static.name}"
_handler = wraps(self._static_request_handler)(
partial(
self._static_request_handler,
file_or_directory,
static.use_modified_since,
static.use_content_range,
static.stream_large_files,
content_type=static.content_type,
)
)
route, _ = self.route(
uri=uri,
methods=["GET", "HEAD"],
name=name,
host=static.host,
strict_slashes=static.strict_slashes,
static=True,
)(_handler)
return route

View File

@ -1,186 +0,0 @@
from functools import partial, wraps
from mimetypes import guess_type
from os import path
from pathlib import PurePath
from re import sub
from time import gmtime, strftime
from urllib.parse import unquote
from sanic.compat import stat_async
from sanic.exceptions import (
ContentRangeError,
FileNotFound,
HeaderNotFound,
InvalidUsage,
)
from sanic.handlers import ContentRangeHandler
from sanic.log import error_logger
from sanic.models.futures import FutureStatic
from sanic.response import HTTPResponse, file, file_stream
async def _static_request_handler(
file_or_directory,
use_modified_since,
use_content_range,
stream_large_files,
request,
content_type=None,
file_uri=None,
):
# Using this to determine if the URL is trying to break out of the path
# served. os.path.realpath seems to be very slow
if file_uri and "../" in file_uri:
raise InvalidUsage("Invalid URL")
# Merge served directory and requested file if provided
# Strip all / that in the beginning of the URL to help prevent python
# from herping a derp and treating the uri as an absolute path
root_path = file_path = file_or_directory
if file_uri:
file_path = path.join(file_or_directory, sub("^[/]*", "", file_uri))
# URL decode the path sent by the browser otherwise we won't be able to
# match filenames which got encoded (filenames with spaces etc)
file_path = path.abspath(unquote(file_path))
if not file_path.startswith(path.abspath(unquote(root_path))):
error_logger.exception(
f"File not found: path={file_or_directory}, "
f"relative_url={file_uri}"
)
raise FileNotFound(
"File not found", path=file_or_directory, relative_url=file_uri
)
try:
headers = {}
# Check if the client has been sent this file before
# and it has not been modified since
stats = None
if use_modified_since:
stats = await stat_async(file_path)
modified_since = strftime(
"%a, %d %b %Y %H:%M:%S GMT", gmtime(stats.st_mtime)
)
if request.headers.get("If-Modified-Since") == modified_since:
return HTTPResponse(status=304)
headers["Last-Modified"] = modified_since
_range = None
if use_content_range:
_range = None
if not stats:
stats = await stat_async(file_path)
headers["Accept-Ranges"] = "bytes"
headers["Content-Length"] = str(stats.st_size)
if request.method != "HEAD":
try:
_range = ContentRangeHandler(request, stats)
except HeaderNotFound:
pass
else:
del headers["Content-Length"]
for key, value in _range.headers.items():
headers[key] = value
headers["Content-Type"] = (
content_type or guess_type(file_path)[0] or "text/plain"
)
if request.method == "HEAD":
return HTTPResponse(headers=headers)
else:
if stream_large_files:
if type(stream_large_files) == int:
threshold = stream_large_files
else:
threshold = 1024 * 1024
if not stats:
stats = await stat_async(file_path)
if stats.st_size >= threshold:
return await file_stream(
file_path, headers=headers, _range=_range
)
return await file(file_path, headers=headers, _range=_range)
except ContentRangeError:
raise
except Exception:
error_logger.exception(
f"File not found: path={file_or_directory}, "
f"relative_url={file_uri}"
)
raise FileNotFound(
"File not found", path=file_or_directory, relative_url=file_uri
)
def register(
app,
static: FutureStatic,
):
# TODO: Though sanic is not a file server, I feel like we should at least
# make a good effort here. Modified-since is nice, but we could
# also look into etags, expires, and caching
"""
Register a static directory handler with Sanic by adding a route to the
router and registering a handler.
:param app: Sanic
:param file_or_directory: File or directory path to serve from
:type file_or_directory: Union[str,bytes,Path]
:param uri: URL to serve from
:type uri: str
:param pattern: regular expression used to match files in the URL
:param use_modified_since: If true, send file modified time, and return
not modified if the browser's matches the
server's
:param use_content_range: If true, process header for range requests
and sends the file part that is requested
:param stream_large_files: If true, use the file_stream() handler rather
than the file() handler to send the file
If this is an integer, this represents the
threshold size to switch to file_stream()
:param name: user defined name used for url_for
:type name: str
:param content_type: user defined content type for header
:return: registered static routes
:rtype: List[sanic.router.Route]
"""
if isinstance(static.file_or_directory, bytes):
file_or_directory = static.file_or_directory.decode("utf-8")
elif isinstance(static.file_or_directory, PurePath):
file_or_directory = str(static.file_or_directory)
elif not isinstance(static.file_or_directory, str):
raise ValueError("Invalid file path string.")
else:
file_or_directory = static.file_or_directory
uri = static.uri
name = static.name
# If we're not trying to match a file directly,
# serve from the folder
if not path.isfile(file_or_directory):
uri += "/<file_uri>"
# special prefix for static files
# if not static.name.startswith("_static_"):
# name = f"_static_{static.name}"
_handler = wraps(_static_request_handler)(
partial(
_static_request_handler,
file_or_directory,
static.use_modified_since,
static.use_content_range,
static.stream_large_files,
content_type=static.content_type,
)
)
route, _ = app.route(
uri=uri,
methods=["GET", "HEAD"],
name=name,
host=static.host,
strict_slashes=static.strict_slashes,
static=True,
)(_handler)
return route

View File

@ -110,6 +110,11 @@ def test_bp_group(app: Sanic):
global MIDDLEWARE_INVOKE_COUNTER global MIDDLEWARE_INVOKE_COUNTER
MIDDLEWARE_INVOKE_COUNTER["request"] += 1 MIDDLEWARE_INVOKE_COUNTER["request"] += 1
@blueprint_group_1.middleware
def blueprint_group_1_middleware_not_called(request):
global MIDDLEWARE_INVOKE_COUNTER
MIDDLEWARE_INVOKE_COUNTER["request"] += 1
@blueprint_3.route("/") @blueprint_3.route("/")
def blueprint_3_default_route(request): def blueprint_3_default_route(request):
return text("BP3_OK") return text("BP3_OK")
@ -142,7 +147,7 @@ def test_bp_group(app: Sanic):
assert response.text == "BP3_OK" assert response.text == "BP3_OK"
assert MIDDLEWARE_INVOKE_COUNTER["response"] == 3 assert MIDDLEWARE_INVOKE_COUNTER["response"] == 3
assert MIDDLEWARE_INVOKE_COUNTER["request"] == 2 assert MIDDLEWARE_INVOKE_COUNTER["request"] == 4
def test_bp_group_list_operations(app: Sanic): def test_bp_group_list_operations(app: Sanic):
@ -179,3 +184,19 @@ def test_bp_group_list_operations(app: Sanic):
assert len(blueprint_group_1) == 2 assert len(blueprint_group_1) == 2
assert blueprint_group_1.url_prefix == "/bp" assert blueprint_group_1.url_prefix == "/bp"
def test_bp_group_as_list():
blueprint_1 = Blueprint("blueprint_1", url_prefix="/bp1")
blueprint_2 = Blueprint("blueprint_2", url_prefix="/bp2")
blueprint_group_1 = Blueprint.group([blueprint_1, blueprint_2])
assert len(blueprint_group_1) == 2
def test_bp_group_as_nested_group():
blueprint_1 = Blueprint("blueprint_1", url_prefix="/bp1")
blueprint_2 = Blueprint("blueprint_2", url_prefix="/bp2")
blueprint_group_1 = Blueprint.group(
Blueprint.group(blueprint_1, blueprint_2)
)
assert len(blueprint_group_1) == 2

View File

@ -30,6 +30,23 @@ def test_middleware_request(app):
assert type(results[0]) is Request assert type(results[0]) is Request
def test_middleware_request_as_convenience(app):
results = []
@app.on_request
async def handler1(request):
results.append(request)
@app.route("/")
async def handler2(request):
return text("OK")
request, response = app.test_client.get("/")
assert response.text == "OK"
assert type(results[0]) is Request
def test_middleware_response(app): def test_middleware_response(app):
results = [] results = []
@ -54,6 +71,54 @@ def test_middleware_response(app):
assert isinstance(results[2], HTTPResponse) assert isinstance(results[2], HTTPResponse)
def test_middleware_response_as_convenience(app):
results = []
@app.on_request
async def process_request(request):
results.append(request)
@app.on_response
async def process_response(request, response):
results.append(request)
results.append(response)
@app.route("/")
async def handler(request):
return text("OK")
request, response = app.test_client.get("/")
assert response.text == "OK"
assert type(results[0]) is Request
assert type(results[1]) is Request
assert isinstance(results[2], HTTPResponse)
def test_middleware_response_as_convenience_called(app):
results = []
@app.on_request()
async def process_request(request):
results.append(request)
@app.on_response()
async def process_response(request, response):
results.append(request)
results.append(response)
@app.route("/")
async def handler(request):
return text("OK")
request, response = app.test_client.get("/")
assert response.text == "OK"
assert type(results[0]) is Request
assert type(results[1]) is Request
assert isinstance(results[2], HTTPResponse)
def test_middleware_response_exception(app): def test_middleware_response_exception(app):
result = {"status_code": "middleware not run"} result = {"status_code": "middleware not run"}

View File

@ -633,6 +633,19 @@ def test_websocket_route(app, url):
assert ev.is_set() assert ev.is_set()
def test_websocket_route_invalid_handler(app):
with pytest.raises(ValueError) as e:
@app.websocket("/")
async def handler():
...
assert e.match(
r"Required parameter `request` and/or `ws` missing in the "
r"handler\(\) route\?"
)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("url", ["/ws", "ws"]) @pytest.mark.parametrize("url", ["/ws", "ws"])
async def test_websocket_route_asgi(app, url): async def test_websocket_route_asgi(app, url):

View File

@ -8,6 +8,8 @@ import pytest
from sanic_testing.testing import HOST, PORT from sanic_testing.testing import HOST, PORT
from sanic.exceptions import InvalidUsage
AVAILABLE_LISTENERS = [ AVAILABLE_LISTENERS = [
"before_server_start", "before_server_start",
@ -80,6 +82,18 @@ def test_all_listeners(app):
assert app.name + listener_name == output.pop() assert app.name + listener_name == output.pop()
@skipif_no_alarm
def test_all_listeners_as_convenience(app):
output = []
for listener_name in AVAILABLE_LISTENERS:
listener = create_listener(listener_name, output)
method = getattr(app, listener_name)
method(listener)
start_stop_app(app)
for listener_name in AVAILABLE_LISTENERS:
assert app.name + listener_name == output.pop()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_trigger_before_events_create_server(app): async def test_trigger_before_events_create_server(app):
class MySanicDb: class MySanicDb:
@ -95,6 +109,20 @@ async def test_trigger_before_events_create_server(app):
assert isinstance(app.db, MySanicDb) assert isinstance(app.db, MySanicDb)
@pytest.mark.asyncio
async def test_trigger_before_events_create_server_missing_event(app):
class MySanicDb:
pass
with pytest.raises(InvalidUsage):
@app.listener
async def init_db(app, loop):
app.db = MySanicDb()
assert not hasattr(app, "db")
def test_create_server_trigger_events(app): def test_create_server_trigger_events(app):
"""Test if create_server can trigger server events""" """Test if create_server can trigger server events"""

View File

@ -127,6 +127,40 @@ def test_static_file_content_type(app, static_file_directory, file_name):
assert response.headers["Content-Type"] == "text/html; charset=utf-8" assert response.headers["Content-Type"] == "text/html; charset=utf-8"
@pytest.mark.parametrize(
"file_name,expected",
[
("test.html", "text/html; charset=utf-8"),
("decode me.txt", "text/plain; charset=utf-8"),
("test.file", "application/octet-stream"),
],
)
def test_static_file_content_type_guessed(
app, static_file_directory, file_name, expected
):
app.static(
"/testing.file",
get_file_path(static_file_directory, file_name),
)
request, response = app.test_client.get("/testing.file")
assert response.status == 200
assert response.body == get_file_content(static_file_directory, file_name)
assert response.headers["Content-Type"] == expected
def test_static_file_content_type_with_charset(app, static_file_directory):
app.static(
"/testing.file",
get_file_path(static_file_directory, "decode me.txt"),
content_type="text/plain;charset=ISO-8859-1",
)
request, response = app.test_client.get("/testing.file")
assert response.status == 200
assert response.headers["Content-Type"] == "text/plain;charset=ISO-8859-1"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"file_name", ["test.file", "decode me.txt", "symlink", "hard_link"] "file_name", ["test.file", "decode me.txt", "symlink", "hard_link"]
) )