From 00693c56fa31999d4d8b397dd758a3a1fc167d47 Mon Sep 17 00:00:00 2001 From: Leo Vasanko Date: Tue, 5 Aug 2025 06:41:07 -0600 Subject: [PATCH] DB refactor (currently broken) --- passkey/db/sql.py | 266 +++++++++++++++++++++------------------------- 1 file changed, 120 insertions(+), 146 deletions(-) diff --git a/passkey/db/sql.py b/passkey/db/sql.py index 3f0d671..3243890 100644 --- a/passkey/db/sql.py +++ b/passkey/db/sql.py @@ -5,8 +5,8 @@ This module provides an async database layer using SQLAlchemy async mode for managing users and credentials in a WebAuthn authentication system. """ -from contextlib import asynccontextmanager from datetime import datetime +from functools import wraps from uuid import UUID from sqlalchemy import ( @@ -20,7 +20,7 @@ from sqlalchemy import ( update, ) from sqlalchemy.dialects.sqlite import BLOB, JSON -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from . import Credential, Session, User @@ -28,6 +28,21 @@ from . import Credential, Session, User DB_PATH = "sqlite+aiosqlite:///webauthn.db" +def with_session(func): + """Decorator that provides a database session with transaction to the method.""" + + @wraps(func) + async def wrapper(self, *args, **kwargs): + async with self.async_session_factory() as session: + async with session.begin(): + result = await func(self, session, *args, **kwargs) + await session.flush() + return result + await session.commit() + + return wrapper + + # SQLAlchemy Models class Base(DeclarativeBase): pass @@ -85,32 +100,26 @@ class SessionModel(Base): user: Mapped["UserModel"] = relationship("UserModel") -# Global engine and session factory -engine = create_async_engine(DB_PATH, echo=False) -async_session_factory = async_sessionmaker(engine, expire_on_commit=False) - - -@asynccontextmanager -async def connect(): - """Context manager for database connections.""" - async with async_session_factory() as session: - yield DB(session) - await session.commit() - - class DB: - def __init__(self, session: AsyncSession): - self.session = session + """Database class that handles its own connections.""" + + def __init__(self, db_path: str = DB_PATH): + """Initialize with database path.""" + self.engine = create_async_engine(db_path, echo=False) + self.async_session_factory = async_sessionmaker( + self.engine, expire_on_commit=False + ) async def init_db(self) -> None: """Initialize database tables.""" - async with engine.begin() as conn: + async with self.engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) - async def get_user_by_user_uuid(self, user_uuid: UUID) -> User: + @with_session + async def get_user_by_user_uuid(self, session, user_uuid: UUID) -> User: """Get user record by WebAuthn user UUID.""" stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes) - result = await self.session.execute(stmt) + result = await session.execute(stmt) user_model = result.scalar_one_or_none() if user_model: @@ -123,7 +132,8 @@ class DB: ) raise ValueError("User not found") - async def create_user(self, user: User) -> None: + @with_session + async def create_user(self, session, user: User) -> None: """Create a new user.""" user_model = UserModel( user_uuid=user.user_uuid.bytes, @@ -132,10 +142,10 @@ class DB: last_seen=user.last_seen, visits=user.visits, ) - self.session.add(user_model) - await self.session.flush() + session.add(user_model) - async def create_credential(self, credential: Credential) -> None: + @with_session + async def create_credential(self, session, credential: Credential) -> None: """Store a credential for a user.""" credential_model = CredentialModel( uuid=credential.uuid.bytes, @@ -148,15 +158,15 @@ class DB: last_used=credential.last_used, last_verified=credential.last_verified, ) - self.session.add(credential_model) - await self.session.flush() + session.add(credential_model) - async def get_credential_by_id(self, credential_id: bytes) -> Credential: + @with_session + async def get_credential_by_id(self, session, credential_id: bytes) -> Credential: """Get credential by credential ID.""" stmt = select(CredentialModel).where( CredentialModel.credential_id == credential_id ) - result = await self.session.execute(stmt) + result = await session.execute(stmt) credential_model = result.scalar_one_or_none() if not credential_model: @@ -173,15 +183,19 @@ class DB: last_verified=credential_model.last_verified, ) - async def get_credentials_by_user_uuid(self, user_uuid: UUID) -> list[bytes]: + @with_session + async def get_credentials_by_user_uuid( + self, session, user_uuid: UUID + ) -> list[bytes]: """Get all credential IDs for a user.""" stmt = select(CredentialModel.credential_id).where( CredentialModel.user_uuid == user_uuid.bytes ) - result = await self.session.execute(stmt) + result = await session.execute(stmt) return [row[0] for row in result.fetchall()] - async def update_credential(self, credential: Credential) -> None: + @with_session + async def update_credential(self, session, credential: Credential) -> None: """Update the sign count, created_at, last_used, and last_verified for a credential.""" stmt = ( update(CredentialModel) @@ -193,34 +207,78 @@ class DB: last_verified=credential.last_verified, ) ) - await self.session.execute(stmt) + await session.execute(stmt) - async def login(self, user_uuid: UUID, credential: Credential) -> None: + @with_session + async def login(self, session, user_uuid: UUID, credential: Credential) -> None: """Update the last_seen timestamp for a user and the credential record used for logging in.""" - async with self.session.begin(): - # Update credential - await self.update_credential(credential) - - # Update user's last_seen and increment visits - stmt = ( - update(UserModel) - .where(UserModel.user_uuid == user_uuid.bytes) - .values(last_seen=credential.last_used, visits=UserModel.visits + 1) + # Update credential + stmt = ( + update(CredentialModel) + .where(CredentialModel.credential_id == credential.credential_id) + .values( + sign_count=credential.sign_count, + created_at=credential.created_at, + last_used=credential.last_used, + last_verified=credential.last_verified, ) - await self.session.execute(stmt) + ) + await session.execute(stmt) - async def delete_credential(self, uuid: UUID, user_uuid: UUID) -> None: + # Update user's last_seen and increment visits + stmt = ( + update(UserModel) + .where(UserModel.user_uuid == user_uuid.bytes) + .values(last_seen=credential.last_used, visits=UserModel.visits + 1) + ) + await session.execute(stmt) + + @with_session + async def create_user_and_credential( + self, session, user: User, credential: Credential + ) -> None: + """Create a new user and their first credential in a single transaction.""" + # Set visits to 1 for the new user since they're creating their first session + user.visits = 1 + + # Create user + user_model = UserModel( + user_uuid=user.user_uuid.bytes, + user_name=user.user_name, + created_at=user.created_at or datetime.now(), + last_seen=user.last_seen, + visits=user.visits, + ) + session.add(user_model) + + # Create credential + credential_model = CredentialModel( + uuid=credential.uuid.bytes, + credential_id=credential.credential_id, + user_uuid=credential.user_uuid.bytes, + aaguid=credential.aaguid.bytes, + public_key=credential.public_key, + sign_count=credential.sign_count, + created_at=credential.created_at, + last_used=credential.last_used, + last_verified=credential.last_verified, + ) + session.add(credential_model) + + @with_session + async def delete_credential(self, session, uuid: UUID, user_uuid: UUID) -> None: """Delete a credential by its ID.""" stmt = ( delete(CredentialModel) .where(CredentialModel.uuid == uuid.bytes) .where(CredentialModel.user_uuid == user_uuid.bytes) ) - await self.session.execute(stmt) - await self.session.commit() + await session.execute(stmt) + @with_session async def create_session( self, + session, user_uuid: UUID, key: bytes, expires: datetime, @@ -235,14 +293,14 @@ class DB: expires=expires, info=info, ) - self.session.add(session_model) - await self.session.flush() + session.add(session_model) return key - async def get_session(self, key: bytes) -> Session | None: + @with_session + async def get_session(self, session, key: bytes) -> Session | None: """Get session by 16-byte key.""" stmt = select(SessionModel).where(SessionModel.key == key) - result = await self.session.execute(stmt) + result = await session.execute(stmt) session_model = result.scalar_one_or_none() if session_model: @@ -253,113 +311,29 @@ class DB: if session_model.credential_uuid else None, expires=session_model.expires, - info=session_model.info, + info=session_model.info or {}, ) return None - async def delete_session(self, key: bytes) -> None: + @with_session + async def delete_session(self, session, key: bytes) -> None: """Delete a session by 16-byte key.""" - await self.session.execute(delete(SessionModel).where(SessionModel.key == key)) + await session.execute(delete(SessionModel).where(SessionModel.key == key)) - async def update_session(self, key: bytes, expires: datetime, info: dict) -> None: + @with_session + async def update_session( + self, session, key: bytes, expires: datetime, info: dict + ) -> None: """Update session expiration time and/or info.""" - await self.session.execute( + await session.execute( update(SessionModel) .where(SessionModel.key == key) .values(expires=expires, info=info) ) - async def cleanup_expired_sessions(self) -> None: + @with_session + async def cleanup_expired_sessions(self, session) -> None: """Remove expired sessions.""" current_time = datetime.now() stmt = delete(SessionModel).where(SessionModel.expires < current_time) - await self.session.execute(stmt) - - -# Standalone functions that handle database connections internally -async def init_database() -> None: - """Initialize database tables.""" - async with connect() as db: - await db.init_db() - - -async def create_user_and_credential(user: User, credential: Credential) -> None: - """Create a new user and their first credential in a single transaction.""" - async with connect() as db: - await db.session.begin() - # Set visits to 1 for the new user since they're creating their first session - user.visits = 1 - await db.create_user(user) - await db.create_credential(credential) - - -async def get_user_by_uuid(user_uuid: UUID) -> User: - """Get user record by WebAuthn user UUID.""" - async with connect() as db: - return await db.get_user_by_user_uuid(user_uuid) - - -async def create_credential_for_user(credential: Credential) -> None: - """Store a credential for an existing user.""" - async with connect() as db: - await db.create_credential(credential) - - -async def get_credential_by_id(credential_id: bytes) -> Credential: - """Get credential by credential ID.""" - async with connect() as db: - return await db.get_credential_by_id(credential_id) - - -async def get_user_credentials(user_uuid: UUID) -> list[bytes]: - """Get all credential IDs for a user.""" - async with connect() as db: - return await db.get_credentials_by_user_uuid(user_uuid) - - -async def login_user(user_uuid: UUID, credential: Credential) -> None: - """Update the last_seen timestamp for a user and the credential record used for logging in.""" - async with connect() as db: - await db.login(user_uuid, credential) - - -async def delete_credential(uuid: UUID, user_uuid: UUID) -> None: - """Delete a credential by its ID.""" - async with connect() as db: - await db.delete_credential(uuid, user_uuid) - - -async def create_session( - user_uuid: UUID, - key: bytes, - expires: datetime, - info: dict, - credential_uuid: UUID | None = None, -) -> bytes: - """Create a new authentication session for a user. If credential_uuid is None, creates a session without a specific credential.""" - async with connect() as db: - return await db.create_session(user_uuid, key, expires, info, credential_uuid) - - -async def get_session(key: bytes) -> Session | None: - """Get session by 16-byte key.""" - async with connect() as db: - return await db.get_session(key) - - -async def delete_session(key: bytes) -> None: - """Delete a session by 16-byte key.""" - async with connect() as db: - await db.delete_session(key) - - -async def update_session(key: bytes, expires: datetime, info: dict) -> None: - """Update session expiration time and/or info.""" - async with connect() as db: - await db.update_session(key, expires, info) - - -async def cleanup_expired_sessions() -> None: - """Remove expired sessions.""" - async with connect() as db: - await db.cleanup_expired_sessions() + await session.execute(stmt)