Merge branch 'master' into streaming

This commit is contained in:
L. Kärkkäinen 2020-03-26 08:59:33 +02:00
commit 1aac4f546b
18 changed files with 254 additions and 83 deletions

View File

@ -95,6 +95,7 @@ class Sanic:
self.sock = None
self.strict_slashes = strict_slashes
self.listeners = defaultdict(list)
self.is_stopping = False
self.is_running = False
self.websocket_enabled = False
self.websocket_tasks = set()
@ -1181,6 +1182,7 @@ class Sanic:
try:
self.is_running = True
self.is_stopping = False
if workers > 1 and os.name != "posix":
logger.warn(
f"Multiprocessing is currently not supported on {os.name},"
@ -1213,6 +1215,8 @@ class Sanic:
def stop(self):
"""This kills the Sanic"""
if not self.is_stopping:
self.is_stopping = True
get_event_loop().stop()
async def create_server(

View File

@ -1,17 +1,11 @@
from asyncio import CancelledError
import asyncio
import signal
from sys import argv
from multidict import CIMultiDict # type: ignore
try:
from trio import Cancelled # type: ignore
CancelledErrors = tuple([CancelledError, Cancelled])
except ImportError:
CancelledErrors = tuple([CancelledError])
class Header(CIMultiDict):
def get_all(self, key):
return self.getall(key, default=[])
@ -20,15 +14,42 @@ class Header(CIMultiDict):
use_trio = argv[0].endswith("hypercorn") and "trio" in argv
if use_trio:
from trio import open_file as open_async, Path # type: ignore
import trio # type: ignore
def stat_async(path):
return Path(path).stat()
return trio.Path(path).stat()
open_async = trio.open_file
CancelledErrors = tuple([asyncio.CancelledError, trio.Cancelled])
else:
from aiofiles import open as aio_open # type: ignore
from aiofiles.os import stat as stat_async # type: ignore # noqa: F401
async def open_async(file, mode="r", **kwargs):
return aio_open(file, mode, **kwargs)
CancelledErrors = tuple([asyncio.CancelledError])
def ctrlc_workaround_for_windows(app):
async def stay_active(app):
"""Asyncio wakeups to allow receiving SIGINT in Python"""
while not die:
# If someone else stopped the app, just exit
if app.is_stopping:
return
# Windows Python blocks signal handlers while the event loop is
# waiting for I/O. Frequent wakeups keep interrupts flowing.
await asyncio.sleep(0.1)
# Can't be called from signal handler, so call it from here
app.stop()
def ctrlc_handler(sig, frame):
nonlocal die
if die:
raise KeyboardInterrupt("Non-graceful Ctrl+C")
die = True
die = False
signal.signal(signal.SIGINT, ctrlc_handler)
app.add_task(stay_active)

View File

@ -84,6 +84,7 @@ class StreamingHTTPResponse(BaseHTTPResponse):
self.status = status
self.headers = Header(headers or {})
self._cookies = None
self.protocol = None
async def write(self, data):
"""Writes a chunk of data to the streaming response.

View File

@ -11,6 +11,7 @@ from signal import signal as signal_func
from socket import SO_REUSEADDR, SOL_SOCKET, socket
from time import monotonic as current_time
from sanic.compat import ctrlc_workaround_for_windows
from sanic.exceptions import RequestTimeout, ServiceUnavailable
from sanic.http import Http, Stage
from sanic.log import logger
@ -25,6 +26,8 @@ try:
except ImportError:
pass
OS_IS_WINDOWS = os.name == "nt"
class Signal:
stopped = False
@ -530,15 +533,11 @@ def serve(
# Register signals for graceful termination
if register_sys_signals:
_singals = (SIGTERM,) if run_multiple else (SIGINT, SIGTERM)
for _signal in _singals:
try:
loop.add_signal_handler(_signal, loop.stop)
except NotImplementedError:
logger.warning(
"Sanic tried to use loop.add_signal_handler "
"but it is not implemented on this platform."
)
if OS_IS_WINDOWS:
ctrlc_workaround_for_windows(app)
else:
for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]:
loop.add_signal_handler(_signal, app.stop)
pid = os.getpid()
try:
logger.info("Starting worker [%s]", pid)

View File

@ -12,7 +12,7 @@ from sanic.response import text
ASGI_HOST = "mockserver"
HOST = "127.0.0.1"
PORT = 42101
PORT = None
class SanicTestClient:
@ -99,7 +99,7 @@ class SanicTestClient:
if self.port:
server_kwargs = dict(
host=host or self.host, port=self.port, **server_kwargs
host=host or self.host, port=self.port, **server_kwargs,
)
host, port = host or self.host, self.port
else:
@ -107,6 +107,7 @@ class SanicTestClient:
sock.bind((host or self.host, 0))
server_kwargs = dict(sock=sock, **server_kwargs)
host, port = sock.getsockname()
self.port = port
if uri.startswith(
("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")
@ -118,6 +119,9 @@ class SanicTestClient:
url = "{scheme}://{host}:{port}{uri}".format(
scheme=scheme, host=host, port=port, uri=uri
)
# Tests construct URLs using PORT = None, which means random port not
# known until this function is called, so fix that here
url = url.replace(":None/", f":{port}/")
@self.app.listener("after_server_start")
async def _collect_response(sanic, loop):
@ -207,7 +211,7 @@ class SanicASGITestClient(httpx.AsyncClient):
self.app = app
dispatch = SanicASGIDispatch(app=app, client=(ASGI_HOST, PORT))
dispatch = SanicASGIDispatch(app=app, client=(ASGI_HOST, PORT or 0))
super().__init__(dispatch=dispatch, base_url=base_url)
self.last_request = None

View File

@ -6,6 +6,7 @@ from inspect import isawaitable
import pytest
from sanic import Sanic
from sanic.exceptions import SanicException
from sanic.response import text
@ -48,6 +49,7 @@ def test_asyncio_server_no_start_serving(app):
if not uvloop_installed():
loop = asyncio.get_event_loop()
asyncio_srv_coro = app.create_server(
port=43123,
return_asyncio_server=True,
asyncio_server_kwargs=dict(start_serving=False),
)
@ -61,6 +63,7 @@ def test_asyncio_server_start_serving(app):
if not uvloop_installed():
loop = asyncio.get_event_loop()
asyncio_srv_coro = app.create_server(
port=43124,
return_asyncio_server=True,
asyncio_server_kwargs=dict(start_serving=False),
)
@ -199,10 +202,17 @@ def test_handle_request_with_nested_sanic_exception(app, monkeypatch, caplog):
with caplog.at_level(logging.ERROR):
request, response = app.test_client.get("/")
port = request.server_port
assert port > 0
assert response.status == 500
assert "Mock SanicException" in response.text
assert (
"sanic.root",
logging.ERROR,
"Exception occurred while handling uri: 'http://127.0.0.1:42101/'",
f"Exception occurred while handling uri: 'http://127.0.0.1:{port}/'",
) in caplog.record_tuples
def test_app_name_required():
with pytest.deprecated_call():
Sanic()

View File

@ -221,7 +221,7 @@ async def test_request_class_custom():
class MyCustomRequest(Request):
pass
app = Sanic(request_class=MyCustomRequest)
app = Sanic(name=__name__, request_class=MyCustomRequest)
@app.get("/custom")
def custom_request(request):

View File

@ -44,42 +44,42 @@ def test_load_from_object_string_exception(app):
def test_auto_load_env():
environ["SANIC_TEST_ANSWER"] = "42"
app = Sanic()
app = Sanic(name=__name__)
assert app.config.TEST_ANSWER == 42
del environ["SANIC_TEST_ANSWER"]
def test_auto_load_bool_env():
environ["SANIC_TEST_ANSWER"] = "True"
app = Sanic()
app = Sanic(name=__name__)
assert app.config.TEST_ANSWER == True
del environ["SANIC_TEST_ANSWER"]
def test_dont_load_env():
environ["SANIC_TEST_ANSWER"] = "42"
app = Sanic(load_env=False)
app = Sanic(name=__name__, load_env=False)
assert getattr(app.config, "TEST_ANSWER", None) is None
del environ["SANIC_TEST_ANSWER"]
def test_load_env_prefix():
environ["MYAPP_TEST_ANSWER"] = "42"
app = Sanic(load_env="MYAPP_")
app = Sanic(name=__name__, load_env="MYAPP_")
assert app.config.TEST_ANSWER == 42
del environ["MYAPP_TEST_ANSWER"]
def test_load_env_prefix_float_values():
environ["MYAPP_TEST_ROI"] = "2.3"
app = Sanic(load_env="MYAPP_")
app = Sanic(name=__name__, load_env="MYAPP_")
assert app.config.TEST_ROI == 2.3
del environ["MYAPP_TEST_ROI"]
def test_load_env_prefix_string_value():
environ["MYAPP_TEST_TOKEN"] = "somerandomtesttoken"
app = Sanic(load_env="MYAPP_")
app = Sanic(name=__name__, load_env="MYAPP_")
assert app.config.TEST_TOKEN == "somerandomtesttoken"
del environ["MYAPP_TEST_TOKEN"]

View File

@ -30,7 +30,7 @@ def test_deprecated_custom_request():
Sanic(request_class=DeprecCustomRequest)
def test_custom_request():
app = Sanic(request_class=CustomRequest)
app = Sanic(name=__name__, request_class=CustomRequest)
@app.route("/post", methods=["POST"])
async def post_handler(request):

View File

@ -7,12 +7,12 @@ import httpx
from sanic import Sanic, server
from sanic.response import text
from sanic.testing import HOST, PORT, SanicTestClient
from sanic.testing import HOST, SanicTestClient
CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True}
old_conn = None
PORT = 42101 # test_keep_alive_timeout_reuse doesn't work with random port
class ReusableSanicConnectionPool(

View File

@ -1,4 +1,5 @@
import logging
import os
import uuid
from importlib import reload
@ -12,6 +13,7 @@ import sanic
from sanic import Sanic
from sanic.log import LOGGING_CONFIG_DEFAULTS, logger
from sanic.response import text
from sanic.testing import SanicTestClient
logging_format = """module: %(module)s; \
@ -126,7 +128,7 @@ def test_log_connection_lost(app, debug, monkeypatch):
def test_logger(caplog):
rand_string = str(uuid.uuid4())
app = Sanic()
app = Sanic(name=__name__)
@app.get("/")
def log_info(request):
@ -136,15 +138,67 @@ def test_logger(caplog):
with caplog.at_level(logging.INFO):
request, response = app.test_client.get("/")
port = request.server_port
# Note: testing with random port doesn't show the banner because it doesn't
# define host and port. This test supports both modes.
if caplog.record_tuples[0] == (
"sanic.root",
logging.INFO,
f"Goin' Fast @ http://127.0.0.1:{port}",
):
caplog.record_tuples.pop(0)
assert caplog.record_tuples[0] == (
"sanic.root",
logging.INFO,
"Goin' Fast @ http://127.0.0.1:42101",
f"http://127.0.0.1:{port}/",
)
assert caplog.record_tuples[1] == ("sanic.root", logging.INFO, rand_string)
assert caplog.record_tuples[-1] == (
"sanic.root",
logging.INFO,
"Server Stopped",
)
def test_logger_static_and_secure(caplog):
# Same as test_logger, except for more coverage:
# - test_client initialised separately for static port
# - using ssl
rand_string = str(uuid.uuid4())
app = Sanic(name=__name__)
@app.get("/")
def log_info(request):
logger.info(rand_string)
return text("hello")
current_dir = os.path.dirname(os.path.realpath(__file__))
ssl_cert = os.path.join(current_dir, "certs/selfsigned.cert")
ssl_key = os.path.join(current_dir, "certs/selfsigned.key")
ssl_dict = {"cert": ssl_cert, "key": ssl_key}
test_client = SanicTestClient(app, port=42101)
with caplog.at_level(logging.INFO):
request, response = test_client.get(
f"https://127.0.0.1:{test_client.port}/",
server_kwargs=dict(ssl=ssl_dict),
)
port = test_client.port
assert caplog.record_tuples[0] == (
"sanic.root",
logging.INFO,
f"Goin' Fast @ https://127.0.0.1:{port}",
)
assert caplog.record_tuples[1] == (
"sanic.root",
logging.INFO,
"http://127.0.0.1:42101/",
f"https://127.0.0.1:{port}/",
)
assert caplog.record_tuples[2] == ("sanic.root", logging.INFO, rand_string)
assert caplog.record_tuples[-1] == (

View File

@ -49,10 +49,10 @@ def test_logo_false(app, caplog):
loop.run_until_complete(_server.wait_closed())
app.stop()
banner, port = caplog.record_tuples[ROW][2].rsplit(":", 1)
assert caplog.record_tuples[ROW][1] == logging.INFO
assert caplog.record_tuples[ROW][
2
] == f"Goin' Fast @ http://127.0.0.1:{PORT}"
assert banner == "Goin' Fast @ http://127.0.0.1"
assert int(port) > 0
def test_logo_true(app, caplog):

View File

@ -7,7 +7,7 @@ from sanic.views import CompositionView, HTTPMethodView
from sanic.views import stream as stream_decorator
data = "abc" * 10000000
data = "abc" * 1_000_000
def test_request_stream_method_view(app):

View File

@ -12,7 +12,7 @@ from sanic import Blueprint, Sanic
from sanic.exceptions import ServerError
from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, Request, RequestParameters
from sanic.response import html, json, text
from sanic.testing import ASGI_HOST, HOST, PORT
from sanic.testing import ASGI_HOST, HOST, PORT, SanicTestClient
# ------------------------------------------------------------ #
@ -1029,7 +1029,7 @@ def test_url_attributes_no_ssl(app, path, query, expected_url):
app.add_route(handler, path)
request, response = app.test_client.get(path + f"?{query}")
assert request.url == expected_url.format(HOST, PORT)
assert request.url == expected_url.format(HOST, request.server_port)
parsed = urlparse(request.url)
@ -1086,11 +1086,12 @@ def test_url_attributes_with_ssl_context(app, path, query, expected_url):
app.add_route(handler, path)
port = app.test_client.port
request, response = app.test_client.get(
f"https://{HOST}:{PORT}" + path + f"?{query}",
server_kwargs={"ssl": context},
)
assert request.url == expected_url.format(HOST, PORT)
assert request.url == expected_url.format(HOST, request.server_port)
parsed = urlparse(request.url)
@ -1125,7 +1126,7 @@ def test_url_attributes_with_ssl_dict(app, path, query, expected_url):
f"https://{HOST}:{PORT}" + path + f"?{query}",
server_kwargs={"ssl": ssl_dict},
)
assert request.url == expected_url.format(HOST, PORT)
assert request.url == expected_url.format(HOST, request.server_port)
parsed = urlparse(request.url)
@ -1917,8 +1918,9 @@ def test_request_server_port(app):
def handler(request):
return text("OK")
request, response = app.test_client.get("/", headers={"Host": "my-server"})
assert request.server_port == app.test_client.port
test_client = SanicTestClient(app)
request, response = test_client.get("/", headers={"Host": "my-server"})
assert request.server_port == test_client.port
def test_request_server_port_in_host_header(app):
@ -1939,6 +1941,9 @@ def test_request_server_port_in_host_header(app):
request, response = app.test_client.get(
"/", headers={"Host": "mal_formed:5555"}
)
if PORT is None:
assert request.server_port != 5555
else:
assert request.server_port == app.test_client.port
@ -1979,7 +1984,7 @@ def test_server_name_and_url_for(app):
request, response = app.test_client.get("/foo")
assert (
request.url_for("handler")
== f"http://my-server:{app.test_client.port}/foo"
== f"http://my-server:{request.server_port}/foo"
)
app.config.SERVER_NAME = "https://my-server/path"
@ -2040,7 +2045,7 @@ async def test_request_form_invalid_content_type_asgi(app):
def test_endpoint_basic():
app = Sanic()
app = Sanic(name=__name__)
@app.route("/")
def my_unique_handler(request):
@ -2053,7 +2058,7 @@ def test_endpoint_basic():
@pytest.mark.asyncio
async def test_endpoint_basic_asgi():
app = Sanic()
app = Sanic(name=__name__)
@app.route("/")
def my_unique_handler(request):
@ -2132,5 +2137,5 @@ def test_url_for_without_server_name(app):
request, response = app.test_client.get("/sample")
assert (
response.json["url"]
== f"http://127.0.0.1:{app.test_client.port}/url-for"
== f"http://127.0.0.1:{request.server_port}/url-for"
)

View File

@ -6,6 +6,7 @@ from sanic import Sanic
from sanic.constants import HTTP_METHODS
from sanic.response import json, text
from sanic.router import ParameterNameConflicts, RouteDoesNotExist, RouteExists
from sanic.testing import SanicTestClient
# ------------------------------------------------------------ #
@ -163,35 +164,36 @@ def test_route_optional_slash(app):
def test_route_strict_slashes_set_to_false_and_host_is_a_list(app):
# Part of regression test for issue #1120
site1 = f"127.0.0.1:{app.test_client.port}"
test_client = SanicTestClient(app, port=42101)
site1 = f"127.0.0.1:{test_client.port}"
# before fix, this raises a RouteExists error
@app.get("/get", host=[site1, "site2.com"], strict_slashes=False)
def get_handler(request):
return text("OK")
request, response = app.test_client.get("http://" + site1 + "/get")
request, response = test_client.get("http://" + site1 + "/get")
assert response.text == "OK"
@app.post("/post", host=[site1, "site2.com"], strict_slashes=False)
def post_handler(request):
return text("OK")
request, response = app.test_client.post("http://" + site1 + "/post")
request, response = test_client.post("http://" + site1 + "/post")
assert response.text == "OK"
@app.put("/put", host=[site1, "site2.com"], strict_slashes=False)
def put_handler(request):
return text("OK")
request, response = app.test_client.put("http://" + site1 + "/put")
request, response = test_client.put("http://" + site1 + "/put")
assert response.text == "OK"
@app.delete("/delete", host=[site1, "site2.com"], strict_slashes=False)
def delete_handler(request):
return text("OK")
request, response = app.test_client.delete("http://" + site1 + "/delete")
request, response = test_client.delete("http://" + site1 + "/delete")
assert response.text == "OK"

View File

@ -1,6 +1,9 @@
import asyncio
import signal
from contextlib import closing
from socket import socket
import pytest
from sanic.testing import HOST, PORT
@ -118,7 +121,12 @@ def test_create_server_trigger_events(app):
app.listener("after_server_stop")(after_stop)
loop = asyncio.get_event_loop()
serv_coro = app.create_server(return_asyncio_server=True)
# Use random port for tests
with closing(socket()) as sock:
sock.bind(("127.0.0.1", 0))
serv_coro = app.create_server(return_asyncio_server=True, sock=sock)
serv_task = asyncio.ensure_future(serv_coro, loop=loop)
server = loop.run_until_complete(serv_task)
server.after_start()

View File

@ -1,8 +1,13 @@
import asyncio
import os
import signal
from queue import Queue
from unittest.mock import MagicMock
import pytest
from sanic.compat import ctrlc_workaround_for_windows
from sanic.response import HTTPResponse
from sanic.testing import HOST, PORT
@ -16,13 +21,21 @@ calledq = Queue()
def set_loop(app, loop):
loop.add_signal_handler = MagicMock()
global mock
mock = MagicMock()
if os.name == "nt":
signal.signal = mock
else:
loop.add_signal_handler = mock
def after(app, loop):
calledq.put(loop.add_signal_handler.called)
calledq.put(mock.called)
@pytest.mark.skipif(
os.name == "nt", reason="May hang CI on py38/windows"
)
def test_register_system_signals(app):
"""Test if sanic register system signals"""
@ -38,6 +51,9 @@ def test_register_system_signals(app):
assert calledq.get() is True
@pytest.mark.skipif(
os.name == "nt", reason="May hang CI on py38/windows"
)
def test_dont_register_system_signals(app):
"""Test if sanic don't register system signals"""
@ -51,3 +67,49 @@ def test_dont_register_system_signals(app):
app.run(HOST, PORT, register_sys_signals=False)
assert calledq.get() is False
@pytest.mark.skipif(
os.name == "nt", reason="windows cannot SIGINT processes"
)
def test_windows_workaround():
"""Test Windows workaround (on any other OS)"""
# At least some code coverage, even though this test doesn't work on
# Windows...
class MockApp:
def __init__(self):
self.is_stopping = False
def stop(self):
assert not self.is_stopping
self.is_stopping = True
def add_task(self, func):
loop = asyncio.get_event_loop()
self.stay_active_task = loop.create_task(func(self))
async def atest(stop_first):
app = MockApp()
ctrlc_workaround_for_windows(app)
await asyncio.sleep(0.05)
if stop_first:
app.stop()
await asyncio.sleep(0.2)
assert app.is_stopping == stop_first
# First Ctrl+C: should call app.stop() within 0.1 seconds
os.kill(os.getpid(), signal.SIGINT)
await asyncio.sleep(0.2)
assert app.is_stopping
assert app.stay_active_task.result() == None
# Second Ctrl+C should raise
with pytest.raises(KeyboardInterrupt):
os.kill(os.getpid(), signal.SIGINT)
return "OK"
# Run in our private loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
res = loop.run_until_complete(atest(False))
assert res == "OK"
res = loop.run_until_complete(atest(True))
assert res == "OK"

View File

@ -27,7 +27,8 @@ def test_test_client_port_default(app):
return json(request.transport.get_extra_info("sockname")[1])
test_client = SanicTestClient(app)
assert test_client.port == PORT
assert test_client.port == PORT # Can be None before request
request, response = test_client.get("/get")
assert response.json == PORT
assert test_client.port > 0
assert response.json == test_client.port