diff --git a/sanic/__main__.py b/sanic/__main__.py index c9fa2e52..6619705c 100644 --- a/sanic/__main__.py +++ b/sanic/__main__.py @@ -1,28 +1,83 @@ import os import sys -from argparse import ArgumentParser +from argparse import ArgumentParser, RawDescriptionHelpFormatter from importlib import import_module from typing import Any, Dict, Optional +from sanic import __version__ from sanic.app import Sanic +from sanic.config import BASE_LOGO from sanic.log import logger +class SanicArgumentParser(ArgumentParser): + def add_bool_arguments(self, *args, **kwargs): + group = self.add_mutually_exclusive_group() + group.add_argument(*args, action="store_true", **kwargs) + kwargs["help"] = "no " + kwargs["help"] + group.add_argument( + "--no-" + args[0][2:], *args[1:], action="store_false", **kwargs + ) + + def main(): - parser = ArgumentParser(prog="sanic") - parser.add_argument("--host", dest="host", type=str, default="127.0.0.1") - parser.add_argument("--port", dest="port", type=int, default=8000) - parser.add_argument("--unix", dest="unix", type=str, default="") + parser = SanicArgumentParser( + prog="sanic", + description=BASE_LOGO, + formatter_class=RawDescriptionHelpFormatter, + ) + parser.add_argument( + "-H", + "--host", + dest="host", + type=str, + default="127.0.0.1", + help="host address [default 127.0.0.1]", + ) + parser.add_argument( + "-p", + "--port", + dest="port", + type=int, + default=8000, + help="port to serve on [default 8000]", + ) + parser.add_argument( + "-u", + "--unix", + dest="unix", + type=str, + default="", + help="location of unix socket", + ) parser.add_argument( "--cert", dest="cert", type=str, help="location of certificate for SSL" ) parser.add_argument( "--key", dest="key", type=str, help="location of keyfile for SSL." ) - parser.add_argument("--workers", dest="workers", type=int, default=1) + parser.add_argument( + "-w", + "--workers", + dest="workers", + type=int, + default=1, + help="number of worker processes [default 1]", + ) parser.add_argument("--debug", dest="debug", action="store_true") - parser.add_argument("module") + parser.add_bool_arguments( + "--access-logs", dest="access_log", help="display access logs" + ) + parser.add_argument( + "-v", + "--version", + action="version", + version=f"Sanic {__version__}", + ) + parser.add_argument( + "module", help="path to your Sanic app. Example: path.to.server:app" + ) args = parser.parse_args() try: @@ -30,9 +85,12 @@ def main(): if module_path not in sys.path: sys.path.append(module_path) - module_parts = args.module.split(".") - module_name = ".".join(module_parts[:-1]) - app_name = module_parts[-1] + if ":" in args.module: + module_name, app_name = args.module.rsplit(":", 1) + else: + module_parts = args.module.split(".") + module_name = ".".join(module_parts[:-1]) + app_name = module_parts[-1] module = import_module(module_name) app = getattr(module, app_name, None) @@ -57,6 +115,7 @@ def main(): unix=args.unix, workers=args.workers, debug=args.debug, + access_log=args.access_log, ssl=ssl, ) except ImportError as e: diff --git a/tests/test_blueprints.py b/tests/test_blueprints.py index 6e9fb20f..60bc8221 100644 --- a/tests/test_blueprints.py +++ b/tests/test_blueprints.py @@ -735,6 +735,7 @@ def test_static_blueprint_name(app: Sanic, static_file_directory, file_name): _, response = app.test_client.get("/static/test.file/") assert response.status == 200 + @pytest.mark.parametrize("file_name", ["test.file"]) def test_static_blueprintp_mw(app: Sanic, static_file_directory, file_name): current_file = inspect.getfile(inspect.currentframe()) @@ -745,7 +746,7 @@ def test_static_blueprintp_mw(app: Sanic, static_file_directory, file_name): bp = Blueprint(name="test_mw", url_prefix="") - @bp.middleware('request') + @bp.middleware("request") def bp_mw1(request): nonlocal triggered triggered = True @@ -754,7 +755,7 @@ def test_static_blueprintp_mw(app: Sanic, static_file_directory, file_name): "/test.file", get_file_path(static_file_directory, file_name), strict_slashes=True, - name="static" + name="static", ) app.blueprint(bp) diff --git a/tests/test_load_module_from_file_location.py b/tests/test_load_module_from_file_location.py index 5dc42d5c..c47913dd 100644 --- a/tests/test_load_module_from_file_location.py +++ b/tests/test_load_module_from_file_location.py @@ -20,7 +20,9 @@ def test_load_module_from_file_location(loaded_module_from_file_location): @pytest.mark.dependency(depends=["test_load_module_from_file_location"]) -def test_loaded_module_from_file_location_name(loaded_module_from_file_location,): +def test_loaded_module_from_file_location_name( + loaded_module_from_file_location, +): name = loaded_module_from_file_location.__name__ if "C:\\" in name: name = name.split("\\")[-1]