DB refactor (currently broken)

This commit is contained in:
Leo Vasanko 2025-08-05 06:41:07 -06:00
parent a5af644404
commit 00693c56fa

View File

@ -5,8 +5,8 @@ This module provides an async database layer using SQLAlchemy async mode
for managing users and credentials in a WebAuthn authentication system.
"""
from contextlib import asynccontextmanager
from datetime import datetime
from functools import wraps
from uuid import UUID
from sqlalchemy import (
@ -20,7 +20,7 @@ from sqlalchemy import (
update,
)
from sqlalchemy.dialects.sqlite import BLOB, JSON
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from . import Credential, Session, User
@ -28,6 +28,21 @@ from . import Credential, Session, User
DB_PATH = "sqlite+aiosqlite:///webauthn.db"
def with_session(func):
"""Decorator that provides a database session with transaction to the method."""
@wraps(func)
async def wrapper(self, *args, **kwargs):
async with self.async_session_factory() as session:
async with session.begin():
result = await func(self, session, *args, **kwargs)
await session.flush()
return result
await session.commit()
return wrapper
# SQLAlchemy Models
class Base(DeclarativeBase):
pass
@ -85,32 +100,26 @@ class SessionModel(Base):
user: Mapped["UserModel"] = relationship("UserModel")
# Global engine and session factory
engine = create_async_engine(DB_PATH, echo=False)
async_session_factory = async_sessionmaker(engine, expire_on_commit=False)
@asynccontextmanager
async def connect():
"""Context manager for database connections."""
async with async_session_factory() as session:
yield DB(session)
await session.commit()
class DB:
def __init__(self, session: AsyncSession):
self.session = session
"""Database class that handles its own connections."""
def __init__(self, db_path: str = DB_PATH):
"""Initialize with database path."""
self.engine = create_async_engine(db_path, echo=False)
self.async_session_factory = async_sessionmaker(
self.engine, expire_on_commit=False
)
async def init_db(self) -> None:
"""Initialize database tables."""
async with engine.begin() as conn:
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def get_user_by_user_uuid(self, user_uuid: UUID) -> User:
@with_session
async def get_user_by_user_uuid(self, session, user_uuid: UUID) -> User:
"""Get user record by WebAuthn user UUID."""
stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes)
result = await self.session.execute(stmt)
result = await session.execute(stmt)
user_model = result.scalar_one_or_none()
if user_model:
@ -123,7 +132,8 @@ class DB:
)
raise ValueError("User not found")
async def create_user(self, user: User) -> None:
@with_session
async def create_user(self, session, user: User) -> None:
"""Create a new user."""
user_model = UserModel(
user_uuid=user.user_uuid.bytes,
@ -132,10 +142,10 @@ class DB:
last_seen=user.last_seen,
visits=user.visits,
)
self.session.add(user_model)
await self.session.flush()
session.add(user_model)
async def create_credential(self, credential: Credential) -> None:
@with_session
async def create_credential(self, session, credential: Credential) -> None:
"""Store a credential for a user."""
credential_model = CredentialModel(
uuid=credential.uuid.bytes,
@ -148,15 +158,15 @@ class DB:
last_used=credential.last_used,
last_verified=credential.last_verified,
)
self.session.add(credential_model)
await self.session.flush()
session.add(credential_model)
async def get_credential_by_id(self, credential_id: bytes) -> Credential:
@with_session
async def get_credential_by_id(self, session, credential_id: bytes) -> Credential:
"""Get credential by credential ID."""
stmt = select(CredentialModel).where(
CredentialModel.credential_id == credential_id
)
result = await self.session.execute(stmt)
result = await session.execute(stmt)
credential_model = result.scalar_one_or_none()
if not credential_model:
@ -173,15 +183,19 @@ class DB:
last_verified=credential_model.last_verified,
)
async def get_credentials_by_user_uuid(self, user_uuid: UUID) -> list[bytes]:
@with_session
async def get_credentials_by_user_uuid(
self, session, user_uuid: UUID
) -> list[bytes]:
"""Get all credential IDs for a user."""
stmt = select(CredentialModel.credential_id).where(
CredentialModel.user_uuid == user_uuid.bytes
)
result = await self.session.execute(stmt)
result = await session.execute(stmt)
return [row[0] for row in result.fetchall()]
async def update_credential(self, credential: Credential) -> None:
@with_session
async def update_credential(self, session, credential: Credential) -> None:
"""Update the sign count, created_at, last_used, and last_verified for a credential."""
stmt = (
update(CredentialModel)
@ -193,34 +207,78 @@ class DB:
last_verified=credential.last_verified,
)
)
await self.session.execute(stmt)
await session.execute(stmt)
async def login(self, user_uuid: UUID, credential: Credential) -> None:
@with_session
async def login(self, session, user_uuid: UUID, credential: Credential) -> None:
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
async with self.session.begin():
# Update credential
await self.update_credential(credential)
# Update user's last_seen and increment visits
stmt = (
update(UserModel)
.where(UserModel.user_uuid == user_uuid.bytes)
.values(last_seen=credential.last_used, visits=UserModel.visits + 1)
# Update credential
stmt = (
update(CredentialModel)
.where(CredentialModel.credential_id == credential.credential_id)
.values(
sign_count=credential.sign_count,
created_at=credential.created_at,
last_used=credential.last_used,
last_verified=credential.last_verified,
)
await self.session.execute(stmt)
)
await session.execute(stmt)
async def delete_credential(self, uuid: UUID, user_uuid: UUID) -> None:
# Update user's last_seen and increment visits
stmt = (
update(UserModel)
.where(UserModel.user_uuid == user_uuid.bytes)
.values(last_seen=credential.last_used, visits=UserModel.visits + 1)
)
await session.execute(stmt)
@with_session
async def create_user_and_credential(
self, session, user: User, credential: Credential
) -> None:
"""Create a new user and their first credential in a single transaction."""
# Set visits to 1 for the new user since they're creating their first session
user.visits = 1
# Create user
user_model = UserModel(
user_uuid=user.user_uuid.bytes,
user_name=user.user_name,
created_at=user.created_at or datetime.now(),
last_seen=user.last_seen,
visits=user.visits,
)
session.add(user_model)
# Create credential
credential_model = CredentialModel(
uuid=credential.uuid.bytes,
credential_id=credential.credential_id,
user_uuid=credential.user_uuid.bytes,
aaguid=credential.aaguid.bytes,
public_key=credential.public_key,
sign_count=credential.sign_count,
created_at=credential.created_at,
last_used=credential.last_used,
last_verified=credential.last_verified,
)
session.add(credential_model)
@with_session
async def delete_credential(self, session, uuid: UUID, user_uuid: UUID) -> None:
"""Delete a credential by its ID."""
stmt = (
delete(CredentialModel)
.where(CredentialModel.uuid == uuid.bytes)
.where(CredentialModel.user_uuid == user_uuid.bytes)
)
await self.session.execute(stmt)
await self.session.commit()
await session.execute(stmt)
@with_session
async def create_session(
self,
session,
user_uuid: UUID,
key: bytes,
expires: datetime,
@ -235,14 +293,14 @@ class DB:
expires=expires,
info=info,
)
self.session.add(session_model)
await self.session.flush()
session.add(session_model)
return key
async def get_session(self, key: bytes) -> Session | None:
@with_session
async def get_session(self, session, key: bytes) -> Session | None:
"""Get session by 16-byte key."""
stmt = select(SessionModel).where(SessionModel.key == key)
result = await self.session.execute(stmt)
result = await session.execute(stmt)
session_model = result.scalar_one_or_none()
if session_model:
@ -253,113 +311,29 @@ class DB:
if session_model.credential_uuid
else None,
expires=session_model.expires,
info=session_model.info,
info=session_model.info or {},
)
return None
async def delete_session(self, key: bytes) -> None:
@with_session
async def delete_session(self, session, key: bytes) -> None:
"""Delete a session by 16-byte key."""
await self.session.execute(delete(SessionModel).where(SessionModel.key == key))
await session.execute(delete(SessionModel).where(SessionModel.key == key))
async def update_session(self, key: bytes, expires: datetime, info: dict) -> None:
@with_session
async def update_session(
self, session, key: bytes, expires: datetime, info: dict
) -> None:
"""Update session expiration time and/or info."""
await self.session.execute(
await session.execute(
update(SessionModel)
.where(SessionModel.key == key)
.values(expires=expires, info=info)
)
async def cleanup_expired_sessions(self) -> None:
@with_session
async def cleanup_expired_sessions(self, session) -> None:
"""Remove expired sessions."""
current_time = datetime.now()
stmt = delete(SessionModel).where(SessionModel.expires < current_time)
await self.session.execute(stmt)
# Standalone functions that handle database connections internally
async def init_database() -> None:
"""Initialize database tables."""
async with connect() as db:
await db.init_db()
async def create_user_and_credential(user: User, credential: Credential) -> None:
"""Create a new user and their first credential in a single transaction."""
async with connect() as db:
await db.session.begin()
# Set visits to 1 for the new user since they're creating their first session
user.visits = 1
await db.create_user(user)
await db.create_credential(credential)
async def get_user_by_uuid(user_uuid: UUID) -> User:
"""Get user record by WebAuthn user UUID."""
async with connect() as db:
return await db.get_user_by_user_uuid(user_uuid)
async def create_credential_for_user(credential: Credential) -> None:
"""Store a credential for an existing user."""
async with connect() as db:
await db.create_credential(credential)
async def get_credential_by_id(credential_id: bytes) -> Credential:
"""Get credential by credential ID."""
async with connect() as db:
return await db.get_credential_by_id(credential_id)
async def get_user_credentials(user_uuid: UUID) -> list[bytes]:
"""Get all credential IDs for a user."""
async with connect() as db:
return await db.get_credentials_by_user_uuid(user_uuid)
async def login_user(user_uuid: UUID, credential: Credential) -> None:
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
async with connect() as db:
await db.login(user_uuid, credential)
async def delete_credential(uuid: UUID, user_uuid: UUID) -> None:
"""Delete a credential by its ID."""
async with connect() as db:
await db.delete_credential(uuid, user_uuid)
async def create_session(
user_uuid: UUID,
key: bytes,
expires: datetime,
info: dict,
credential_uuid: UUID | None = None,
) -> bytes:
"""Create a new authentication session for a user. If credential_uuid is None, creates a session without a specific credential."""
async with connect() as db:
return await db.create_session(user_uuid, key, expires, info, credential_uuid)
async def get_session(key: bytes) -> Session | None:
"""Get session by 16-byte key."""
async with connect() as db:
return await db.get_session(key)
async def delete_session(key: bytes) -> None:
"""Delete a session by 16-byte key."""
async with connect() as db:
await db.delete_session(key)
async def update_session(key: bytes, expires: datetime, info: dict) -> None:
"""Update session expiration time and/or info."""
async with connect() as db:
await db.update_session(key, expires, info)
async def cleanup_expired_sessions() -> None:
"""Remove expired sessions."""
async with connect() as db:
await db.cleanup_expired_sessions()
await session.execute(stmt)