Globals restructured to their own module. Origin and RP definition.

This commit is contained in:
Leo Vasanko 2025-08-06 13:23:35 -06:00
parent 5a129220aa
commit dcca3e3fbd
11 changed files with 120 additions and 67 deletions

View File

@ -11,7 +11,8 @@ independent of any web framework:
from datetime import datetime, timedelta
from uuid import UUID
from .db import Session, db
from .db import Session
from .globals import db
from .util.tokens import create_token, reset_key, session_key
EXPIRES = timedelta(hours=24)

View File

@ -11,7 +11,8 @@ from datetime import datetime
import uuid7
from .authsession import expires
from .db import Org, Permission, User, db
from .db import Org, Permission, User
from .globals import db
from .util import passphrase, tokens

View File

@ -232,28 +232,6 @@ class DatabaseInterface(ABC):
"""Create a new user and their first credential in a transaction."""
class DatabaseManager:
"""Manager for the global database instance."""
def __init__(self):
self._instance: DatabaseInterface | None = None
@property
def instance(self) -> DatabaseInterface:
if self._instance is None:
raise RuntimeError(
"Database not initialized. Call e.g. db.sql.init() first."
)
return self._instance
@instance.setter
def instance(self, instance: DatabaseInterface) -> None:
self._instance = instance
db = DatabaseManager()
__all__ = [
"User",
"Credential",
@ -261,5 +239,4 @@ __all__ = [
"Org",
"Permission",
"DatabaseInterface",
"db",
]

View File

@ -23,7 +23,8 @@ from sqlalchemy.dialects.sqlite import BLOB, JSON
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from . import Credential, DatabaseInterface, Org, Permission, Session, User, db
from ..globals import db
from . import Credential, DatabaseInterface, Org, Permission, Session, User
DB_PATH = "sqlite+aiosqlite:///passkey-auth.sqlite"

View File

@ -1,9 +1,13 @@
import argparse
import asyncio
import logging
import uvicorn
def main():
# Configure logging to remove the "ERROR:root:" prefix
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
parser = argparse.ArgumentParser(
description="Run the passkey authentication server"
)
@ -16,9 +20,25 @@ def main():
parser.add_argument(
"--dev", action="store_true", help="Enable development mode with auto-reload"
)
parser.add_argument(
"--rp-id", default="localhost", help="Relying Party ID (default: localhost)"
)
parser.add_argument("--rp-name", help="Relying Party name (default: same as rp-id)")
parser.add_argument("--origin", help="Origin URL (default: https://<rp-id>)")
args = parser.parse_args()
# Initialize the application
try:
from .. import globals
asyncio.run(
globals.init(rp_id=args.rp_id, rp_name=args.rp_name, origin=args.origin)
)
except ValueError as e:
logging.error(f"⚠️ {e}")
return
uvicorn.run(
"passkey.fastapi:app",
host=args.host,

View File

@ -17,7 +17,7 @@ from passkey.util import passphrase
from .. import aaguid
from ..authsession import delete_credential, get_reset, get_session
from ..db import db
from ..globals import db
from ..util.tokens import session_key
from . import session

View File

@ -1,6 +1,5 @@
import contextlib
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import Cookie, FastAPI, Request, Response
@ -8,7 +7,6 @@ from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from ..authsession import get_session
from ..db import db
from . import ws
from .api import register_api_routes
from .reset import register_reset_routes
@ -16,25 +14,7 @@ from .reset import register_reset_routes
STATIC_DIR = Path(__file__).parent.parent / "frontend-build"
@asynccontextmanager
async def lifespan(app: FastAPI):
# Test if we have a database already initialized, otherwise use SQL
try:
db.instance
except RuntimeError:
from ..db import sql
await sql.init()
# Bootstrap system if needed
from ..bootstrap import bootstrap_if_needed
await bootstrap_if_needed()
yield
app = FastAPI(lifespan=lifespan)
app = FastAPI()
# Global exception handlers

View File

@ -4,7 +4,7 @@ from fastapi import Cookie, HTTPException, Request, Response
from fastapi.responses import RedirectResponse
from ..authsession import expires, get_session
from ..db import db
from ..globals import db
from ..util import passphrase, tokens
from . import session

View File

@ -8,8 +8,8 @@ from fastapi import Cookie, FastAPI, Query, WebSocket, WebSocketDisconnect
from webauthn.helpers.exceptions import InvalidAuthenticationResponse
from ..authsession import EXPIRES, create_session, get_reset, get_session
from ..db import User, db
from ..sansio import Passkey
from ..db import User
from ..globals import db, passkey
from ..util import passphrase
from ..util.tokens import create_token, session_key
from .session import infodict
@ -36,12 +36,6 @@ def websocket_error_handler(func):
# Create a FastAPI subapp for WebSocket endpoints
app = FastAPI()
# Initialize the passkey instance
passkey = Passkey(
rp_id="localhost",
rp_name="Passkey Auth",
)
async def register_chat(
ws: WebSocket,
@ -51,7 +45,7 @@ async def register_chat(
origin: str | None = None,
):
"""Generate registration options and send them to the client."""
options, challenge = passkey.reg_generate_options(
options, challenge = passkey.instance.reg_generate_options(
user_id=user_uuid,
user_name=user_name,
credential_ids=credential_ids,
@ -59,7 +53,7 @@ async def register_chat(
)
await ws.send_json(options)
response = await ws.receive_json()
return passkey.reg_verify(response, challenge, user_uuid, origin=origin)
return passkey.instance.reg_verify(response, challenge, user_uuid, origin=origin)
@app.websocket("/register")
@ -139,14 +133,14 @@ async def websocket_register_add(ws: WebSocket, auth=Cookie(None)):
@websocket_error_handler
async def websocket_authenticate(ws: WebSocket):
origin = ws.headers["origin"]
options, challenge = passkey.auth_generate_options()
options, challenge = passkey.instance.auth_generate_options()
await ws.send_json(options)
# Wait for the client to use his authenticator to authenticate
credential = passkey.auth_parse(await ws.receive_json())
credential = passkey.instance.auth_parse(await ws.receive_json())
# Fetch from the database by credential ID
stored_cred = await db.instance.get_credential_by_id(credential.raw_id)
# Verify the credential matches the stored data
passkey.auth_verify(credential, challenge, stored_cred, origin=origin)
passkey.instance.auth_verify(credential, challenge, stored_cred, origin=origin)
# Update both credential and user's last_seen timestamp
await db.instance.login(stored_cred.user_uuid, stored_cred)

56
passkey/globals.py Normal file
View File

@ -0,0 +1,56 @@
from typing import Generic, TypeVar
from .db import DatabaseInterface
from .sansio import Passkey
T = TypeVar("T")
class Manager(Generic[T]):
"""Generic manager for global instances."""
def __init__(self, name: str):
self._instance: T | None = None
self._name = name
@property
def instance(self) -> T:
if self._instance is None:
raise RuntimeError(
f"{self._name} not initialized. Call globals.init() first."
)
return self._instance
@instance.setter
def instance(self, instance: T) -> None:
self._instance = instance
async def init(
rp_id: str = "localhost", rp_name: str | None = None, origin: str | None = None
) -> None:
"""Initialize the global database, passkey instance, and bootstrap the system if needed."""
# Initialize passkey instance with provided parameters
passkey.instance = Passkey(
rp_id=rp_id,
rp_name=rp_name or rp_id,
origin=origin,
)
# Test if we have a database already initialized, otherwise use SQL
try:
db.instance
except RuntimeError:
from .db import sql
await sql.init()
# Bootstrap system if needed
from .bootstrap import bootstrap_if_needed
await bootstrap_if_needed()
# Global instances
passkey = Manager[Passkey]("Passkey")
db = Manager[DatabaseInterface]("Database")

View File

@ -9,6 +9,7 @@ This module provides a unified interface for WebAuthn operations including:
import json
from datetime import datetime
from urllib.parse import urlparse
from uuid import UUID
import uuid7
@ -45,7 +46,7 @@ class Passkey:
def __init__(
self,
rp_id: str,
rp_name: str,
rp_name: str | None = None,
origin: str | None = None,
supported_pub_key_algs: list[COSEAlgorithmIdentifier] | None = None,
):
@ -54,19 +55,41 @@ class Passkey:
Args:
rp_id: Your security domain (e.g. "example.com")
rp_name: The relying party name (e.g., "My Application" - visible to users)
origin: The origin URL of the application (e.g. "https://app.example.com"). Must be a subdomain or same as rp_id, with port and scheme but no path included.
rp_name: The relying party display name (e.g. "Example App"). May be shown in authenticators.
origin: The origin URL of the application (e.g. "https://app.example.com").
If no scheme is provided, "https://" will be prepended.
Must be a subdomain or same as rp_id, with port and scheme but no path included.
supported_pub_key_algs: List of supported COSE algorithms (default is EDDSA, ECDSA_SHA_256, RSASSA_PKCS1_v1_5_SHA_256).
Raises:
ValueError: If the origin domain doesn't match or isn't a subdomain of rp_id.
"""
self.rp_id = rp_id
self.rp_name = rp_name
self.origin = origin or f"https://{rp_id}"
self.rp_name = rp_name or rp_id
self.origin = self._normalize_and_validate_origin(origin, rp_id)
self.supported_pub_key_algs = supported_pub_key_algs or [
COSEAlgorithmIdentifier.EDDSA,
COSEAlgorithmIdentifier.ECDSA_SHA_256,
COSEAlgorithmIdentifier.RSASSA_PKCS1_v1_5_SHA_256,
]
def _normalize_and_validate_origin(self, origin: str | None, rp_id: str) -> str:
if origin is None:
origin = f"https://{rp_id}"
elif "://" not in origin:
origin = f"https://{origin}"
hostname = urlparse(origin).hostname
if not hostname:
raise ValueError(f"Invalid origin URL: no hostname found in '{origin}'")
if hostname == rp_id or hostname.endswith(f".{rp_id}"):
return origin
raise ValueError(
f"Origin domain '{hostname}' must be the same as or a subdomain of rp_id '{rp_id}'"
)
### Registration Methods ###
def reg_generate_options(