diff --git a/passkey/authsession.py b/passkey/authsession.py index f848a34..27abac2 100644 --- a/passkey/authsession.py +++ b/passkey/authsession.py @@ -11,7 +11,8 @@ independent of any web framework: from datetime import datetime, timedelta from uuid import UUID -from .db import Session, db +from .db import Session +from .globals import db from .util.tokens import create_token, reset_key, session_key EXPIRES = timedelta(hours=24) diff --git a/passkey/bootstrap.py b/passkey/bootstrap.py index 41cdd51..9bfa654 100644 --- a/passkey/bootstrap.py +++ b/passkey/bootstrap.py @@ -11,7 +11,8 @@ from datetime import datetime import uuid7 from .authsession import expires -from .db import Org, Permission, User, db +from .db import Org, Permission, User +from .globals import db from .util import passphrase, tokens diff --git a/passkey/db/__init__.py b/passkey/db/__init__.py index 73d8647..72a106a 100644 --- a/passkey/db/__init__.py +++ b/passkey/db/__init__.py @@ -232,28 +232,6 @@ class DatabaseInterface(ABC): """Create a new user and their first credential in a transaction.""" -class DatabaseManager: - """Manager for the global database instance.""" - - def __init__(self): - self._instance: DatabaseInterface | None = None - - @property - def instance(self) -> DatabaseInterface: - if self._instance is None: - raise RuntimeError( - "Database not initialized. Call e.g. db.sql.init() first." - ) - return self._instance - - @instance.setter - def instance(self, instance: DatabaseInterface) -> None: - self._instance = instance - - -db = DatabaseManager() - - __all__ = [ "User", "Credential", @@ -261,5 +239,4 @@ __all__ = [ "Org", "Permission", "DatabaseInterface", - "db", ] diff --git a/passkey/db/sql.py b/passkey/db/sql.py index 2ccfcbc..695efd2 100644 --- a/passkey/db/sql.py +++ b/passkey/db/sql.py @@ -23,7 +23,8 @@ from sqlalchemy.dialects.sqlite import BLOB, JSON from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from . import Credential, DatabaseInterface, Org, Permission, Session, User, db +from ..globals import db +from . import Credential, DatabaseInterface, Org, Permission, Session, User DB_PATH = "sqlite+aiosqlite:///passkey-auth.sqlite" diff --git a/passkey/fastapi/__main__.py b/passkey/fastapi/__main__.py index 1fb7aa7..5bdab3a 100644 --- a/passkey/fastapi/__main__.py +++ b/passkey/fastapi/__main__.py @@ -1,9 +1,13 @@ import argparse +import asyncio +import logging import uvicorn def main(): + # Configure logging to remove the "ERROR:root:" prefix + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) parser = argparse.ArgumentParser( description="Run the passkey authentication server" ) @@ -16,9 +20,25 @@ def main(): parser.add_argument( "--dev", action="store_true", help="Enable development mode with auto-reload" ) + parser.add_argument( + "--rp-id", default="localhost", help="Relying Party ID (default: localhost)" + ) + parser.add_argument("--rp-name", help="Relying Party name (default: same as rp-id)") + parser.add_argument("--origin", help="Origin URL (default: https://)") args = parser.parse_args() + # Initialize the application + try: + from .. import globals + + asyncio.run( + globals.init(rp_id=args.rp_id, rp_name=args.rp_name, origin=args.origin) + ) + except ValueError as e: + logging.error(f"⚠️ {e}") + return + uvicorn.run( "passkey.fastapi:app", host=args.host, diff --git a/passkey/fastapi/api.py b/passkey/fastapi/api.py index 0e262fb..7164a41 100644 --- a/passkey/fastapi/api.py +++ b/passkey/fastapi/api.py @@ -17,7 +17,7 @@ from passkey.util import passphrase from .. import aaguid from ..authsession import delete_credential, get_reset, get_session -from ..db import db +from ..globals import db from ..util.tokens import session_key from . import session diff --git a/passkey/fastapi/mainapp.py b/passkey/fastapi/mainapp.py index cb96db9..9c5aece 100644 --- a/passkey/fastapi/mainapp.py +++ b/passkey/fastapi/mainapp.py @@ -1,6 +1,5 @@ import contextlib import logging -from contextlib import asynccontextmanager from pathlib import Path from fastapi import Cookie, FastAPI, Request, Response @@ -8,7 +7,6 @@ from fastapi.responses import FileResponse, JSONResponse from fastapi.staticfiles import StaticFiles from ..authsession import get_session -from ..db import db from . import ws from .api import register_api_routes from .reset import register_reset_routes @@ -16,25 +14,7 @@ from .reset import register_reset_routes STATIC_DIR = Path(__file__).parent.parent / "frontend-build" -@asynccontextmanager -async def lifespan(app: FastAPI): - # Test if we have a database already initialized, otherwise use SQL - try: - db.instance - except RuntimeError: - from ..db import sql - - await sql.init() - - # Bootstrap system if needed - from ..bootstrap import bootstrap_if_needed - - await bootstrap_if_needed() - - yield - - -app = FastAPI(lifespan=lifespan) +app = FastAPI() # Global exception handlers diff --git a/passkey/fastapi/reset.py b/passkey/fastapi/reset.py index 60ff5b3..01e0dbd 100644 --- a/passkey/fastapi/reset.py +++ b/passkey/fastapi/reset.py @@ -4,7 +4,7 @@ from fastapi import Cookie, HTTPException, Request, Response from fastapi.responses import RedirectResponse from ..authsession import expires, get_session -from ..db import db +from ..globals import db from ..util import passphrase, tokens from . import session diff --git a/passkey/fastapi/ws.py b/passkey/fastapi/ws.py index 4ec49b7..ba97e81 100644 --- a/passkey/fastapi/ws.py +++ b/passkey/fastapi/ws.py @@ -8,8 +8,8 @@ from fastapi import Cookie, FastAPI, Query, WebSocket, WebSocketDisconnect from webauthn.helpers.exceptions import InvalidAuthenticationResponse from ..authsession import EXPIRES, create_session, get_reset, get_session -from ..db import User, db -from ..sansio import Passkey +from ..db import User +from ..globals import db, passkey from ..util import passphrase from ..util.tokens import create_token, session_key from .session import infodict @@ -36,12 +36,6 @@ def websocket_error_handler(func): # Create a FastAPI subapp for WebSocket endpoints app = FastAPI() -# Initialize the passkey instance -passkey = Passkey( - rp_id="localhost", - rp_name="Passkey Auth", -) - async def register_chat( ws: WebSocket, @@ -51,7 +45,7 @@ async def register_chat( origin: str | None = None, ): """Generate registration options and send them to the client.""" - options, challenge = passkey.reg_generate_options( + options, challenge = passkey.instance.reg_generate_options( user_id=user_uuid, user_name=user_name, credential_ids=credential_ids, @@ -59,7 +53,7 @@ async def register_chat( ) await ws.send_json(options) response = await ws.receive_json() - return passkey.reg_verify(response, challenge, user_uuid, origin=origin) + return passkey.instance.reg_verify(response, challenge, user_uuid, origin=origin) @app.websocket("/register") @@ -139,14 +133,14 @@ async def websocket_register_add(ws: WebSocket, auth=Cookie(None)): @websocket_error_handler async def websocket_authenticate(ws: WebSocket): origin = ws.headers["origin"] - options, challenge = passkey.auth_generate_options() + options, challenge = passkey.instance.auth_generate_options() await ws.send_json(options) # Wait for the client to use his authenticator to authenticate - credential = passkey.auth_parse(await ws.receive_json()) + credential = passkey.instance.auth_parse(await ws.receive_json()) # Fetch from the database by credential ID stored_cred = await db.instance.get_credential_by_id(credential.raw_id) # Verify the credential matches the stored data - passkey.auth_verify(credential, challenge, stored_cred, origin=origin) + passkey.instance.auth_verify(credential, challenge, stored_cred, origin=origin) # Update both credential and user's last_seen timestamp await db.instance.login(stored_cred.user_uuid, stored_cred) diff --git a/passkey/globals.py b/passkey/globals.py new file mode 100644 index 0000000..c8c381d --- /dev/null +++ b/passkey/globals.py @@ -0,0 +1,56 @@ +from typing import Generic, TypeVar + +from .db import DatabaseInterface +from .sansio import Passkey + +T = TypeVar("T") + + +class Manager(Generic[T]): + """Generic manager for global instances.""" + + def __init__(self, name: str): + self._instance: T | None = None + self._name = name + + @property + def instance(self) -> T: + if self._instance is None: + raise RuntimeError( + f"{self._name} not initialized. Call globals.init() first." + ) + return self._instance + + @instance.setter + def instance(self, instance: T) -> None: + self._instance = instance + + +async def init( + rp_id: str = "localhost", rp_name: str | None = None, origin: str | None = None +) -> None: + """Initialize the global database, passkey instance, and bootstrap the system if needed.""" + # Initialize passkey instance with provided parameters + passkey.instance = Passkey( + rp_id=rp_id, + rp_name=rp_name or rp_id, + origin=origin, + ) + + # Test if we have a database already initialized, otherwise use SQL + try: + db.instance + except RuntimeError: + from .db import sql + + await sql.init() + + # Bootstrap system if needed + from .bootstrap import bootstrap_if_needed + + await bootstrap_if_needed() + + +# Global instances +passkey = Manager[Passkey]("Passkey") +db = Manager[DatabaseInterface]("Database") diff --git a/passkey/sansio.py b/passkey/sansio.py index 0471658..28904e6 100644 --- a/passkey/sansio.py +++ b/passkey/sansio.py @@ -9,6 +9,7 @@ This module provides a unified interface for WebAuthn operations including: import json from datetime import datetime +from urllib.parse import urlparse from uuid import UUID import uuid7 @@ -45,7 +46,7 @@ class Passkey: def __init__( self, rp_id: str, - rp_name: str, + rp_name: str | None = None, origin: str | None = None, supported_pub_key_algs: list[COSEAlgorithmIdentifier] | None = None, ): @@ -54,19 +55,41 @@ class Passkey: Args: rp_id: Your security domain (e.g. "example.com") - rp_name: The relying party name (e.g., "My Application" - visible to users) - origin: The origin URL of the application (e.g. "https://app.example.com"). Must be a subdomain or same as rp_id, with port and scheme but no path included. + rp_name: The relying party display name (e.g. "Example App"). May be shown in authenticators. + origin: The origin URL of the application (e.g. "https://app.example.com"). + If no scheme is provided, "https://" will be prepended. + Must be a subdomain or same as rp_id, with port and scheme but no path included. supported_pub_key_algs: List of supported COSE algorithms (default is EDDSA, ECDSA_SHA_256, RSASSA_PKCS1_v1_5_SHA_256). + + Raises: + ValueError: If the origin domain doesn't match or isn't a subdomain of rp_id. """ self.rp_id = rp_id - self.rp_name = rp_name - self.origin = origin or f"https://{rp_id}" + self.rp_name = rp_name or rp_id + self.origin = self._normalize_and_validate_origin(origin, rp_id) self.supported_pub_key_algs = supported_pub_key_algs or [ COSEAlgorithmIdentifier.EDDSA, COSEAlgorithmIdentifier.ECDSA_SHA_256, COSEAlgorithmIdentifier.RSASSA_PKCS1_v1_5_SHA_256, ] + def _normalize_and_validate_origin(self, origin: str | None, rp_id: str) -> str: + if origin is None: + origin = f"https://{rp_id}" + elif "://" not in origin: + origin = f"https://{origin}" + + hostname = urlparse(origin).hostname + if not hostname: + raise ValueError(f"Invalid origin URL: no hostname found in '{origin}'") + + if hostname == rp_id or hostname.endswith(f".{rp_id}"): + return origin + + raise ValueError( + f"Origin domain '{hostname}' must be the same as or a subdomain of rp_id '{rp_id}'" + ) + ### Registration Methods ### def reg_generate_options(