DB refactor (currently broken)

This commit is contained in:
Leo Vasanko 2025-08-05 06:41:07 -06:00
parent a5af644404
commit 00693c56fa

View File

@ -5,8 +5,8 @@ This module provides an async database layer using SQLAlchemy async mode
for managing users and credentials in a WebAuthn authentication system. for managing users and credentials in a WebAuthn authentication system.
""" """
from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
from functools import wraps
from uuid import UUID from uuid import UUID
from sqlalchemy import ( from sqlalchemy import (
@ -20,7 +20,7 @@ from sqlalchemy import (
update, update,
) )
from sqlalchemy.dialects.sqlite import BLOB, JSON from sqlalchemy.dialects.sqlite import BLOB, JSON
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from . import Credential, Session, User from . import Credential, Session, User
@ -28,6 +28,21 @@ from . import Credential, Session, User
DB_PATH = "sqlite+aiosqlite:///webauthn.db" DB_PATH = "sqlite+aiosqlite:///webauthn.db"
def with_session(func):
"""Decorator that provides a database session with transaction to the method."""
@wraps(func)
async def wrapper(self, *args, **kwargs):
async with self.async_session_factory() as session:
async with session.begin():
result = await func(self, session, *args, **kwargs)
await session.flush()
return result
await session.commit()
return wrapper
# SQLAlchemy Models # SQLAlchemy Models
class Base(DeclarativeBase): class Base(DeclarativeBase):
pass pass
@ -85,32 +100,26 @@ class SessionModel(Base):
user: Mapped["UserModel"] = relationship("UserModel") user: Mapped["UserModel"] = relationship("UserModel")
# Global engine and session factory
engine = create_async_engine(DB_PATH, echo=False)
async_session_factory = async_sessionmaker(engine, expire_on_commit=False)
@asynccontextmanager
async def connect():
"""Context manager for database connections."""
async with async_session_factory() as session:
yield DB(session)
await session.commit()
class DB: class DB:
def __init__(self, session: AsyncSession): """Database class that handles its own connections."""
self.session = session
def __init__(self, db_path: str = DB_PATH):
"""Initialize with database path."""
self.engine = create_async_engine(db_path, echo=False)
self.async_session_factory = async_sessionmaker(
self.engine, expire_on_commit=False
)
async def init_db(self) -> None: async def init_db(self) -> None:
"""Initialize database tables.""" """Initialize database tables."""
async with engine.begin() as conn: async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
async def get_user_by_user_uuid(self, user_uuid: UUID) -> User: @with_session
async def get_user_by_user_uuid(self, session, user_uuid: UUID) -> User:
"""Get user record by WebAuthn user UUID.""" """Get user record by WebAuthn user UUID."""
stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes) stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes)
result = await self.session.execute(stmt) result = await session.execute(stmt)
user_model = result.scalar_one_or_none() user_model = result.scalar_one_or_none()
if user_model: if user_model:
@ -123,7 +132,8 @@ class DB:
) )
raise ValueError("User not found") raise ValueError("User not found")
async def create_user(self, user: User) -> None: @with_session
async def create_user(self, session, user: User) -> None:
"""Create a new user.""" """Create a new user."""
user_model = UserModel( user_model = UserModel(
user_uuid=user.user_uuid.bytes, user_uuid=user.user_uuid.bytes,
@ -132,10 +142,10 @@ class DB:
last_seen=user.last_seen, last_seen=user.last_seen,
visits=user.visits, visits=user.visits,
) )
self.session.add(user_model) session.add(user_model)
await self.session.flush()
async def create_credential(self, credential: Credential) -> None: @with_session
async def create_credential(self, session, credential: Credential) -> None:
"""Store a credential for a user.""" """Store a credential for a user."""
credential_model = CredentialModel( credential_model = CredentialModel(
uuid=credential.uuid.bytes, uuid=credential.uuid.bytes,
@ -148,15 +158,15 @@ class DB:
last_used=credential.last_used, last_used=credential.last_used,
last_verified=credential.last_verified, last_verified=credential.last_verified,
) )
self.session.add(credential_model) session.add(credential_model)
await self.session.flush()
async def get_credential_by_id(self, credential_id: bytes) -> Credential: @with_session
async def get_credential_by_id(self, session, credential_id: bytes) -> Credential:
"""Get credential by credential ID.""" """Get credential by credential ID."""
stmt = select(CredentialModel).where( stmt = select(CredentialModel).where(
CredentialModel.credential_id == credential_id CredentialModel.credential_id == credential_id
) )
result = await self.session.execute(stmt) result = await session.execute(stmt)
credential_model = result.scalar_one_or_none() credential_model = result.scalar_one_or_none()
if not credential_model: if not credential_model:
@ -173,15 +183,19 @@ class DB:
last_verified=credential_model.last_verified, last_verified=credential_model.last_verified,
) )
async def get_credentials_by_user_uuid(self, user_uuid: UUID) -> list[bytes]: @with_session
async def get_credentials_by_user_uuid(
self, session, user_uuid: UUID
) -> list[bytes]:
"""Get all credential IDs for a user.""" """Get all credential IDs for a user."""
stmt = select(CredentialModel.credential_id).where( stmt = select(CredentialModel.credential_id).where(
CredentialModel.user_uuid == user_uuid.bytes CredentialModel.user_uuid == user_uuid.bytes
) )
result = await self.session.execute(stmt) result = await session.execute(stmt)
return [row[0] for row in result.fetchall()] return [row[0] for row in result.fetchall()]
async def update_credential(self, credential: Credential) -> None: @with_session
async def update_credential(self, session, credential: Credential) -> None:
"""Update the sign count, created_at, last_used, and last_verified for a credential.""" """Update the sign count, created_at, last_used, and last_verified for a credential."""
stmt = ( stmt = (
update(CredentialModel) update(CredentialModel)
@ -193,34 +207,78 @@ class DB:
last_verified=credential.last_verified, last_verified=credential.last_verified,
) )
) )
await self.session.execute(stmt) await session.execute(stmt)
async def login(self, user_uuid: UUID, credential: Credential) -> None: @with_session
async def login(self, session, user_uuid: UUID, credential: Credential) -> None:
"""Update the last_seen timestamp for a user and the credential record used for logging in.""" """Update the last_seen timestamp for a user and the credential record used for logging in."""
async with self.session.begin(): # Update credential
# Update credential stmt = (
await self.update_credential(credential) update(CredentialModel)
.where(CredentialModel.credential_id == credential.credential_id)
# Update user's last_seen and increment visits .values(
stmt = ( sign_count=credential.sign_count,
update(UserModel) created_at=credential.created_at,
.where(UserModel.user_uuid == user_uuid.bytes) last_used=credential.last_used,
.values(last_seen=credential.last_used, visits=UserModel.visits + 1) last_verified=credential.last_verified,
) )
await self.session.execute(stmt) )
await session.execute(stmt)
async def delete_credential(self, uuid: UUID, user_uuid: UUID) -> None: # Update user's last_seen and increment visits
stmt = (
update(UserModel)
.where(UserModel.user_uuid == user_uuid.bytes)
.values(last_seen=credential.last_used, visits=UserModel.visits + 1)
)
await session.execute(stmt)
@with_session
async def create_user_and_credential(
self, session, user: User, credential: Credential
) -> None:
"""Create a new user and their first credential in a single transaction."""
# Set visits to 1 for the new user since they're creating their first session
user.visits = 1
# Create user
user_model = UserModel(
user_uuid=user.user_uuid.bytes,
user_name=user.user_name,
created_at=user.created_at or datetime.now(),
last_seen=user.last_seen,
visits=user.visits,
)
session.add(user_model)
# Create credential
credential_model = CredentialModel(
uuid=credential.uuid.bytes,
credential_id=credential.credential_id,
user_uuid=credential.user_uuid.bytes,
aaguid=credential.aaguid.bytes,
public_key=credential.public_key,
sign_count=credential.sign_count,
created_at=credential.created_at,
last_used=credential.last_used,
last_verified=credential.last_verified,
)
session.add(credential_model)
@with_session
async def delete_credential(self, session, uuid: UUID, user_uuid: UUID) -> None:
"""Delete a credential by its ID.""" """Delete a credential by its ID."""
stmt = ( stmt = (
delete(CredentialModel) delete(CredentialModel)
.where(CredentialModel.uuid == uuid.bytes) .where(CredentialModel.uuid == uuid.bytes)
.where(CredentialModel.user_uuid == user_uuid.bytes) .where(CredentialModel.user_uuid == user_uuid.bytes)
) )
await self.session.execute(stmt) await session.execute(stmt)
await self.session.commit()
@with_session
async def create_session( async def create_session(
self, self,
session,
user_uuid: UUID, user_uuid: UUID,
key: bytes, key: bytes,
expires: datetime, expires: datetime,
@ -235,14 +293,14 @@ class DB:
expires=expires, expires=expires,
info=info, info=info,
) )
self.session.add(session_model) session.add(session_model)
await self.session.flush()
return key return key
async def get_session(self, key: bytes) -> Session | None: @with_session
async def get_session(self, session, key: bytes) -> Session | None:
"""Get session by 16-byte key.""" """Get session by 16-byte key."""
stmt = select(SessionModel).where(SessionModel.key == key) stmt = select(SessionModel).where(SessionModel.key == key)
result = await self.session.execute(stmt) result = await session.execute(stmt)
session_model = result.scalar_one_or_none() session_model = result.scalar_one_or_none()
if session_model: if session_model:
@ -253,113 +311,29 @@ class DB:
if session_model.credential_uuid if session_model.credential_uuid
else None, else None,
expires=session_model.expires, expires=session_model.expires,
info=session_model.info, info=session_model.info or {},
) )
return None return None
async def delete_session(self, key: bytes) -> None: @with_session
async def delete_session(self, session, key: bytes) -> None:
"""Delete a session by 16-byte key.""" """Delete a session by 16-byte key."""
await self.session.execute(delete(SessionModel).where(SessionModel.key == key)) await session.execute(delete(SessionModel).where(SessionModel.key == key))
async def update_session(self, key: bytes, expires: datetime, info: dict) -> None: @with_session
async def update_session(
self, session, key: bytes, expires: datetime, info: dict
) -> None:
"""Update session expiration time and/or info.""" """Update session expiration time and/or info."""
await self.session.execute( await session.execute(
update(SessionModel) update(SessionModel)
.where(SessionModel.key == key) .where(SessionModel.key == key)
.values(expires=expires, info=info) .values(expires=expires, info=info)
) )
async def cleanup_expired_sessions(self) -> None: @with_session
async def cleanup_expired_sessions(self, session) -> None:
"""Remove expired sessions.""" """Remove expired sessions."""
current_time = datetime.now() current_time = datetime.now()
stmt = delete(SessionModel).where(SessionModel.expires < current_time) stmt = delete(SessionModel).where(SessionModel.expires < current_time)
await self.session.execute(stmt) await session.execute(stmt)
# Standalone functions that handle database connections internally
async def init_database() -> None:
"""Initialize database tables."""
async with connect() as db:
await db.init_db()
async def create_user_and_credential(user: User, credential: Credential) -> None:
"""Create a new user and their first credential in a single transaction."""
async with connect() as db:
await db.session.begin()
# Set visits to 1 for the new user since they're creating their first session
user.visits = 1
await db.create_user(user)
await db.create_credential(credential)
async def get_user_by_uuid(user_uuid: UUID) -> User:
"""Get user record by WebAuthn user UUID."""
async with connect() as db:
return await db.get_user_by_user_uuid(user_uuid)
async def create_credential_for_user(credential: Credential) -> None:
"""Store a credential for an existing user."""
async with connect() as db:
await db.create_credential(credential)
async def get_credential_by_id(credential_id: bytes) -> Credential:
"""Get credential by credential ID."""
async with connect() as db:
return await db.get_credential_by_id(credential_id)
async def get_user_credentials(user_uuid: UUID) -> list[bytes]:
"""Get all credential IDs for a user."""
async with connect() as db:
return await db.get_credentials_by_user_uuid(user_uuid)
async def login_user(user_uuid: UUID, credential: Credential) -> None:
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
async with connect() as db:
await db.login(user_uuid, credential)
async def delete_credential(uuid: UUID, user_uuid: UUID) -> None:
"""Delete a credential by its ID."""
async with connect() as db:
await db.delete_credential(uuid, user_uuid)
async def create_session(
user_uuid: UUID,
key: bytes,
expires: datetime,
info: dict,
credential_uuid: UUID | None = None,
) -> bytes:
"""Create a new authentication session for a user. If credential_uuid is None, creates a session without a specific credential."""
async with connect() as db:
return await db.create_session(user_uuid, key, expires, info, credential_uuid)
async def get_session(key: bytes) -> Session | None:
"""Get session by 16-byte key."""
async with connect() as db:
return await db.get_session(key)
async def delete_session(key: bytes) -> None:
"""Delete a session by 16-byte key."""
async with connect() as db:
await db.delete_session(key)
async def update_session(key: bytes, expires: datetime, info: dict) -> None:
"""Update session expiration time and/or info."""
async with connect() as db:
await db.update_session(key, expires, info)
async def cleanup_expired_sessions() -> None:
"""Remove expired sessions."""
async with connect() as db:
await db.cleanup_expired_sessions()