diff --git a/passkey/db/__init__.py b/passkey/db/__init__.py index 0ca910e..24655a1 100644 --- a/passkey/db/__init__.py +++ b/passkey/db/__init__.py @@ -134,15 +134,26 @@ class DatabaseInterface(ABC): """Create a new user and their first credential in a transaction.""" -# Global DB instance -database_instance: DatabaseInterface | None = None +class DatabaseManager: + """Manager for the global database instance.""" + + def __init__(self): + self._instance: DatabaseInterface | None = None + + @property + def instance(self) -> DatabaseInterface: + if self._instance is None: + raise RuntimeError( + "Database not initialized. Call e.g. db.sql.init() first." + ) + return self._instance + + @instance.setter + def instance(self, instance: DatabaseInterface) -> None: + self._instance = instance -def database() -> DatabaseInterface: - """Get the global database instance.""" - if database_instance is None: - raise RuntimeError("Database not initialized. Call e.g. db.sql.init() first.") - return database_instance +db = DatabaseManager() __all__ = [ @@ -150,6 +161,5 @@ __all__ = [ "Credential", "Session", "DatabaseInterface", - "database_instance", - "database", + "db", ] diff --git a/passkey/db/sql.py b/passkey/db/sql.py index 8ddccda..6061310 100644 --- a/passkey/db/sql.py +++ b/passkey/db/sql.py @@ -23,15 +23,14 @@ from sqlalchemy.dialects.sqlite import BLOB, JSON from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship -from . import Credential, DatabaseInterface, Session, User +from . import Credential, DatabaseInterface, Session, User, db DB_PATH = "sqlite+aiosqlite:///webauthn.db" -def init(*args, **kwargs): - from .. import db - - db.database_instance = DB() +async def init(*args, **kwargs): + db.instance = DB() + await db.instance.init_db() # SQLAlchemy Models diff --git a/passkey/fastapi/api.py b/passkey/fastapi/api.py index 8217a9a..c0f0acb 100644 --- a/passkey/fastapi/api.py +++ b/passkey/fastapi/api.py @@ -14,7 +14,7 @@ from fastapi import Cookie, Depends, FastAPI, Request, Response from fastapi.security import HTTPBearer from .. import aaguid -from ..db import database +from ..db import db from ..util.tokens import session_key from . import session @@ -42,15 +42,15 @@ def register_api_routes(app: FastAPI): """Get full user information for the authenticated user.""" try: s = await session.get_session(auth, reset_allowed=True) - u = await database().get_user_by_user_uuid(s.user_uuid) + u = await db.instance.get_user_by_user_uuid(s.user_uuid) # Get all credentials for the user - credential_ids = await database().get_credentials_by_user_uuid(s.user_uuid) + credential_ids = await db.instance.get_credentials_by_user_uuid(s.user_uuid) credentials = [] user_aaguids = set() for cred_id in credential_ids: - c = await database().get_credential_by_id(cred_id) + c = await db.instance.get_credential_by_id(cred_id) # Convert AAGUID to string format aaguid_str = str(c.aaguid) @@ -102,7 +102,7 @@ def register_api_routes(app: FastAPI): """Log out the current user by clearing the session cookie and deleting from database.""" if not auth: return {"status": "success", "message": "Already logged out"} - await database().delete_session(session_key(auth)) + await db.instance.delete_session(session_key(auth)) response.delete_cookie("auth") return {"status": "success", "message": "Logged out successfully"} diff --git a/passkey/fastapi/main.py b/passkey/fastapi/main.py index 6f27c59..448c986 100644 --- a/passkey/fastapi/main.py +++ b/passkey/fastapi/main.py @@ -1,14 +1,3 @@ -""" -Minimal FastAPI WebAuthn server with WebSocket support for passkey registration and authentication. - -This module provides a simple WebAuthn implementation that: -- Uses WebSocket for real-time communication -- Supports Resident Keys (discoverable credentials) for passwordless authentication -- Maintains challenges locally per connection -- Uses async SQLite database for persistent storage of users and credentials -- Enables true passwordless authentication where users don't need to enter a user_name -""" - import contextlib import logging from contextlib import asynccontextmanager @@ -20,7 +9,6 @@ from fastapi.responses import ( ) from fastapi.staticfiles import StaticFiles -from ..db import close_db, init_db from . import session, ws from .api import register_api_routes from .reset import register_reset_routes @@ -30,9 +18,10 @@ STATIC_DIR = Path(__file__).parent.parent / "frontend-build" @asynccontextmanager async def lifespan(app: FastAPI): - await init_db() + from ..db import sql + + await sql.init() yield - await close_db() app = FastAPI(lifespan=lifespan) diff --git a/passkey/fastapi/reset.py b/passkey/fastapi/reset.py index 45dd553..f26a4aa 100644 --- a/passkey/fastapi/reset.py +++ b/passkey/fastapi/reset.py @@ -3,7 +3,7 @@ import logging from fastapi import Cookie, HTTPException, Request from fastapi.responses import RedirectResponse -from ..db import database +from ..db import db from ..util import passphrase, tokens from . import session @@ -20,7 +20,7 @@ def register_reset_routes(app): # Generate a human-readable token token = passphrase.generate() # e.g., "cross.rotate.yin.note.evoke" - await database().create_session( + await db.instance.create_session( user_uuid=s.user_uuid, key=tokens.reset_key(token), expires=session.expires(), @@ -56,7 +56,7 @@ def register_reset_routes(app): try: # Get session token to validate it exists and get user_id key = tokens.reset_key(reset_token) - sess = await database().get_session(key) + sess = await db.instance.get_session(key) if not sess: raise ValueError("Invalid or expired registration token") diff --git a/passkey/fastapi/session.py b/passkey/fastapi/session.py index 6350a74..0ab9926 100644 --- a/passkey/fastapi/session.py +++ b/passkey/fastapi/session.py @@ -14,7 +14,7 @@ from uuid import UUID from fastapi import Request, Response, WebSocket -from ..db import Session, database +from ..db import Session, db from ..util import passphrase from ..util.tokens import create_token, reset_key, session_key @@ -37,7 +37,7 @@ def infodict(request: Request | WebSocket, type: str) -> dict: async def create_session(user_uuid: UUID, info: dict, credential_uuid: UUID) -> str: """Create a new session and return a session token.""" token = create_token() - await database().create_session( + await db.instance.create_session( user_uuid=user_uuid, key=session_key(token), expires=datetime.now() + EXPIRES, @@ -56,7 +56,7 @@ async def get_session(token: str, reset_allowed=False) -> Session: else: key = session_key(token) - session = await database().get_session(key) + session = await db.instance.get_session(key) if not session: raise ValueError("Invalid or expired session token") return session @@ -65,7 +65,7 @@ async def get_session(token: str, reset_allowed=False) -> Session: async def refresh_session_token(token: str): """Refresh a session extending its expiry.""" # Get the current session - s = await database().update_session( + s = await db.instance.update_session( session_key(token), datetime.now() + EXPIRES, {} ) @@ -88,4 +88,4 @@ def set_session_cookie(response: Response, token: str) -> None: async def delete_credential(credential_uuid: UUID, auth: str): """Delete a specific credential for the current user.""" s = await get_session(auth) - await database().delete_credential(credential_uuid, s.user_uuid) + await db.instance.delete_credential(credential_uuid, s.user_uuid) diff --git a/passkey/fastapi/ws.py b/passkey/fastapi/ws.py index 31dd51e..c7025c8 100644 --- a/passkey/fastapi/ws.py +++ b/passkey/fastapi/ws.py @@ -18,7 +18,7 @@ from webauthn.helpers.exceptions import InvalidAuthenticationResponse from passkey.fastapi import session -from ..db import User, database +from ..db import User, db from ..sansio import Passkey from ..util.tokens import create_token, session_key from .session import create_session, infodict @@ -65,13 +65,13 @@ async def websocket_register_new( credential = await register_chat(ws, user_uuid, user_name, origin=origin) # Store the user and credential in the database - await database().create_user_and_credential( + await db.instance.create_user_and_credential( User(user_uuid, user_name, created_at=datetime.now()), credential, ) # Create a session token for the new user token = create_token() - await database().create_session( + await db.instance.create_session( user_uuid=user_uuid, key=session_key(token), expires=datetime.now() + session.EXPIRES, @@ -106,16 +106,16 @@ async def websocket_register_add(ws: WebSocket, auth=Cookie(None)): user_uuid = s.user_uuid # Get user information to get the user_name - user = await database().get_user_by_user_uuid(user_uuid) + user = await db.instance.get_user_by_user_uuid(user_uuid) user_name = user.user_name - challenge_ids = await database().get_credentials_by_user_uuid(user_uuid) + challenge_ids = await db.instance.get_credentials_by_user_uuid(user_uuid) # WebAuthn registration credential = await register_chat( ws, user_uuid, user_name, challenge_ids, origin ) # Store the new credential in the database - await database().create_credential(credential) + await db.instance.create_credential(credential) await ws.send_json( { @@ -144,11 +144,11 @@ async def websocket_authenticate(ws: WebSocket): # Wait for the client to use his authenticator to authenticate credential = passkey.auth_parse(await ws.receive_json()) # Fetch from the database by credential ID - stored_cred = await database().get_credential_by_id(credential.raw_id) + stored_cred = await db.instance.get_credential_by_id(credential.raw_id) # Verify the credential matches the stored data passkey.auth_verify(credential, challenge, stored_cred, origin=origin) # Update both credential and user's last_seen timestamp - await database().login(stored_cred.user_uuid, stored_cred) + await db.instance.login(stored_cred.user_uuid, stored_cred) # Create a session token for the authenticated user assert stored_cred.uuid is not None