Merge branch 'master' into streaming

This commit is contained in:
L. Kärkkäinen 2020-03-26 08:59:33 +02:00
commit 1aac4f546b
18 changed files with 254 additions and 83 deletions

View File

@ -95,6 +95,7 @@ class Sanic:
self.sock = None self.sock = None
self.strict_slashes = strict_slashes self.strict_slashes = strict_slashes
self.listeners = defaultdict(list) self.listeners = defaultdict(list)
self.is_stopping = False
self.is_running = False self.is_running = False
self.websocket_enabled = False self.websocket_enabled = False
self.websocket_tasks = set() self.websocket_tasks = set()
@ -1181,6 +1182,7 @@ class Sanic:
try: try:
self.is_running = True self.is_running = True
self.is_stopping = False
if workers > 1 and os.name != "posix": if workers > 1 and os.name != "posix":
logger.warn( logger.warn(
f"Multiprocessing is currently not supported on {os.name}," f"Multiprocessing is currently not supported on {os.name},"
@ -1213,6 +1215,8 @@ class Sanic:
def stop(self): def stop(self):
"""This kills the Sanic""" """This kills the Sanic"""
if not self.is_stopping:
self.is_stopping = True
get_event_loop().stop() get_event_loop().stop()
async def create_server( async def create_server(

View File

@ -1,17 +1,11 @@
from asyncio import CancelledError import asyncio
import signal
from sys import argv from sys import argv
from multidict import CIMultiDict # type: ignore from multidict import CIMultiDict # type: ignore
try:
from trio import Cancelled # type: ignore
CancelledErrors = tuple([CancelledError, Cancelled])
except ImportError:
CancelledErrors = tuple([CancelledError])
class Header(CIMultiDict): class Header(CIMultiDict):
def get_all(self, key): def get_all(self, key):
return self.getall(key, default=[]) return self.getall(key, default=[])
@ -20,15 +14,42 @@ class Header(CIMultiDict):
use_trio = argv[0].endswith("hypercorn") and "trio" in argv use_trio = argv[0].endswith("hypercorn") and "trio" in argv
if use_trio: if use_trio:
from trio import open_file as open_async, Path # type: ignore import trio # type: ignore
def stat_async(path): def stat_async(path):
return Path(path).stat() return trio.Path(path).stat()
open_async = trio.open_file
CancelledErrors = tuple([asyncio.CancelledError, trio.Cancelled])
else: else:
from aiofiles import open as aio_open # type: ignore from aiofiles import open as aio_open # type: ignore
from aiofiles.os import stat as stat_async # type: ignore # noqa: F401 from aiofiles.os import stat as stat_async # type: ignore # noqa: F401
async def open_async(file, mode="r", **kwargs): async def open_async(file, mode="r", **kwargs):
return aio_open(file, mode, **kwargs) return aio_open(file, mode, **kwargs)
CancelledErrors = tuple([asyncio.CancelledError])
def ctrlc_workaround_for_windows(app):
async def stay_active(app):
"""Asyncio wakeups to allow receiving SIGINT in Python"""
while not die:
# If someone else stopped the app, just exit
if app.is_stopping:
return
# Windows Python blocks signal handlers while the event loop is
# waiting for I/O. Frequent wakeups keep interrupts flowing.
await asyncio.sleep(0.1)
# Can't be called from signal handler, so call it from here
app.stop()
def ctrlc_handler(sig, frame):
nonlocal die
if die:
raise KeyboardInterrupt("Non-graceful Ctrl+C")
die = True
die = False
signal.signal(signal.SIGINT, ctrlc_handler)
app.add_task(stay_active)

View File

@ -84,6 +84,7 @@ class StreamingHTTPResponse(BaseHTTPResponse):
self.status = status self.status = status
self.headers = Header(headers or {}) self.headers = Header(headers or {})
self._cookies = None self._cookies = None
self.protocol = None
async def write(self, data): async def write(self, data):
"""Writes a chunk of data to the streaming response. """Writes a chunk of data to the streaming response.

View File

@ -11,6 +11,7 @@ from signal import signal as signal_func
from socket import SO_REUSEADDR, SOL_SOCKET, socket from socket import SO_REUSEADDR, SOL_SOCKET, socket
from time import monotonic as current_time from time import monotonic as current_time
from sanic.compat import ctrlc_workaround_for_windows
from sanic.exceptions import RequestTimeout, ServiceUnavailable from sanic.exceptions import RequestTimeout, ServiceUnavailable
from sanic.http import Http, Stage from sanic.http import Http, Stage
from sanic.log import logger from sanic.log import logger
@ -25,6 +26,8 @@ try:
except ImportError: except ImportError:
pass pass
OS_IS_WINDOWS = os.name == "nt"
class Signal: class Signal:
stopped = False stopped = False
@ -530,15 +533,11 @@ def serve(
# Register signals for graceful termination # Register signals for graceful termination
if register_sys_signals: if register_sys_signals:
_singals = (SIGTERM,) if run_multiple else (SIGINT, SIGTERM) if OS_IS_WINDOWS:
for _signal in _singals: ctrlc_workaround_for_windows(app)
try: else:
loop.add_signal_handler(_signal, loop.stop) for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]:
except NotImplementedError: loop.add_signal_handler(_signal, app.stop)
logger.warning(
"Sanic tried to use loop.add_signal_handler "
"but it is not implemented on this platform."
)
pid = os.getpid() pid = os.getpid()
try: try:
logger.info("Starting worker [%s]", pid) logger.info("Starting worker [%s]", pid)

View File

@ -12,7 +12,7 @@ from sanic.response import text
ASGI_HOST = "mockserver" ASGI_HOST = "mockserver"
HOST = "127.0.0.1" HOST = "127.0.0.1"
PORT = 42101 PORT = None
class SanicTestClient: class SanicTestClient:
@ -99,7 +99,7 @@ class SanicTestClient:
if self.port: if self.port:
server_kwargs = dict( server_kwargs = dict(
host=host or self.host, port=self.port, **server_kwargs host=host or self.host, port=self.port, **server_kwargs,
) )
host, port = host or self.host, self.port host, port = host or self.host, self.port
else: else:
@ -107,6 +107,7 @@ class SanicTestClient:
sock.bind((host or self.host, 0)) sock.bind((host or self.host, 0))
server_kwargs = dict(sock=sock, **server_kwargs) server_kwargs = dict(sock=sock, **server_kwargs)
host, port = sock.getsockname() host, port = sock.getsockname()
self.port = port
if uri.startswith( if uri.startswith(
("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:") ("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")
@ -118,6 +119,9 @@ class SanicTestClient:
url = "{scheme}://{host}:{port}{uri}".format( url = "{scheme}://{host}:{port}{uri}".format(
scheme=scheme, host=host, port=port, uri=uri scheme=scheme, host=host, port=port, uri=uri
) )
# Tests construct URLs using PORT = None, which means random port not
# known until this function is called, so fix that here
url = url.replace(":None/", f":{port}/")
@self.app.listener("after_server_start") @self.app.listener("after_server_start")
async def _collect_response(sanic, loop): async def _collect_response(sanic, loop):
@ -207,7 +211,7 @@ class SanicASGITestClient(httpx.AsyncClient):
self.app = app self.app = app
dispatch = SanicASGIDispatch(app=app, client=(ASGI_HOST, PORT)) dispatch = SanicASGIDispatch(app=app, client=(ASGI_HOST, PORT or 0))
super().__init__(dispatch=dispatch, base_url=base_url) super().__init__(dispatch=dispatch, base_url=base_url)
self.last_request = None self.last_request = None

View File

@ -6,6 +6,7 @@ from inspect import isawaitable
import pytest import pytest
from sanic import Sanic
from sanic.exceptions import SanicException from sanic.exceptions import SanicException
from sanic.response import text from sanic.response import text
@ -48,6 +49,7 @@ def test_asyncio_server_no_start_serving(app):
if not uvloop_installed(): if not uvloop_installed():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
asyncio_srv_coro = app.create_server( asyncio_srv_coro = app.create_server(
port=43123,
return_asyncio_server=True, return_asyncio_server=True,
asyncio_server_kwargs=dict(start_serving=False), asyncio_server_kwargs=dict(start_serving=False),
) )
@ -61,6 +63,7 @@ def test_asyncio_server_start_serving(app):
if not uvloop_installed(): if not uvloop_installed():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
asyncio_srv_coro = app.create_server( asyncio_srv_coro = app.create_server(
port=43124,
return_asyncio_server=True, return_asyncio_server=True,
asyncio_server_kwargs=dict(start_serving=False), asyncio_server_kwargs=dict(start_serving=False),
) )
@ -199,10 +202,17 @@ def test_handle_request_with_nested_sanic_exception(app, monkeypatch, caplog):
with caplog.at_level(logging.ERROR): with caplog.at_level(logging.ERROR):
request, response = app.test_client.get("/") request, response = app.test_client.get("/")
port = request.server_port
assert port > 0
assert response.status == 500 assert response.status == 500
assert "Mock SanicException" in response.text assert "Mock SanicException" in response.text
assert ( assert (
"sanic.root", "sanic.root",
logging.ERROR, logging.ERROR,
"Exception occurred while handling uri: 'http://127.0.0.1:42101/'", f"Exception occurred while handling uri: 'http://127.0.0.1:{port}/'",
) in caplog.record_tuples ) in caplog.record_tuples
def test_app_name_required():
with pytest.deprecated_call():
Sanic()

View File

@ -221,7 +221,7 @@ async def test_request_class_custom():
class MyCustomRequest(Request): class MyCustomRequest(Request):
pass pass
app = Sanic(request_class=MyCustomRequest) app = Sanic(name=__name__, request_class=MyCustomRequest)
@app.get("/custom") @app.get("/custom")
def custom_request(request): def custom_request(request):

View File

@ -44,42 +44,42 @@ def test_load_from_object_string_exception(app):
def test_auto_load_env(): def test_auto_load_env():
environ["SANIC_TEST_ANSWER"] = "42" environ["SANIC_TEST_ANSWER"] = "42"
app = Sanic() app = Sanic(name=__name__)
assert app.config.TEST_ANSWER == 42 assert app.config.TEST_ANSWER == 42
del environ["SANIC_TEST_ANSWER"] del environ["SANIC_TEST_ANSWER"]
def test_auto_load_bool_env(): def test_auto_load_bool_env():
environ["SANIC_TEST_ANSWER"] = "True" environ["SANIC_TEST_ANSWER"] = "True"
app = Sanic() app = Sanic(name=__name__)
assert app.config.TEST_ANSWER == True assert app.config.TEST_ANSWER == True
del environ["SANIC_TEST_ANSWER"] del environ["SANIC_TEST_ANSWER"]
def test_dont_load_env(): def test_dont_load_env():
environ["SANIC_TEST_ANSWER"] = "42" environ["SANIC_TEST_ANSWER"] = "42"
app = Sanic(load_env=False) app = Sanic(name=__name__, load_env=False)
assert getattr(app.config, "TEST_ANSWER", None) is None assert getattr(app.config, "TEST_ANSWER", None) is None
del environ["SANIC_TEST_ANSWER"] del environ["SANIC_TEST_ANSWER"]
def test_load_env_prefix(): def test_load_env_prefix():
environ["MYAPP_TEST_ANSWER"] = "42" environ["MYAPP_TEST_ANSWER"] = "42"
app = Sanic(load_env="MYAPP_") app = Sanic(name=__name__, load_env="MYAPP_")
assert app.config.TEST_ANSWER == 42 assert app.config.TEST_ANSWER == 42
del environ["MYAPP_TEST_ANSWER"] del environ["MYAPP_TEST_ANSWER"]
def test_load_env_prefix_float_values(): def test_load_env_prefix_float_values():
environ["MYAPP_TEST_ROI"] = "2.3" environ["MYAPP_TEST_ROI"] = "2.3"
app = Sanic(load_env="MYAPP_") app = Sanic(name=__name__, load_env="MYAPP_")
assert app.config.TEST_ROI == 2.3 assert app.config.TEST_ROI == 2.3
del environ["MYAPP_TEST_ROI"] del environ["MYAPP_TEST_ROI"]
def test_load_env_prefix_string_value(): def test_load_env_prefix_string_value():
environ["MYAPP_TEST_TOKEN"] = "somerandomtesttoken" environ["MYAPP_TEST_TOKEN"] = "somerandomtesttoken"
app = Sanic(load_env="MYAPP_") app = Sanic(name=__name__, load_env="MYAPP_")
assert app.config.TEST_TOKEN == "somerandomtesttoken" assert app.config.TEST_TOKEN == "somerandomtesttoken"
del environ["MYAPP_TEST_TOKEN"] del environ["MYAPP_TEST_TOKEN"]

View File

@ -30,7 +30,7 @@ def test_deprecated_custom_request():
Sanic(request_class=DeprecCustomRequest) Sanic(request_class=DeprecCustomRequest)
def test_custom_request(): def test_custom_request():
app = Sanic(request_class=CustomRequest) app = Sanic(name=__name__, request_class=CustomRequest)
@app.route("/post", methods=["POST"]) @app.route("/post", methods=["POST"])
async def post_handler(request): async def post_handler(request):

View File

@ -7,12 +7,12 @@ import httpx
from sanic import Sanic, server from sanic import Sanic, server
from sanic.response import text from sanic.response import text
from sanic.testing import HOST, PORT, SanicTestClient from sanic.testing import HOST, SanicTestClient
CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True} CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True}
old_conn = None old_conn = None
PORT = 42101 # test_keep_alive_timeout_reuse doesn't work with random port
class ReusableSanicConnectionPool( class ReusableSanicConnectionPool(

View File

@ -1,4 +1,5 @@
import logging import logging
import os
import uuid import uuid
from importlib import reload from importlib import reload
@ -12,6 +13,7 @@ import sanic
from sanic import Sanic from sanic import Sanic
from sanic.log import LOGGING_CONFIG_DEFAULTS, logger from sanic.log import LOGGING_CONFIG_DEFAULTS, logger
from sanic.response import text from sanic.response import text
from sanic.testing import SanicTestClient
logging_format = """module: %(module)s; \ logging_format = """module: %(module)s; \
@ -126,7 +128,7 @@ def test_log_connection_lost(app, debug, monkeypatch):
def test_logger(caplog): def test_logger(caplog):
rand_string = str(uuid.uuid4()) rand_string = str(uuid.uuid4())
app = Sanic() app = Sanic(name=__name__)
@app.get("/") @app.get("/")
def log_info(request): def log_info(request):
@ -136,15 +138,67 @@ def test_logger(caplog):
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
request, response = app.test_client.get("/") request, response = app.test_client.get("/")
port = request.server_port
# Note: testing with random port doesn't show the banner because it doesn't
# define host and port. This test supports both modes.
if caplog.record_tuples[0] == (
"sanic.root",
logging.INFO,
f"Goin' Fast @ http://127.0.0.1:{port}",
):
caplog.record_tuples.pop(0)
assert caplog.record_tuples[0] == ( assert caplog.record_tuples[0] == (
"sanic.root", "sanic.root",
logging.INFO, logging.INFO,
"Goin' Fast @ http://127.0.0.1:42101", f"http://127.0.0.1:{port}/",
)
assert caplog.record_tuples[1] == ("sanic.root", logging.INFO, rand_string)
assert caplog.record_tuples[-1] == (
"sanic.root",
logging.INFO,
"Server Stopped",
)
def test_logger_static_and_secure(caplog):
# Same as test_logger, except for more coverage:
# - test_client initialised separately for static port
# - using ssl
rand_string = str(uuid.uuid4())
app = Sanic(name=__name__)
@app.get("/")
def log_info(request):
logger.info(rand_string)
return text("hello")
current_dir = os.path.dirname(os.path.realpath(__file__))
ssl_cert = os.path.join(current_dir, "certs/selfsigned.cert")
ssl_key = os.path.join(current_dir, "certs/selfsigned.key")
ssl_dict = {"cert": ssl_cert, "key": ssl_key}
test_client = SanicTestClient(app, port=42101)
with caplog.at_level(logging.INFO):
request, response = test_client.get(
f"https://127.0.0.1:{test_client.port}/",
server_kwargs=dict(ssl=ssl_dict),
)
port = test_client.port
assert caplog.record_tuples[0] == (
"sanic.root",
logging.INFO,
f"Goin' Fast @ https://127.0.0.1:{port}",
) )
assert caplog.record_tuples[1] == ( assert caplog.record_tuples[1] == (
"sanic.root", "sanic.root",
logging.INFO, logging.INFO,
"http://127.0.0.1:42101/", f"https://127.0.0.1:{port}/",
) )
assert caplog.record_tuples[2] == ("sanic.root", logging.INFO, rand_string) assert caplog.record_tuples[2] == ("sanic.root", logging.INFO, rand_string)
assert caplog.record_tuples[-1] == ( assert caplog.record_tuples[-1] == (

View File

@ -49,10 +49,10 @@ def test_logo_false(app, caplog):
loop.run_until_complete(_server.wait_closed()) loop.run_until_complete(_server.wait_closed())
app.stop() app.stop()
banner, port = caplog.record_tuples[ROW][2].rsplit(":", 1)
assert caplog.record_tuples[ROW][1] == logging.INFO assert caplog.record_tuples[ROW][1] == logging.INFO
assert caplog.record_tuples[ROW][ assert banner == "Goin' Fast @ http://127.0.0.1"
2 assert int(port) > 0
] == f"Goin' Fast @ http://127.0.0.1:{PORT}"
def test_logo_true(app, caplog): def test_logo_true(app, caplog):

View File

@ -7,7 +7,7 @@ from sanic.views import CompositionView, HTTPMethodView
from sanic.views import stream as stream_decorator from sanic.views import stream as stream_decorator
data = "abc" * 10000000 data = "abc" * 1_000_000
def test_request_stream_method_view(app): def test_request_stream_method_view(app):

View File

@ -12,7 +12,7 @@ from sanic import Blueprint, Sanic
from sanic.exceptions import ServerError from sanic.exceptions import ServerError
from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, Request, RequestParameters from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, Request, RequestParameters
from sanic.response import html, json, text from sanic.response import html, json, text
from sanic.testing import ASGI_HOST, HOST, PORT from sanic.testing import ASGI_HOST, HOST, PORT, SanicTestClient
# ------------------------------------------------------------ # # ------------------------------------------------------------ #
@ -1029,7 +1029,7 @@ def test_url_attributes_no_ssl(app, path, query, expected_url):
app.add_route(handler, path) app.add_route(handler, path)
request, response = app.test_client.get(path + f"?{query}") request, response = app.test_client.get(path + f"?{query}")
assert request.url == expected_url.format(HOST, PORT) assert request.url == expected_url.format(HOST, request.server_port)
parsed = urlparse(request.url) parsed = urlparse(request.url)
@ -1086,11 +1086,12 @@ def test_url_attributes_with_ssl_context(app, path, query, expected_url):
app.add_route(handler, path) app.add_route(handler, path)
port = app.test_client.port
request, response = app.test_client.get( request, response = app.test_client.get(
f"https://{HOST}:{PORT}" + path + f"?{query}", f"https://{HOST}:{PORT}" + path + f"?{query}",
server_kwargs={"ssl": context}, server_kwargs={"ssl": context},
) )
assert request.url == expected_url.format(HOST, PORT) assert request.url == expected_url.format(HOST, request.server_port)
parsed = urlparse(request.url) parsed = urlparse(request.url)
@ -1125,7 +1126,7 @@ def test_url_attributes_with_ssl_dict(app, path, query, expected_url):
f"https://{HOST}:{PORT}" + path + f"?{query}", f"https://{HOST}:{PORT}" + path + f"?{query}",
server_kwargs={"ssl": ssl_dict}, server_kwargs={"ssl": ssl_dict},
) )
assert request.url == expected_url.format(HOST, PORT) assert request.url == expected_url.format(HOST, request.server_port)
parsed = urlparse(request.url) parsed = urlparse(request.url)
@ -1917,8 +1918,9 @@ def test_request_server_port(app):
def handler(request): def handler(request):
return text("OK") return text("OK")
request, response = app.test_client.get("/", headers={"Host": "my-server"}) test_client = SanicTestClient(app)
assert request.server_port == app.test_client.port request, response = test_client.get("/", headers={"Host": "my-server"})
assert request.server_port == test_client.port
def test_request_server_port_in_host_header(app): def test_request_server_port_in_host_header(app):
@ -1939,6 +1941,9 @@ def test_request_server_port_in_host_header(app):
request, response = app.test_client.get( request, response = app.test_client.get(
"/", headers={"Host": "mal_formed:5555"} "/", headers={"Host": "mal_formed:5555"}
) )
if PORT is None:
assert request.server_port != 5555
else:
assert request.server_port == app.test_client.port assert request.server_port == app.test_client.port
@ -1979,7 +1984,7 @@ def test_server_name_and_url_for(app):
request, response = app.test_client.get("/foo") request, response = app.test_client.get("/foo")
assert ( assert (
request.url_for("handler") request.url_for("handler")
== f"http://my-server:{app.test_client.port}/foo" == f"http://my-server:{request.server_port}/foo"
) )
app.config.SERVER_NAME = "https://my-server/path" app.config.SERVER_NAME = "https://my-server/path"
@ -2040,7 +2045,7 @@ async def test_request_form_invalid_content_type_asgi(app):
def test_endpoint_basic(): def test_endpoint_basic():
app = Sanic() app = Sanic(name=__name__)
@app.route("/") @app.route("/")
def my_unique_handler(request): def my_unique_handler(request):
@ -2053,7 +2058,7 @@ def test_endpoint_basic():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_endpoint_basic_asgi(): async def test_endpoint_basic_asgi():
app = Sanic() app = Sanic(name=__name__)
@app.route("/") @app.route("/")
def my_unique_handler(request): def my_unique_handler(request):
@ -2132,5 +2137,5 @@ def test_url_for_without_server_name(app):
request, response = app.test_client.get("/sample") request, response = app.test_client.get("/sample")
assert ( assert (
response.json["url"] response.json["url"]
== f"http://127.0.0.1:{app.test_client.port}/url-for" == f"http://127.0.0.1:{request.server_port}/url-for"
) )

View File

@ -6,6 +6,7 @@ from sanic import Sanic
from sanic.constants import HTTP_METHODS from sanic.constants import HTTP_METHODS
from sanic.response import json, text from sanic.response import json, text
from sanic.router import ParameterNameConflicts, RouteDoesNotExist, RouteExists from sanic.router import ParameterNameConflicts, RouteDoesNotExist, RouteExists
from sanic.testing import SanicTestClient
# ------------------------------------------------------------ # # ------------------------------------------------------------ #
@ -163,35 +164,36 @@ def test_route_optional_slash(app):
def test_route_strict_slashes_set_to_false_and_host_is_a_list(app): def test_route_strict_slashes_set_to_false_and_host_is_a_list(app):
# Part of regression test for issue #1120 # Part of regression test for issue #1120
site1 = f"127.0.0.1:{app.test_client.port}" test_client = SanicTestClient(app, port=42101)
site1 = f"127.0.0.1:{test_client.port}"
# before fix, this raises a RouteExists error # before fix, this raises a RouteExists error
@app.get("/get", host=[site1, "site2.com"], strict_slashes=False) @app.get("/get", host=[site1, "site2.com"], strict_slashes=False)
def get_handler(request): def get_handler(request):
return text("OK") return text("OK")
request, response = app.test_client.get("http://" + site1 + "/get") request, response = test_client.get("http://" + site1 + "/get")
assert response.text == "OK" assert response.text == "OK"
@app.post("/post", host=[site1, "site2.com"], strict_slashes=False) @app.post("/post", host=[site1, "site2.com"], strict_slashes=False)
def post_handler(request): def post_handler(request):
return text("OK") return text("OK")
request, response = app.test_client.post("http://" + site1 + "/post") request, response = test_client.post("http://" + site1 + "/post")
assert response.text == "OK" assert response.text == "OK"
@app.put("/put", host=[site1, "site2.com"], strict_slashes=False) @app.put("/put", host=[site1, "site2.com"], strict_slashes=False)
def put_handler(request): def put_handler(request):
return text("OK") return text("OK")
request, response = app.test_client.put("http://" + site1 + "/put") request, response = test_client.put("http://" + site1 + "/put")
assert response.text == "OK" assert response.text == "OK"
@app.delete("/delete", host=[site1, "site2.com"], strict_slashes=False) @app.delete("/delete", host=[site1, "site2.com"], strict_slashes=False)
def delete_handler(request): def delete_handler(request):
return text("OK") return text("OK")
request, response = app.test_client.delete("http://" + site1 + "/delete") request, response = test_client.delete("http://" + site1 + "/delete")
assert response.text == "OK" assert response.text == "OK"

View File

@ -1,6 +1,9 @@
import asyncio import asyncio
import signal import signal
from contextlib import closing
from socket import socket
import pytest import pytest
from sanic.testing import HOST, PORT from sanic.testing import HOST, PORT
@ -118,7 +121,12 @@ def test_create_server_trigger_events(app):
app.listener("after_server_stop")(after_stop) app.listener("after_server_stop")(after_stop)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
serv_coro = app.create_server(return_asyncio_server=True)
# Use random port for tests
with closing(socket()) as sock:
sock.bind(("127.0.0.1", 0))
serv_coro = app.create_server(return_asyncio_server=True, sock=sock)
serv_task = asyncio.ensure_future(serv_coro, loop=loop) serv_task = asyncio.ensure_future(serv_coro, loop=loop)
server = loop.run_until_complete(serv_task) server = loop.run_until_complete(serv_task)
server.after_start() server.after_start()

View File

@ -1,8 +1,13 @@
import asyncio import asyncio
import os
import signal
from queue import Queue from queue import Queue
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest
from sanic.compat import ctrlc_workaround_for_windows
from sanic.response import HTTPResponse from sanic.response import HTTPResponse
from sanic.testing import HOST, PORT from sanic.testing import HOST, PORT
@ -16,13 +21,21 @@ calledq = Queue()
def set_loop(app, loop): def set_loop(app, loop):
loop.add_signal_handler = MagicMock() global mock
mock = MagicMock()
if os.name == "nt":
signal.signal = mock
else:
loop.add_signal_handler = mock
def after(app, loop): def after(app, loop):
calledq.put(loop.add_signal_handler.called) calledq.put(mock.called)
@pytest.mark.skipif(
os.name == "nt", reason="May hang CI on py38/windows"
)
def test_register_system_signals(app): def test_register_system_signals(app):
"""Test if sanic register system signals""" """Test if sanic register system signals"""
@ -38,6 +51,9 @@ def test_register_system_signals(app):
assert calledq.get() is True assert calledq.get() is True
@pytest.mark.skipif(
os.name == "nt", reason="May hang CI on py38/windows"
)
def test_dont_register_system_signals(app): def test_dont_register_system_signals(app):
"""Test if sanic don't register system signals""" """Test if sanic don't register system signals"""
@ -51,3 +67,49 @@ def test_dont_register_system_signals(app):
app.run(HOST, PORT, register_sys_signals=False) app.run(HOST, PORT, register_sys_signals=False)
assert calledq.get() is False assert calledq.get() is False
@pytest.mark.skipif(
os.name == "nt", reason="windows cannot SIGINT processes"
)
def test_windows_workaround():
"""Test Windows workaround (on any other OS)"""
# At least some code coverage, even though this test doesn't work on
# Windows...
class MockApp:
def __init__(self):
self.is_stopping = False
def stop(self):
assert not self.is_stopping
self.is_stopping = True
def add_task(self, func):
loop = asyncio.get_event_loop()
self.stay_active_task = loop.create_task(func(self))
async def atest(stop_first):
app = MockApp()
ctrlc_workaround_for_windows(app)
await asyncio.sleep(0.05)
if stop_first:
app.stop()
await asyncio.sleep(0.2)
assert app.is_stopping == stop_first
# First Ctrl+C: should call app.stop() within 0.1 seconds
os.kill(os.getpid(), signal.SIGINT)
await asyncio.sleep(0.2)
assert app.is_stopping
assert app.stay_active_task.result() == None
# Second Ctrl+C should raise
with pytest.raises(KeyboardInterrupt):
os.kill(os.getpid(), signal.SIGINT)
return "OK"
# Run in our private loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
res = loop.run_until_complete(atest(False))
assert res == "OK"
res = loop.run_until_complete(atest(True))
assert res == "OK"

View File

@ -27,7 +27,8 @@ def test_test_client_port_default(app):
return json(request.transport.get_extra_info("sockname")[1]) return json(request.transport.get_extra_info("sockname")[1])
test_client = SanicTestClient(app) test_client = SanicTestClient(app)
assert test_client.port == PORT assert test_client.port == PORT # Can be None before request
request, response = test_client.get("/get") request, response = test_client.get("/get")
assert response.json == PORT assert test_client.port > 0
assert response.json == test_client.port