diff --git a/passkey/db/__init__.py b/passkey/db/__init__.py index 83cb18c..0ca910e 100644 --- a/passkey/db/__init__.py +++ b/passkey/db/__init__.py @@ -5,7 +5,8 @@ This module provides dataclasses and database abstractions for managing users, credentials, and sessions in a WebAuthn authentication system. """ -from dataclasses import dataclass, field +from abc import ABC, abstractmethod +from dataclasses import dataclass from datetime import datetime from uuid import UUID @@ -47,4 +48,108 @@ class Session: credential_uuid: UUID | None = None -__all__ = ["User", "Credential", "Session"] +class DatabaseInterface(ABC): + """Abstract base class defining the database interface. + + This class defines the public API that database implementations should provide. + Implementations may use decorators like @with_session that modify method signatures + at runtime, so this interface focuses on the logical operations rather than + exact parameter matching. + """ + + @abstractmethod + async def init_db(self) -> None: + """Initialize database tables.""" + pass + + # User operations + @abstractmethod + async def get_user_by_user_uuid(self, user_uuid: UUID) -> User: + """Get user record by WebAuthn user UUID.""" + + @abstractmethod + async def create_user(self, user: User) -> None: + """Create a new user.""" + + # Credential operations + @abstractmethod + async def create_credential(self, credential: Credential) -> None: + """Store a credential for a user.""" + + @abstractmethod + async def get_credential_by_id(self, credential_id: bytes) -> Credential: + """Get credential by credential ID.""" + + @abstractmethod + async def get_credentials_by_user_uuid(self, user_uuid: UUID) -> list[bytes]: + """Get all credential IDs for a user.""" + + @abstractmethod + async def update_credential(self, credential: Credential) -> None: + """Update the sign count, created_at, last_used, and last_verified for a credential.""" + + @abstractmethod + async def delete_credential(self, uuid: UUID, user_uuid: UUID) -> None: + """Delete a specific credential for a user.""" + + # Session operations + @abstractmethod + async def create_session( + self, + user_uuid: UUID, + key: bytes, + expires: datetime, + info: dict, + credential_uuid: UUID | None = None, + ) -> None: + """Create a new session.""" + + @abstractmethod + async def get_session(self, key: bytes) -> Session | None: + """Get session by key.""" + + @abstractmethod + async def delete_session(self, key: bytes) -> None: + """Delete session by key.""" + + @abstractmethod + async def update_session( + self, key: bytes, expires: datetime, info: dict + ) -> Session | None: + """Update session expiry and info.""" + + @abstractmethod + async def cleanup(self) -> None: + """Called periodically to clean up expired records.""" + + # Combined operations + @abstractmethod + async def login(self, user_uuid: UUID, credential: Credential) -> None: + """Update user and credential timestamps after successful login.""" + + @abstractmethod + async def create_user_and_credential( + self, user: User, credential: Credential + ) -> None: + """Create a new user and their first credential in a transaction.""" + + +# Global DB instance +database_instance: DatabaseInterface | None = None + + +def database() -> DatabaseInterface: + """Get the global database instance.""" + if database_instance is None: + raise RuntimeError("Database not initialized. Call e.g. db.sql.init() first.") + return database_instance + + +__all__ = [ + "User", + "Credential", + "Session", + "DatabaseInterface", + "database_instance", + "database", +] diff --git a/passkey/db/sql.py b/passkey/db/sql.py index 3243890..8ddccda 100644 --- a/passkey/db/sql.py +++ b/passkey/db/sql.py @@ -5,8 +5,8 @@ This module provides an async database layer using SQLAlchemy async mode for managing users and credentials in a WebAuthn authentication system. """ +from contextlib import asynccontextmanager from datetime import datetime -from functools import wraps from uuid import UUID from sqlalchemy import ( @@ -23,24 +23,15 @@ from sqlalchemy.dialects.sqlite import BLOB, JSON from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship -from . import Credential, Session, User +from . import Credential, DatabaseInterface, Session, User DB_PATH = "sqlite+aiosqlite:///webauthn.db" -def with_session(func): - """Decorator that provides a database session with transaction to the method.""" +def init(*args, **kwargs): + from .. import db - @wraps(func) - async def wrapper(self, *args, **kwargs): - async with self.async_session_factory() as session: - async with session.begin(): - result = await func(self, session, *args, **kwargs) - await session.flush() - return result - await session.commit() - - return wrapper + db.database_instance = DB() # SQLAlchemy Models @@ -100,7 +91,7 @@ class SessionModel(Base): user: Mapped["UserModel"] = relationship("UserModel") -class DB: +class DB(DatabaseInterface): """Database class that handles its own connections.""" def __init__(self, db_path: str = DB_PATH): @@ -110,230 +101,219 @@ class DB: self.engine, expire_on_commit=False ) + @asynccontextmanager + async def session(self): + """Async context manager that provides a database session with transaction.""" + async with self.async_session_factory() as session: + async with session.begin(): + yield session + await session.flush() + await session.commit() + async def init_db(self) -> None: """Initialize database tables.""" async with self.engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) - @with_session - async def get_user_by_user_uuid(self, session, user_uuid: UUID) -> User: - """Get user record by WebAuthn user UUID.""" - stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes) - result = await session.execute(stmt) - user_model = result.scalar_one_or_none() + async def get_user_by_user_uuid(self, user_uuid: UUID) -> User: + async with self.session() as session: + stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes) + result = await session.execute(stmt) + user_model = result.scalar_one_or_none() - if user_model: - return User( - user_uuid=UUID(bytes=user_model.user_uuid), - user_name=user_model.user_name, - created_at=user_model.created_at, - last_seen=user_model.last_seen, - visits=user_model.visits, + if user_model: + return User( + user_uuid=UUID(bytes=user_model.user_uuid), + user_name=user_model.user_name, + created_at=user_model.created_at, + last_seen=user_model.last_seen, + visits=user_model.visits, + ) + raise ValueError("User not found") + + async def create_user(self, user: User) -> None: + async with self.session() as session: + user_model = UserModel( + user_uuid=user.user_uuid.bytes, + user_name=user.user_name, + created_at=user.created_at or datetime.now(), + last_seen=user.last_seen, + visits=user.visits, ) - raise ValueError("User not found") + session.add(user_model) - @with_session - async def create_user(self, session, user: User) -> None: - """Create a new user.""" - user_model = UserModel( - user_uuid=user.user_uuid.bytes, - user_name=user.user_name, - created_at=user.created_at or datetime.now(), - last_seen=user.last_seen, - visits=user.visits, - ) - session.add(user_model) - - @with_session - async def create_credential(self, session, credential: Credential) -> None: - """Store a credential for a user.""" - credential_model = CredentialModel( - uuid=credential.uuid.bytes, - credential_id=credential.credential_id, - user_uuid=credential.user_uuid.bytes, - aaguid=credential.aaguid.bytes, - public_key=credential.public_key, - sign_count=credential.sign_count, - created_at=credential.created_at, - last_used=credential.last_used, - last_verified=credential.last_verified, - ) - session.add(credential_model) - - @with_session - async def get_credential_by_id(self, session, credential_id: bytes) -> Credential: - """Get credential by credential ID.""" - stmt = select(CredentialModel).where( - CredentialModel.credential_id == credential_id - ) - result = await session.execute(stmt) - credential_model = result.scalar_one_or_none() - - if not credential_model: - raise ValueError("Credential not registered") - return Credential( - uuid=UUID(bytes=credential_model.uuid), - credential_id=credential_model.credential_id, - user_uuid=UUID(bytes=credential_model.user_uuid), - aaguid=UUID(bytes=credential_model.aaguid), - public_key=credential_model.public_key, - sign_count=credential_model.sign_count, - created_at=credential_model.created_at, - last_used=credential_model.last_used, - last_verified=credential_model.last_verified, - ) - - @with_session - async def get_credentials_by_user_uuid( - self, session, user_uuid: UUID - ) -> list[bytes]: - """Get all credential IDs for a user.""" - stmt = select(CredentialModel.credential_id).where( - CredentialModel.user_uuid == user_uuid.bytes - ) - result = await session.execute(stmt) - return [row[0] for row in result.fetchall()] - - @with_session - async def update_credential(self, session, credential: Credential) -> None: - """Update the sign count, created_at, last_used, and last_verified for a credential.""" - stmt = ( - update(CredentialModel) - .where(CredentialModel.credential_id == credential.credential_id) - .values( + async def create_credential(self, credential: Credential) -> None: + async with self.session() as session: + credential_model = CredentialModel( + uuid=credential.uuid.bytes, + credential_id=credential.credential_id, + user_uuid=credential.user_uuid.bytes, + aaguid=credential.aaguid.bytes, + public_key=credential.public_key, sign_count=credential.sign_count, created_at=credential.created_at, last_used=credential.last_used, last_verified=credential.last_verified, ) - ) - await session.execute(stmt) + session.add(credential_model) - @with_session - async def login(self, session, user_uuid: UUID, credential: Credential) -> None: - """Update the last_seen timestamp for a user and the credential record used for logging in.""" - # Update credential - stmt = ( - update(CredentialModel) - .where(CredentialModel.credential_id == credential.credential_id) - .values( - sign_count=credential.sign_count, - created_at=credential.created_at, - last_used=credential.last_used, - last_verified=credential.last_verified, + async def get_credential_by_id(self, credential_id: bytes) -> Credential: + async with self.session() as session: + stmt = select(CredentialModel).where( + CredentialModel.credential_id == credential_id ) - ) - await session.execute(stmt) + result = await session.execute(stmt) + credential_model = result.scalar_one_or_none() - # Update user's last_seen and increment visits - stmt = ( - update(UserModel) - .where(UserModel.user_uuid == user_uuid.bytes) - .values(last_seen=credential.last_used, visits=UserModel.visits + 1) - ) - await session.execute(stmt) + if not credential_model: + raise ValueError("Credential not registered") + return Credential( + uuid=UUID(bytes=credential_model.uuid), + credential_id=credential_model.credential_id, + user_uuid=UUID(bytes=credential_model.user_uuid), + aaguid=UUID(bytes=credential_model.aaguid), + public_key=credential_model.public_key, + sign_count=credential_model.sign_count, + created_at=credential_model.created_at, + last_used=credential_model.last_used, + last_verified=credential_model.last_verified, + ) + + async def get_credentials_by_user_uuid(self, user_uuid: UUID) -> list[bytes]: + async with self.session() as session: + stmt = select(CredentialModel.credential_id).where( + CredentialModel.user_uuid == user_uuid.bytes + ) + result = await session.execute(stmt) + return [row[0] for row in result.fetchall()] + + async def update_credential(self, credential: Credential) -> None: + async with self.session() as session: + stmt = ( + update(CredentialModel) + .where(CredentialModel.credential_id == credential.credential_id) + .values( + sign_count=credential.sign_count, + created_at=credential.created_at, + last_used=credential.last_used, + last_verified=credential.last_verified, + ) + ) + await session.execute(stmt) + + async def login(self, user_uuid: UUID, credential: Credential) -> None: + async with self.session() as session: + # Update credential + stmt = ( + update(CredentialModel) + .where(CredentialModel.credential_id == credential.credential_id) + .values( + sign_count=credential.sign_count, + created_at=credential.created_at, + last_used=credential.last_used, + last_verified=credential.last_verified, + ) + ) + await session.execute(stmt) + + # Update user's last_seen and increment visits + stmt = ( + update(UserModel) + .where(UserModel.user_uuid == user_uuid.bytes) + .values(last_seen=credential.last_used, visits=UserModel.visits + 1) + ) + await session.execute(stmt) - @with_session async def create_user_and_credential( - self, session, user: User, credential: Credential + self, user: User, credential: Credential ) -> None: - """Create a new user and their first credential in a single transaction.""" - # Set visits to 1 for the new user since they're creating their first session - user.visits = 1 + async with self.session() as session: + # Set visits to 1 for the new user since they're creating their first session + user.visits = 1 - # Create user - user_model = UserModel( - user_uuid=user.user_uuid.bytes, - user_name=user.user_name, - created_at=user.created_at or datetime.now(), - last_seen=user.last_seen, - visits=user.visits, - ) - session.add(user_model) + # Create user + user_model = UserModel( + user_uuid=user.user_uuid.bytes, + user_name=user.user_name, + created_at=user.created_at or datetime.now(), + last_seen=user.last_seen, + visits=user.visits, + ) + session.add(user_model) - # Create credential - credential_model = CredentialModel( - uuid=credential.uuid.bytes, - credential_id=credential.credential_id, - user_uuid=credential.user_uuid.bytes, - aaguid=credential.aaguid.bytes, - public_key=credential.public_key, - sign_count=credential.sign_count, - created_at=credential.created_at, - last_used=credential.last_used, - last_verified=credential.last_verified, - ) - session.add(credential_model) + # Create credential + credential_model = CredentialModel( + uuid=credential.uuid.bytes, + credential_id=credential.credential_id, + user_uuid=credential.user_uuid.bytes, + aaguid=credential.aaguid.bytes, + public_key=credential.public_key, + sign_count=credential.sign_count, + created_at=credential.created_at, + last_used=credential.last_used, + last_verified=credential.last_verified, + ) + session.add(credential_model) - @with_session - async def delete_credential(self, session, uuid: UUID, user_uuid: UUID) -> None: - """Delete a credential by its ID.""" - stmt = ( - delete(CredentialModel) - .where(CredentialModel.uuid == uuid.bytes) - .where(CredentialModel.user_uuid == user_uuid.bytes) - ) - await session.execute(stmt) + async def delete_credential(self, uuid: UUID, user_uuid: UUID) -> None: + async with self.session() as session: + stmt = ( + delete(CredentialModel) + .where(CredentialModel.uuid == uuid.bytes) + .where(CredentialModel.user_uuid == user_uuid.bytes) + ) + await session.execute(stmt) - @with_session async def create_session( self, - session, user_uuid: UUID, key: bytes, expires: datetime, info: dict, credential_uuid: UUID | None = None, - ) -> bytes: - """Create a new authentication session for a user. If credential_uuid is None, creates a session without a specific credential.""" - session_model = SessionModel( - key=key, - user_uuid=user_uuid.bytes, - credential_uuid=credential_uuid.bytes if credential_uuid else None, - expires=expires, - info=info, - ) - session.add(session_model) - return key - - @with_session - async def get_session(self, session, key: bytes) -> Session | None: - """Get session by 16-byte key.""" - stmt = select(SessionModel).where(SessionModel.key == key) - result = await session.execute(stmt) - session_model = result.scalar_one_or_none() - - if session_model: - return Session( - key=session_model.key, - user_uuid=UUID(bytes=session_model.user_uuid), - credential_uuid=UUID(bytes=session_model.credential_uuid) - if session_model.credential_uuid - else None, - expires=session_model.expires, - info=session_model.info or {}, - ) - return None - - @with_session - async def delete_session(self, session, key: bytes) -> None: - """Delete a session by 16-byte key.""" - await session.execute(delete(SessionModel).where(SessionModel.key == key)) - - @with_session - async def update_session( - self, session, key: bytes, expires: datetime, info: dict ) -> None: - """Update session expiration time and/or info.""" - await session.execute( - update(SessionModel) - .where(SessionModel.key == key) - .values(expires=expires, info=info) - ) + async with self.session() as session: + session_model = SessionModel( + key=key, + user_uuid=user_uuid.bytes, + credential_uuid=credential_uuid.bytes if credential_uuid else None, + expires=expires, + info=info, + ) + session.add(session_model) - @with_session - async def cleanup_expired_sessions(self, session) -> None: - """Remove expired sessions.""" - current_time = datetime.now() - stmt = delete(SessionModel).where(SessionModel.expires < current_time) - await session.execute(stmt) + async def get_session(self, key: bytes) -> Session | None: + async with self.session() as session: + stmt = select(SessionModel).where(SessionModel.key == key) + result = await session.execute(stmt) + session_model = result.scalar_one_or_none() + + if session_model: + return Session( + key=session_model.key, + user_uuid=UUID(bytes=session_model.user_uuid), + credential_uuid=UUID(bytes=session_model.credential_uuid) + if session_model.credential_uuid + else None, + expires=session_model.expires, + info=session_model.info or {}, + ) + return None + + async def delete_session(self, key: bytes) -> None: + async with self.session() as session: + await session.execute(delete(SessionModel).where(SessionModel.key == key)) + + async def update_session(self, key: bytes, expires: datetime, info: dict) -> None: + async with self.session() as session: + await session.execute( + update(SessionModel) + .where(SessionModel.key == key) + .values(expires=expires, info=info) + ) + + async def cleanup(self) -> None: + async with self.session() as session: + current_time = datetime.now() + stmt = delete(SessionModel).where(SessionModel.expires < current_time) + await session.execute(stmt) diff --git a/passkey/fastapi/api.py b/passkey/fastapi/api.py index b38bf98..8217a9a 100644 --- a/passkey/fastapi/api.py +++ b/passkey/fastapi/api.py @@ -14,7 +14,7 @@ from fastapi import Cookie, Depends, FastAPI, Request, Response from fastapi.security import HTTPBearer from .. import aaguid -from ..db import sql +from ..db import database from ..util.tokens import session_key from . import session @@ -38,19 +38,19 @@ def register_api_routes(app: FastAPI): return {"status": "error", "valid": False} @app.post("/auth/user-info") - async def api_user_info(request: Request, response: Response, auth=Cookie(None)): + async def api_user_info(auth=Cookie(None)): """Get full user information for the authenticated user.""" try: s = await session.get_session(auth, reset_allowed=True) - u = await sql.get_user_by_uuid(s.user_uuid) + u = await database().get_user_by_user_uuid(s.user_uuid) # Get all credentials for the user - credential_ids = await sql.get_user_credentials(s.user_uuid) + credential_ids = await database().get_credentials_by_user_uuid(s.user_uuid) credentials = [] user_aaguids = set() for cred_id in credential_ids: - c = await sql.get_credential_by_id(cred_id) + c = await database().get_credential_by_id(cred_id) # Convert AAGUID to string format aaguid_str = str(c.aaguid) @@ -102,14 +102,12 @@ def register_api_routes(app: FastAPI): """Log out the current user by clearing the session cookie and deleting from database.""" if not auth: return {"status": "success", "message": "Already logged out"} - await sql.delete_session(session_key(auth)) + await database().delete_session(session_key(auth)) response.delete_cookie("auth") return {"status": "success", "message": "Logged out successfully"} @app.post("/auth/set-session") - async def api_set_session( - request: Request, response: Response, auth=Depends(bearer_auth) - ): + async def api_set_session(response: Response, auth=Depends(bearer_auth)): """Set session cookie from Authorization header. Fetched after login by WebSocket.""" try: user = await session.get_session(auth.credentials) diff --git a/passkey/fastapi/main.py b/passkey/fastapi/main.py index 8f20d86..6f27c59 100644 --- a/passkey/fastapi/main.py +++ b/passkey/fastapi/main.py @@ -20,7 +20,7 @@ from fastapi.responses import ( ) from fastapi.staticfiles import StaticFiles -from ..db import sql +from ..db import close_db, init_db from . import session, ws from .api import register_api_routes from .reset import register_reset_routes @@ -30,8 +30,9 @@ STATIC_DIR = Path(__file__).parent.parent / "frontend-build" @asynccontextmanager async def lifespan(app: FastAPI): - await sql.init_database() + await init_db() yield + await close_db() app = FastAPI(lifespan=lifespan) diff --git a/passkey/fastapi/reset.py b/passkey/fastapi/reset.py index a5cbd15..45dd553 100644 --- a/passkey/fastapi/reset.py +++ b/passkey/fastapi/reset.py @@ -3,7 +3,7 @@ import logging from fastapi import Cookie, HTTPException, Request from fastapi.responses import RedirectResponse -from ..db import sql +from ..db import database from ..util import passphrase, tokens from . import session @@ -20,7 +20,7 @@ def register_reset_routes(app): # Generate a human-readable token token = passphrase.generate() # e.g., "cross.rotate.yin.note.evoke" - await sql.create_session( + await database().create_session( user_uuid=s.user_uuid, key=tokens.reset_key(token), expires=session.expires(), @@ -56,7 +56,7 @@ def register_reset_routes(app): try: # Get session token to validate it exists and get user_id key = tokens.reset_key(reset_token) - sess = await sql.get_session(key) + sess = await database().get_session(key) if not sess: raise ValueError("Invalid or expired registration token") diff --git a/passkey/fastapi/session.py b/passkey/fastapi/session.py index 66d1996..6350a74 100644 --- a/passkey/fastapi/session.py +++ b/passkey/fastapi/session.py @@ -14,7 +14,7 @@ from uuid import UUID from fastapi import Request, Response, WebSocket -from ..db import Session, sql +from ..db import Session, database from ..util import passphrase from ..util.tokens import create_token, reset_key, session_key @@ -37,7 +37,7 @@ def infodict(request: Request | WebSocket, type: str) -> dict: async def create_session(user_uuid: UUID, info: dict, credential_uuid: UUID) -> str: """Create a new session and return a session token.""" token = create_token() - await sql.create_session( + await database().create_session( user_uuid=user_uuid, key=session_key(token), expires=datetime.now() + EXPIRES, @@ -56,7 +56,7 @@ async def get_session(token: str, reset_allowed=False) -> Session: else: key = session_key(token) - session = await sql.get_session(key) + session = await database().get_session(key) if not session: raise ValueError("Invalid or expired session token") return session @@ -65,7 +65,9 @@ async def get_session(token: str, reset_allowed=False) -> Session: async def refresh_session_token(token: str): """Refresh a session extending its expiry.""" # Get the current session - s = await sql.update_session(session_key(token), datetime.now() + EXPIRES, {}) + s = await database().update_session( + session_key(token), datetime.now() + EXPIRES, {} + ) if not s: raise ValueError("Session not found or expired") @@ -86,4 +88,4 @@ def set_session_cookie(response: Response, token: str) -> None: async def delete_credential(credential_uuid: UUID, auth: str): """Delete a specific credential for the current user.""" s = await get_session(auth) - await sql.delete_credential(credential_uuid, s.user_uuid) + await database().delete_credential(credential_uuid, s.user_uuid) diff --git a/passkey/fastapi/ws.py b/passkey/fastapi/ws.py index 79be5c7..31dd51e 100644 --- a/passkey/fastapi/ws.py +++ b/passkey/fastapi/ws.py @@ -18,7 +18,7 @@ from webauthn.helpers.exceptions import InvalidAuthenticationResponse from passkey.fastapi import session -from ..db import User, sql +from ..db import User, database from ..sansio import Passkey from ..util.tokens import create_token, session_key from .session import create_session, infodict @@ -65,13 +65,13 @@ async def websocket_register_new( credential = await register_chat(ws, user_uuid, user_name, origin=origin) # Store the user and credential in the database - await sql.create_user_and_credential( + await database().create_user_and_credential( User(user_uuid, user_name, created_at=datetime.now()), credential, ) # Create a session token for the new user token = create_token() - await sql.create_session( + await database().create_session( user_uuid=user_uuid, key=session_key(token), expires=datetime.now() + session.EXPIRES, @@ -106,16 +106,16 @@ async def websocket_register_add(ws: WebSocket, auth=Cookie(None)): user_uuid = s.user_uuid # Get user information to get the user_name - user = await sql.get_user_by_uuid(user_uuid) + user = await database().get_user_by_user_uuid(user_uuid) user_name = user.user_name - challenge_ids = await sql.get_user_credentials(user_uuid) + challenge_ids = await database().get_credentials_by_user_uuid(user_uuid) # WebAuthn registration credential = await register_chat( ws, user_uuid, user_name, challenge_ids, origin ) # Store the new credential in the database - await sql.create_credential_for_user(credential) + await database().create_credential(credential) await ws.send_json( { @@ -144,11 +144,11 @@ async def websocket_authenticate(ws: WebSocket): # Wait for the client to use his authenticator to authenticate credential = passkey.auth_parse(await ws.receive_json()) # Fetch from the database by credential ID - stored_cred = await sql.get_credential_by_id(credential.raw_id) + stored_cred = await database().get_credential_by_id(credential.raw_id) # Verify the credential matches the stored data passkey.auth_verify(credential, challenge, stored_cred, origin=origin) # Update both credential and user's last_seen timestamp - await sql.login_user(stored_cred.user_uuid, stored_cred) + await database().login(stored_cred.user_uuid, stored_cred) # Create a session token for the authenticated user assert stored_cred.uuid is not None