87 lines
2.4 KiB
Python
87 lines
2.4 KiB
Python
"""Utilities for determining the auth UI host and base URLs."""
|
|
|
|
import os
|
|
from functools import lru_cache
|
|
from urllib.parse import urlparse, urlsplit
|
|
|
|
from ..globals import passkey as global_passkey
|
|
|
|
_AUTH_HOST_ENV = "PASSKEY_AUTH_HOST"
|
|
|
|
|
|
def _default_origin_scheme() -> str:
|
|
origin_url = urlparse(global_passkey.instance.origin)
|
|
return origin_url.scheme or "https"
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def _load_config() -> tuple[str | None, str] | None:
|
|
raw = os.getenv(_AUTH_HOST_ENV)
|
|
if not raw:
|
|
return None
|
|
candidate = raw.strip()
|
|
if not candidate:
|
|
return None
|
|
parsed = urlparse(candidate if "://" in candidate else f"//{candidate}")
|
|
netloc = parsed.netloc or parsed.path
|
|
if not netloc:
|
|
return None
|
|
return (parsed.scheme or None, netloc.strip("/"))
|
|
|
|
|
|
def configured_auth_host() -> str | None:
|
|
cfg = _load_config()
|
|
return cfg[1] if cfg else None
|
|
|
|
|
|
def is_root_mode() -> bool:
|
|
return _load_config() is not None
|
|
|
|
|
|
def ui_base_path() -> str:
|
|
return "/" if is_root_mode() else "/auth/"
|
|
|
|
|
|
def auth_site_base_url(scheme: str | None = None, host: str | None = None) -> str:
|
|
cfg = _load_config()
|
|
if cfg:
|
|
cfg_scheme, cfg_host = cfg
|
|
scheme_to_use = cfg_scheme or scheme or _default_origin_scheme()
|
|
netloc = cfg_host
|
|
else:
|
|
if host:
|
|
scheme_to_use = scheme or _default_origin_scheme()
|
|
netloc = host.strip("/")
|
|
else:
|
|
origin = global_passkey.instance.origin.rstrip("/")
|
|
return f"{origin}{ui_base_path()}"
|
|
|
|
base = f"{scheme_to_use}://{netloc}".rstrip("/")
|
|
path = ui_base_path().lstrip("/")
|
|
return f"{base}/{path}" if path else f"{base}/"
|
|
|
|
|
|
def reset_link_url(
|
|
token: str, scheme: str | None = None, host: str | None = None
|
|
) -> str:
|
|
base = auth_site_base_url(scheme, host)
|
|
return f"{base}{token}"
|
|
|
|
|
|
def reload_config() -> None:
|
|
_load_config.cache_clear()
|
|
|
|
|
|
def normalize_host(raw_host: str | None) -> str | None:
|
|
"""Normalize a Host header or hostname by stripping port and lowercasing."""
|
|
if not raw_host:
|
|
return None
|
|
candidate = raw_host.strip()
|
|
if not candidate:
|
|
return None
|
|
# Ensure urlsplit can parse bare hosts (prepend //)
|
|
parsed = urlsplit(candidate if "//" in candidate else f"//{candidate}")
|
|
host = parsed.hostname or parsed.path or ""
|
|
host = host.strip("[]") # Remove IPv6 brackets if present
|
|
return host.lower() if host else None
|