import asyncio import inspect import os import warnings from collections import namedtuple from mimetypes import guess_type from random import choice from unittest.mock import MagicMock from urllib.parse import unquote import pytest from aiofiles import os as async_os from sanic.response import ( HTTPResponse, StreamingHTTPResponse, empty, file, file_stream, json, raw, stream, text, ) from sanic.server import HttpProtocol from sanic.testing import HOST, PORT JSON_DATA = {"ok": True} @pytest.mark.filterwarnings("ignore:Types other than str will be") def test_response_body_not_a_string(app): """Test when a response body sent from the application is not a string""" random_num = choice(range(1000)) @app.route("/hello") async def hello_route(request): return text(random_num) request, response = app.test_client.get("/hello") assert response.text == str(random_num) async def sample_streaming_fn(response): await response.write("foo,") await asyncio.sleep(0.001) await response.write("bar") def test_method_not_allowed(app): @app.get("/") async def test_get(request): return response.json({"hello": "world"}) request, response = app.test_client.head("/") assert response.headers["Allow"] == "GET" request, response = app.test_client.post("/") assert response.headers["Allow"] == "GET" @app.post("/") async def test_post(request): return response.json({"hello": "world"}) request, response = app.test_client.head("/") assert response.status == 405 assert set(response.headers["Allow"].split(", ")) == {"GET", "POST"} assert response.headers["Content-Length"] == "0" request, response = app.test_client.patch("/") assert response.status == 405 assert set(response.headers["Allow"].split(", ")) == {"GET", "POST"} assert response.headers["Content-Length"] == "0" def test_response_header(app): @app.get("/") async def test(request): return json({"ok": True}, headers={"CONTENT-TYPE": "application/json"}) request, response = app.test_client.get("/") assert dict(response.headers) == { "connection": "keep-alive", "keep-alive": str(app.config.KEEP_ALIVE_TIMEOUT), "content-length": "11", "content-type": "application/json", } def test_response_content_length(app): @app.get("/response_with_space") async def response_with_space(request): return json( {"message": "Data", "details": "Some Details"}, headers={"CONTENT-TYPE": "application/json"}, ) @app.get("/response_without_space") async def response_without_space(request): return json( {"message": "Data", "details": "Some Details"}, headers={"CONTENT-TYPE": "application/json"}, ) _, response = app.test_client.get("/response_with_space") content_length_for_response_with_space = response.headers.get( "Content-Length" ) _, response = app.test_client.get("/response_without_space") content_length_for_response_without_space = response.headers.get( "Content-Length" ) assert ( content_length_for_response_with_space == content_length_for_response_without_space ) assert content_length_for_response_with_space == "43" def test_response_content_length_with_different_data_types(app): @app.get("/") async def get_data_with_different_types(request): # Indentation issues in the Response is intentional. Please do not fix return json( {"bool": True, "none": None, "string": "string", "number": -1}, headers={"CONTENT-TYPE": "application/json"}, ) _, response = app.test_client.get("/") assert response.headers.get("Content-Length") == "55" @pytest.fixture def json_app(app): @app.route("/") async def test(request): return json(JSON_DATA) @app.get("/no-content") async def no_content_handler(request): return json(JSON_DATA, status=204) @app.get("/no-content/unmodified") async def no_content_unmodified_handler(request): return json(None, status=304) @app.get("/unmodified") async def unmodified_handler(request): return json(JSON_DATA, status=304) @app.delete("/") async def delete_handler(request): return json(None, status=204) return app def test_json_response(json_app): from sanic.response import json_dumps request, response = json_app.test_client.get("/") assert response.status == 200 assert response.text == json_dumps(JSON_DATA) assert response.json == JSON_DATA def test_no_content(json_app): request, response = json_app.test_client.get("/no-content") assert response.status == 204 assert response.text == "" assert "Content-Length" not in response.headers request, response = json_app.test_client.get("/no-content/unmodified") assert response.status == 304 assert response.text == "" assert "Content-Length" not in response.headers assert "Content-Type" not in response.headers request, response = json_app.test_client.get("/unmodified") assert response.status == 304 assert response.text == "" assert "Content-Length" not in response.headers assert "Content-Type" not in response.headers request, response = json_app.test_client.delete("/") assert response.status == 204 assert response.text == "" assert "Content-Length" not in response.headers @pytest.fixture def streaming_app(app): @app.route("/") async def test(request): return stream( sample_streaming_fn, headers={"Content-Length": "7"}, content_type="text/csv", ) return app @pytest.fixture def non_chunked_streaming_app(app): @app.route("/") async def test(request): return stream( sample_streaming_fn, headers={"Content-Length": "7"}, content_type="text/csv", chunked=False, ) return app def test_chunked_streaming_adds_correct_headers(streaming_app): request, response = streaming_app.test_client.get("/") assert response.headers["Transfer-Encoding"] == "chunked" assert response.headers["Content-Type"] == "text/csv" # Content-Length is not allowed by HTTP/1.1 specification # when "Transfer-Encoding: chunked" is used assert "Content-Length" not in response.headers def test_chunked_streaming_returns_correct_content(streaming_app): request, response = streaming_app.test_client.get("/") assert response.text == "foo,bar" def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app): request, response = non_chunked_streaming_app.test_client.get("/") assert "Transfer-Encoding" not in response.headers assert response.headers["Content-Type"] == "text/csv" assert response.headers["Content-Length"] == "7" def test_non_chunked_streaming_returns_correct_content( non_chunked_streaming_app, ): request, response = non_chunked_streaming_app.test_client.get("/") assert response.text == "foo,bar" @pytest.mark.parametrize("status", [200, 201, 400, 401]) def test_stream_response_status_returns_correct_headers(status): response = StreamingHTTPResponse(sample_streaming_fn, status=status) headers = response.get_headers() assert b"HTTP/1.1 %s" % str(status).encode() in headers @pytest.mark.parametrize("keep_alive_timeout", [10, 20, 30]) def test_stream_response_keep_alive_returns_correct_headers( keep_alive_timeout, ): response = StreamingHTTPResponse(sample_streaming_fn) headers = response.get_headers( keep_alive=True, keep_alive_timeout=keep_alive_timeout ) assert b"Keep-Alive: %s\r\n" % str(keep_alive_timeout).encode() in headers def test_stream_response_includes_chunked_header_http11(): response = StreamingHTTPResponse(sample_streaming_fn) headers = response.get_headers(version="1.1") assert b"Transfer-Encoding: chunked\r\n" in headers def test_stream_response_does_not_include_chunked_header_http10(): response = StreamingHTTPResponse(sample_streaming_fn) headers = response.get_headers(version="1.0") assert b"Transfer-Encoding: chunked\r\n" not in headers def test_stream_response_does_not_include_chunked_header_if_disabled(): response = StreamingHTTPResponse(sample_streaming_fn, chunked=False) headers = response.get_headers(version="1.1") assert b"Transfer-Encoding: chunked\r\n" not in headers def test_stream_response_writes_correct_content_to_transport_when_chunked( streaming_app, ): response = StreamingHTTPResponse(sample_streaming_fn) response.protocol = MagicMock(HttpProtocol) response.protocol.transport = MagicMock(asyncio.Transport) async def mock_drain(): pass async def mock_push_data(data): response.protocol.transport.write(data) response.protocol.push_data = mock_push_data response.protocol.drain = mock_drain @streaming_app.listener("after_server_start") async def run_stream(app, loop): await response.stream() assert response.protocol.transport.write.call_args_list[1][0][0] == ( b"4\r\nfoo,\r\n" ) assert response.protocol.transport.write.call_args_list[2][0][0] == ( b"3\r\nbar\r\n" ) assert response.protocol.transport.write.call_args_list[3][0][0] == ( b"0\r\n\r\n" ) assert len(response.protocol.transport.write.call_args_list) == 4 app.stop() streaming_app.run(host=HOST, port=PORT) def test_stream_response_writes_correct_content_to_transport_when_not_chunked( streaming_app, ): response = StreamingHTTPResponse(sample_streaming_fn) response.protocol = MagicMock(HttpProtocol) response.protocol.transport = MagicMock(asyncio.Transport) async def mock_drain(): pass async def mock_push_data(data): response.protocol.transport.write(data) response.protocol.push_data = mock_push_data response.protocol.drain = mock_drain @streaming_app.listener("after_server_start") async def run_stream(app, loop): await response.stream(version="1.0") assert response.protocol.transport.write.call_args_list[1][0][0] == ( b"foo," ) assert response.protocol.transport.write.call_args_list[2][0][0] == ( b"bar" ) assert len(response.protocol.transport.write.call_args_list) == 3 app.stop() streaming_app.run(host=HOST, port=PORT) def test_stream_response_with_cookies(app): @app.route("/") async def test(request): response = stream(sample_streaming_fn, content_type="text/csv") response.cookies["test"] = "modified" response.cookies["test"] = "pass" return response request, response = app.test_client.get("/") assert response.cookies["test"] == "pass" def test_stream_response_without_cookies(app): @app.route("/") async def test(request): return stream(sample_streaming_fn, content_type="text/csv") request, response = app.test_client.get("/") assert response.cookies == {} @pytest.fixture def static_file_directory(): """The static directory to serve""" current_file = inspect.getfile(inspect.currentframe()) current_directory = os.path.dirname(os.path.abspath(current_file)) static_directory = os.path.join(current_directory, "static") return static_directory def get_file_content(static_file_directory, file_name): """The content of the static file to check""" with open(os.path.join(static_file_directory, file_name), "rb") as file: return file.read() @pytest.mark.parametrize( "file_name", ["test.file", "decode me.txt", "python.png"] ) @pytest.mark.parametrize("status", [200, 401]) def test_file_response(app, file_name, static_file_directory, status): @app.route("/files/", methods=["GET"]) def file_route(request, filename): file_path = os.path.join(static_file_directory, filename) file_path = os.path.abspath(unquote(file_path)) return file( file_path, status=status, mime_type=guess_type(file_path)[0] or "text/plain", ) request, response = app.test_client.get(f"/files/{file_name}") assert response.status == status assert response.body == get_file_content(static_file_directory, file_name) assert "Content-Disposition" not in response.headers @pytest.mark.parametrize( "source,dest", [ ("test.file", "my_file.txt"), ("decode me.txt", "readme.md"), ("python.png", "logo.png"), ], ) def test_file_response_custom_filename( app, source, dest, static_file_directory ): @app.route("/files/", methods=["GET"]) def file_route(request, filename): file_path = os.path.join(static_file_directory, filename) file_path = os.path.abspath(unquote(file_path)) return file(file_path, filename=dest) request, response = app.test_client.get(f"/files/{source}") assert response.status == 200 assert response.body == get_file_content(static_file_directory, source) assert ( response.headers["Content-Disposition"] == f'attachment; filename="{dest}"' ) @pytest.mark.parametrize("file_name", ["test.file", "decode me.txt"]) def test_file_head_response(app, file_name, static_file_directory): @app.route("/files/", methods=["GET", "HEAD"]) async def file_route(request, filename): file_path = os.path.join(static_file_directory, filename) file_path = os.path.abspath(unquote(file_path)) stats = await async_os.stat(file_path) headers = dict() headers["Accept-Ranges"] = "bytes" headers["Content-Length"] = str(stats.st_size) if request.method == "HEAD": return HTTPResponse( headers=headers, content_type=guess_type(file_path)[0] or "text/plain", ) else: return file( file_path, headers=headers, mime_type=guess_type(file_path)[0] or "text/plain", ) request, response = app.test_client.head(f"/files/{file_name}") assert response.status == 200 assert "Accept-Ranges" in response.headers assert "Content-Length" in response.headers assert int(response.headers["Content-Length"]) == len( get_file_content(static_file_directory, file_name) ) @pytest.mark.parametrize( "file_name", ["test.file", "decode me.txt", "python.png"] ) def test_file_stream_response(app, file_name, static_file_directory): @app.route("/files/", methods=["GET"]) def file_route(request, filename): file_path = os.path.join(static_file_directory, filename) file_path = os.path.abspath(unquote(file_path)) return file_stream( file_path, chunk_size=32, mime_type=guess_type(file_path)[0] or "text/plain", ) request, response = app.test_client.get(f"/files/{file_name}") assert response.status == 200 assert response.body == get_file_content(static_file_directory, file_name) assert "Content-Disposition" not in response.headers @pytest.mark.parametrize( "source,dest", [ ("test.file", "my_file.txt"), ("decode me.txt", "readme.md"), ("python.png", "logo.png"), ], ) def test_file_stream_response_custom_filename( app, source, dest, static_file_directory ): @app.route("/files/", methods=["GET"]) def file_route(request, filename): file_path = os.path.join(static_file_directory, filename) file_path = os.path.abspath(unquote(file_path)) return file_stream(file_path, chunk_size=32, filename=dest) request, response = app.test_client.get(f"/files/{source}") assert response.status == 200 assert response.body == get_file_content(static_file_directory, source) assert ( response.headers["Content-Disposition"] == f'attachment; filename="{dest}"' ) @pytest.mark.parametrize("file_name", ["test.file", "decode me.txt"]) def test_file_stream_head_response(app, file_name, static_file_directory): @app.route("/files/", methods=["GET", "HEAD"]) async def file_route(request, filename): file_path = os.path.join(static_file_directory, filename) file_path = os.path.abspath(unquote(file_path)) headers = dict() headers["Accept-Ranges"] = "bytes" if request.method == "HEAD": # Return a normal HTTPResponse, not a # StreamingHTTPResponse for a HEAD request stats = await async_os.stat(file_path) headers["Content-Length"] = str(stats.st_size) return HTTPResponse( headers=headers, content_type=guess_type(file_path)[0] or "text/plain", ) else: return file_stream( file_path, chunk_size=32, headers=headers, mime_type=guess_type(file_path)[0] or "text/plain", ) request, response = app.test_client.head(f"/files/{file_name}") assert response.status == 200 # A HEAD request should never be streamed/chunked. if "Transfer-Encoding" in response.headers: assert response.headers["Transfer-Encoding"] != "chunked" assert "Accept-Ranges" in response.headers # A HEAD request should get the Content-Length too assert "Content-Length" in response.headers assert int(response.headers["Content-Length"]) == len( get_file_content(static_file_directory, file_name) ) @pytest.mark.parametrize( "file_name", ["test.file", "decode me.txt", "python.png"] ) @pytest.mark.parametrize( "size,start,end", [(1024, 0, 1024), (4096, 1024, 8192)] ) def test_file_stream_response_range( app, file_name, static_file_directory, size, start, end ): Range = namedtuple("Range", ["size", "start", "end", "total"]) total = len(get_file_content(static_file_directory, file_name)) range = Range(size=size, start=start, end=end, total=total) @app.route("/files/", methods=["GET"]) def file_route(request, filename): file_path = os.path.join(static_file_directory, filename) file_path = os.path.abspath(unquote(file_path)) return file_stream( file_path, chunk_size=32, mime_type=guess_type(file_path)[0] or "text/plain", _range=range, ) request, response = app.test_client.get(f"/files/{file_name}") assert response.status == 206 assert "Content-Range" in response.headers assert ( response.headers["Content-Range"] == f"bytes {range.start}-{range.end}/{range.total}" ) def test_raw_response(app): @app.get("/test") def handler(request): return raw(b"raw_response") request, response = app.test_client.get("/test") assert response.content_type == "application/octet-stream" assert response.body == b"raw_response" def test_empty_response(app): @app.get("/test") def handler(request): return empty() request, response = app.test_client.get("/test") assert response.content_type is None assert response.body == b"" def test_response_body_bytes_deprecated(app): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") HTTPResponse(body_bytes=b'bytes') assert len(w) == 1 assert issubclass(w[0].category, DeprecationWarning) assert ( "Parameter `body_bytes` is deprecated, use `body` instead" in str(w[0].message) )