286 lines
9.5 KiB
Python
286 lines
9.5 KiB
Python
import argparse
|
|
import asyncio
|
|
import ipaddress
|
|
import logging
|
|
import os
|
|
from urllib.parse import urlparse
|
|
|
|
import uvicorn
|
|
|
|
from passkey.util import frontend
|
|
|
|
DEFAULT_HOST = "localhost"
|
|
DEFAULT_SERVE_PORT = 4401
|
|
DEFAULT_DEV_PORT = 4402
|
|
|
|
|
|
def is_subdomain(sub: str, domain: str) -> bool:
|
|
"""Check if sub is a subdomain of domain (or equal)."""
|
|
sub_parts = sub.lower().split(".")
|
|
domain_parts = domain.lower().split(".")
|
|
if len(sub_parts) < len(domain_parts):
|
|
return False
|
|
return sub_parts[-len(domain_parts) :] == domain_parts
|
|
|
|
|
|
def validate_auth_host(auth_host: str, rp_id: str) -> None:
|
|
"""Validate that auth_host is a subdomain of rp_id."""
|
|
parsed = urlparse(auth_host if "://" in auth_host else f"//{auth_host}")
|
|
host = parsed.hostname or parsed.path
|
|
if not host:
|
|
raise SystemExit(f"Invalid auth-host: '{auth_host}'")
|
|
if not is_subdomain(host, rp_id):
|
|
raise SystemExit(
|
|
f"auth-host '{auth_host}' is not a subdomain of rp-id '{rp_id}'"
|
|
)
|
|
|
|
|
|
def parse_endpoint(
|
|
value: str | None, default_port: int
|
|
) -> tuple[str | None, int | None, str | None, bool]:
|
|
"""Parse an endpoint using stdlib (urllib.parse, ipaddress).
|
|
|
|
Returns (host, port, uds_path). If uds_path is not None, host/port are None.
|
|
|
|
Supported forms:
|
|
- host[:port]
|
|
- :port (uses default host)
|
|
- [ipv6][:port] (bracketed for port usage)
|
|
- ipv6 (unbracketed, no port allowed -> default port)
|
|
- unix:/path/to/socket.sock
|
|
- None -> defaults (localhost:4401)
|
|
|
|
Notes:
|
|
- For IPv6 with an explicit port you MUST use brackets (e.g. [::1]:8080)
|
|
- Unbracketed IPv6 like ::1 implies the default port.
|
|
"""
|
|
if not value:
|
|
return DEFAULT_HOST, default_port, None, False
|
|
|
|
# Port only (numeric) -> localhost:port
|
|
if value.isdigit():
|
|
try:
|
|
port_only = int(value)
|
|
except ValueError: # pragma: no cover (isdigit guards)
|
|
raise SystemExit(f"Invalid port '{value}'")
|
|
return DEFAULT_HOST, port_only, None, False
|
|
|
|
# Leading colon :port -> bind all interfaces (0.0.0.0 + ::)
|
|
if value.startswith(":") and value != ":":
|
|
port_part = value[1:]
|
|
if not port_part.isdigit():
|
|
raise SystemExit(f"Invalid port in '{value}'")
|
|
return None, int(port_part), None, True
|
|
|
|
# UNIX domain socket
|
|
if value.startswith("unix:"):
|
|
uds_path = value[5:] or None
|
|
if uds_path is None:
|
|
raise SystemExit("unix: path must not be empty")
|
|
return None, None, uds_path, False
|
|
|
|
# Unbracketed IPv6 (cannot safely contain a port) -> detect by multiple colons
|
|
if value.count(":") > 1 and not value.startswith("["):
|
|
try:
|
|
ipaddress.IPv6Address(value)
|
|
except ValueError as e: # pragma: no cover
|
|
raise SystemExit(f"Invalid IPv6 address '{value}': {e}")
|
|
return value, default_port, None, False
|
|
|
|
# Use urllib.parse for everything else (host[:port], :port, [ipv6][:port])
|
|
parsed = urlparse(f"//{value}") # // prefix lets urlparse treat it as netloc
|
|
host = parsed.hostname
|
|
port = parsed.port
|
|
|
|
# Host may be None if empty (e.g. ':5500')
|
|
if not host:
|
|
host = DEFAULT_HOST
|
|
if port is None:
|
|
port = default_port
|
|
|
|
# Validate IP literals (optional; hostname passes through)
|
|
try:
|
|
# Strip brackets if somehow present (urlparse removes them already)
|
|
ipaddress.ip_address(host)
|
|
except ValueError:
|
|
# Not an IP address -> treat as hostname; no action
|
|
pass
|
|
|
|
return host, port, None, False
|
|
|
|
|
|
def add_common_options(p: argparse.ArgumentParser) -> None:
|
|
p.add_argument(
|
|
"--rp-id", default="localhost", help="Relying Party ID (default: localhost)"
|
|
)
|
|
p.add_argument("--rp-name", help="Relying Party name (default: same as rp-id)")
|
|
p.add_argument("--origin", help="Origin URL (default: https://<rp-id>)")
|
|
p.add_argument(
|
|
"--auth-host",
|
|
help=(
|
|
"Dedicated host (optionally with scheme/port) to serve the auth UI at the root,"
|
|
" e.g. auth.example.com or https://auth.example.com"
|
|
),
|
|
)
|
|
|
|
|
|
def main():
|
|
# Configure logging to remove the "ERROR:root:" prefix
|
|
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
|
|
|
parser = argparse.ArgumentParser(
|
|
prog="passkey-auth", description="Passkey authentication server"
|
|
)
|
|
sub = parser.add_subparsers(dest="command", required=True)
|
|
|
|
# serve subcommand
|
|
serve = sub.add_parser(
|
|
"serve", help="Run the server (production style, no auto-reload)"
|
|
)
|
|
serve.add_argument(
|
|
"hostport",
|
|
nargs="?",
|
|
help=(
|
|
"Endpoint (default: localhost:4401). Forms: host[:port] | :port | "
|
|
"[ipv6][:port] | ipv6 | unix:/path.sock"
|
|
),
|
|
)
|
|
add_common_options(serve)
|
|
|
|
# dev subcommand
|
|
dev = sub.add_parser("dev", help="Run the server in development (auto-reload)")
|
|
dev.add_argument(
|
|
"hostport",
|
|
nargs="?",
|
|
help=(
|
|
"Endpoint (default: localhost:4402). Forms: host[:port] | :port | "
|
|
"[ipv6][:port] | ipv6 | unix:/path.sock"
|
|
),
|
|
)
|
|
add_common_options(dev)
|
|
|
|
# reset subcommand
|
|
reset = sub.add_parser(
|
|
"reset",
|
|
help=(
|
|
"Create a credential reset link for a user. Provide part of the display name or UUID. "
|
|
"If omitted, targets the master admin (first Administration role user in an auth:admin org)."
|
|
),
|
|
)
|
|
reset.add_argument(
|
|
"query",
|
|
nargs="?",
|
|
help="User UUID (full) or case-insensitive substring of display name. If omitted, master admin is used.",
|
|
)
|
|
add_common_options(reset)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.command in {"serve", "dev"}:
|
|
default_port = DEFAULT_DEV_PORT if args.command == "dev" else DEFAULT_SERVE_PORT
|
|
host, port, uds, all_ifaces = parse_endpoint(args.hostport, default_port)
|
|
devmode = args.command == "dev"
|
|
else:
|
|
host = port = uds = all_ifaces = None # type: ignore
|
|
devmode = False
|
|
|
|
# Determine origin (dev mode default override)
|
|
origin = args.origin
|
|
if devmode and not args.origin and not args.rp_id:
|
|
# Dev mode: Vite runs on another port, override:
|
|
origin = "http://localhost:4403"
|
|
|
|
# Export configuration via environment for lifespan initialization in each process
|
|
os.environ.setdefault("PASSKEY_RP_ID", args.rp_id)
|
|
if args.rp_name:
|
|
os.environ["PASSKEY_RP_NAME"] = args.rp_name
|
|
if origin:
|
|
os.environ["PASSKEY_ORIGIN"] = origin
|
|
if getattr(args, "auth_host", None):
|
|
os.environ["PASSKEY_AUTH_HOST"] = args.auth_host
|
|
else:
|
|
# Preserve pre-set env variable if CLI option omitted
|
|
args.auth_host = os.environ.get("PASSKEY_AUTH_HOST")
|
|
|
|
if args.auth_host:
|
|
validate_auth_host(args.auth_host, args.rp_id)
|
|
from passkey.util import hostutil as _hostutil # local import
|
|
|
|
_hostutil.reload_config()
|
|
|
|
# One-time initialization + bootstrap before starting any server processes.
|
|
# Lifespan in worker processes will call globals.init with bootstrap disabled.
|
|
from passkey import globals as _globals # local import
|
|
|
|
asyncio.run(
|
|
_globals.init(
|
|
rp_id=args.rp_id,
|
|
rp_name=args.rp_name,
|
|
origin=origin,
|
|
default_admin=os.getenv("PASSKEY_DEFAULT_ADMIN") or None,
|
|
default_org=os.getenv("PASSKEY_DEFAULT_ORG") or None,
|
|
bootstrap=True,
|
|
)
|
|
)
|
|
|
|
# Handle recover-admin command (no server start)
|
|
if args.command == "reset":
|
|
from passkey.fastapi import reset as reset_cmd # local import
|
|
|
|
exit_code = reset_cmd.run(getattr(args, "query", None))
|
|
raise SystemExit(exit_code)
|
|
|
|
if args.command in {"serve", "dev"}:
|
|
run_kwargs: dict = {
|
|
"reload": devmode,
|
|
"log_level": "info",
|
|
}
|
|
if uds:
|
|
run_kwargs["uds"] = uds
|
|
else:
|
|
if not all_ifaces:
|
|
run_kwargs["host"] = host
|
|
run_kwargs["port"] = port
|
|
|
|
if devmode:
|
|
if os.environ.get("PASSKEY_BUN_PARENT") != "1":
|
|
os.environ["PASSKEY_BUN_PARENT"] = "1"
|
|
frontend.run_dev()
|
|
|
|
if all_ifaces and not uds:
|
|
if devmode:
|
|
run_kwargs["host"] = "::"
|
|
run_kwargs["port"] = port
|
|
uvicorn.run("passkey.fastapi:app", **run_kwargs)
|
|
else:
|
|
from uvicorn import Config, Server # noqa: E402 local import
|
|
|
|
from passkey.fastapi import (
|
|
app as fastapi_app, # noqa: E402 local import
|
|
)
|
|
|
|
async def serve_both():
|
|
servers = []
|
|
assert port is not None
|
|
for h in ("0.0.0.0", "::"):
|
|
try:
|
|
cfg = Config(
|
|
app=fastapi_app,
|
|
host=h,
|
|
port=port,
|
|
log_level="info",
|
|
)
|
|
servers.append(Server(cfg))
|
|
except Exception as e: # pragma: no cover
|
|
logging.warning(f"Failed to configure server for {h}: {e}")
|
|
tasks = [asyncio.create_task(s.serve()) for s in servers]
|
|
await asyncio.gather(*tasks)
|
|
|
|
asyncio.run(serve_both())
|
|
else:
|
|
uvicorn.run("passkey.fastapi:app", **run_kwargs)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|