Database cleanup, base class, separated from FastAPI app.
This commit is contained in:
@@ -5,7 +5,8 @@ This module provides dataclasses and database abstractions for managing
|
||||
users, credentials, and sessions in a WebAuthn authentication system.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
@@ -47,4 +48,108 @@ class Session:
|
||||
credential_uuid: UUID | None = None
|
||||
|
||||
|
||||
__all__ = ["User", "Credential", "Session"]
|
||||
class DatabaseInterface(ABC):
|
||||
"""Abstract base class defining the database interface.
|
||||
|
||||
This class defines the public API that database implementations should provide.
|
||||
Implementations may use decorators like @with_session that modify method signatures
|
||||
at runtime, so this interface focuses on the logical operations rather than
|
||||
exact parameter matching.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def init_db(self) -> None:
|
||||
"""Initialize database tables."""
|
||||
pass
|
||||
|
||||
# User operations
|
||||
@abstractmethod
|
||||
async def get_user_by_user_uuid(self, user_uuid: UUID) -> User:
|
||||
"""Get user record by WebAuthn user UUID."""
|
||||
|
||||
@abstractmethod
|
||||
async def create_user(self, user: User) -> None:
|
||||
"""Create a new user."""
|
||||
|
||||
# Credential operations
|
||||
@abstractmethod
|
||||
async def create_credential(self, credential: Credential) -> None:
|
||||
"""Store a credential for a user."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_credential_by_id(self, credential_id: bytes) -> Credential:
|
||||
"""Get credential by credential ID."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_credentials_by_user_uuid(self, user_uuid: UUID) -> list[bytes]:
|
||||
"""Get all credential IDs for a user."""
|
||||
|
||||
@abstractmethod
|
||||
async def update_credential(self, credential: Credential) -> None:
|
||||
"""Update the sign count, created_at, last_used, and last_verified for a credential."""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_credential(self, uuid: UUID, user_uuid: UUID) -> None:
|
||||
"""Delete a specific credential for a user."""
|
||||
|
||||
# Session operations
|
||||
@abstractmethod
|
||||
async def create_session(
|
||||
self,
|
||||
user_uuid: UUID,
|
||||
key: bytes,
|
||||
expires: datetime,
|
||||
info: dict,
|
||||
credential_uuid: UUID | None = None,
|
||||
) -> None:
|
||||
"""Create a new session."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_session(self, key: bytes) -> Session | None:
|
||||
"""Get session by key."""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_session(self, key: bytes) -> None:
|
||||
"""Delete session by key."""
|
||||
|
||||
@abstractmethod
|
||||
async def update_session(
|
||||
self, key: bytes, expires: datetime, info: dict
|
||||
) -> Session | None:
|
||||
"""Update session expiry and info."""
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup(self) -> None:
|
||||
"""Called periodically to clean up expired records."""
|
||||
|
||||
# Combined operations
|
||||
@abstractmethod
|
||||
async def login(self, user_uuid: UUID, credential: Credential) -> None:
|
||||
"""Update user and credential timestamps after successful login."""
|
||||
|
||||
@abstractmethod
|
||||
async def create_user_and_credential(
|
||||
self, user: User, credential: Credential
|
||||
) -> None:
|
||||
"""Create a new user and their first credential in a transaction."""
|
||||
|
||||
|
||||
# Global DB instance
|
||||
database_instance: DatabaseInterface | None = None
|
||||
|
||||
|
||||
def database() -> DatabaseInterface:
|
||||
"""Get the global database instance."""
|
||||
if database_instance is None:
|
||||
raise RuntimeError("Database not initialized. Call e.g. db.sql.init() first.")
|
||||
return database_instance
|
||||
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
"Credential",
|
||||
"Session",
|
||||
"DatabaseInterface",
|
||||
"database_instance",
|
||||
"database",
|
||||
]
|
||||
|
||||
@@ -5,8 +5,8 @@ This module provides an async database layer using SQLAlchemy async mode
|
||||
for managing users and credentials in a WebAuthn authentication system.
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import (
|
||||
@@ -23,24 +23,15 @@ from sqlalchemy.dialects.sqlite import BLOB, JSON
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
from . import Credential, Session, User
|
||||
from . import Credential, DatabaseInterface, Session, User
|
||||
|
||||
DB_PATH = "sqlite+aiosqlite:///webauthn.db"
|
||||
|
||||
|
||||
def with_session(func):
|
||||
"""Decorator that provides a database session with transaction to the method."""
|
||||
def init(*args, **kwargs):
|
||||
from .. import db
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
async with self.async_session_factory() as session:
|
||||
async with session.begin():
|
||||
result = await func(self, session, *args, **kwargs)
|
||||
await session.flush()
|
||||
return result
|
||||
await session.commit()
|
||||
|
||||
return wrapper
|
||||
db.database_instance = DB()
|
||||
|
||||
|
||||
# SQLAlchemy Models
|
||||
@@ -100,7 +91,7 @@ class SessionModel(Base):
|
||||
user: Mapped["UserModel"] = relationship("UserModel")
|
||||
|
||||
|
||||
class DB:
|
||||
class DB(DatabaseInterface):
|
||||
"""Database class that handles its own connections."""
|
||||
|
||||
def __init__(self, db_path: str = DB_PATH):
|
||||
@@ -110,230 +101,219 @@ class DB:
|
||||
self.engine, expire_on_commit=False
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def session(self):
|
||||
"""Async context manager that provides a database session with transaction."""
|
||||
async with self.async_session_factory() as session:
|
||||
async with session.begin():
|
||||
yield session
|
||||
await session.flush()
|
||||
await session.commit()
|
||||
|
||||
async def init_db(self) -> None:
|
||||
"""Initialize database tables."""
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
@with_session
|
||||
async def get_user_by_user_uuid(self, session, user_uuid: UUID) -> User:
|
||||
"""Get user record by WebAuthn user UUID."""
|
||||
stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes)
|
||||
result = await session.execute(stmt)
|
||||
user_model = result.scalar_one_or_none()
|
||||
async def get_user_by_user_uuid(self, user_uuid: UUID) -> User:
|
||||
async with self.session() as session:
|
||||
stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes)
|
||||
result = await session.execute(stmt)
|
||||
user_model = result.scalar_one_or_none()
|
||||
|
||||
if user_model:
|
||||
return User(
|
||||
user_uuid=UUID(bytes=user_model.user_uuid),
|
||||
user_name=user_model.user_name,
|
||||
created_at=user_model.created_at,
|
||||
last_seen=user_model.last_seen,
|
||||
visits=user_model.visits,
|
||||
if user_model:
|
||||
return User(
|
||||
user_uuid=UUID(bytes=user_model.user_uuid),
|
||||
user_name=user_model.user_name,
|
||||
created_at=user_model.created_at,
|
||||
last_seen=user_model.last_seen,
|
||||
visits=user_model.visits,
|
||||
)
|
||||
raise ValueError("User not found")
|
||||
|
||||
async def create_user(self, user: User) -> None:
|
||||
async with self.session() as session:
|
||||
user_model = UserModel(
|
||||
user_uuid=user.user_uuid.bytes,
|
||||
user_name=user.user_name,
|
||||
created_at=user.created_at or datetime.now(),
|
||||
last_seen=user.last_seen,
|
||||
visits=user.visits,
|
||||
)
|
||||
raise ValueError("User not found")
|
||||
session.add(user_model)
|
||||
|
||||
@with_session
|
||||
async def create_user(self, session, user: User) -> None:
|
||||
"""Create a new user."""
|
||||
user_model = UserModel(
|
||||
user_uuid=user.user_uuid.bytes,
|
||||
user_name=user.user_name,
|
||||
created_at=user.created_at or datetime.now(),
|
||||
last_seen=user.last_seen,
|
||||
visits=user.visits,
|
||||
)
|
||||
session.add(user_model)
|
||||
|
||||
@with_session
|
||||
async def create_credential(self, session, credential: Credential) -> None:
|
||||
"""Store a credential for a user."""
|
||||
credential_model = CredentialModel(
|
||||
uuid=credential.uuid.bytes,
|
||||
credential_id=credential.credential_id,
|
||||
user_uuid=credential.user_uuid.bytes,
|
||||
aaguid=credential.aaguid.bytes,
|
||||
public_key=credential.public_key,
|
||||
sign_count=credential.sign_count,
|
||||
created_at=credential.created_at,
|
||||
last_used=credential.last_used,
|
||||
last_verified=credential.last_verified,
|
||||
)
|
||||
session.add(credential_model)
|
||||
|
||||
@with_session
|
||||
async def get_credential_by_id(self, session, credential_id: bytes) -> Credential:
|
||||
"""Get credential by credential ID."""
|
||||
stmt = select(CredentialModel).where(
|
||||
CredentialModel.credential_id == credential_id
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
credential_model = result.scalar_one_or_none()
|
||||
|
||||
if not credential_model:
|
||||
raise ValueError("Credential not registered")
|
||||
return Credential(
|
||||
uuid=UUID(bytes=credential_model.uuid),
|
||||
credential_id=credential_model.credential_id,
|
||||
user_uuid=UUID(bytes=credential_model.user_uuid),
|
||||
aaguid=UUID(bytes=credential_model.aaguid),
|
||||
public_key=credential_model.public_key,
|
||||
sign_count=credential_model.sign_count,
|
||||
created_at=credential_model.created_at,
|
||||
last_used=credential_model.last_used,
|
||||
last_verified=credential_model.last_verified,
|
||||
)
|
||||
|
||||
@with_session
|
||||
async def get_credentials_by_user_uuid(
|
||||
self, session, user_uuid: UUID
|
||||
) -> list[bytes]:
|
||||
"""Get all credential IDs for a user."""
|
||||
stmt = select(CredentialModel.credential_id).where(
|
||||
CredentialModel.user_uuid == user_uuid.bytes
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return [row[0] for row in result.fetchall()]
|
||||
|
||||
@with_session
|
||||
async def update_credential(self, session, credential: Credential) -> None:
|
||||
"""Update the sign count, created_at, last_used, and last_verified for a credential."""
|
||||
stmt = (
|
||||
update(CredentialModel)
|
||||
.where(CredentialModel.credential_id == credential.credential_id)
|
||||
.values(
|
||||
async def create_credential(self, credential: Credential) -> None:
|
||||
async with self.session() as session:
|
||||
credential_model = CredentialModel(
|
||||
uuid=credential.uuid.bytes,
|
||||
credential_id=credential.credential_id,
|
||||
user_uuid=credential.user_uuid.bytes,
|
||||
aaguid=credential.aaguid.bytes,
|
||||
public_key=credential.public_key,
|
||||
sign_count=credential.sign_count,
|
||||
created_at=credential.created_at,
|
||||
last_used=credential.last_used,
|
||||
last_verified=credential.last_verified,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
session.add(credential_model)
|
||||
|
||||
@with_session
|
||||
async def login(self, session, user_uuid: UUID, credential: Credential) -> None:
|
||||
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
|
||||
# Update credential
|
||||
stmt = (
|
||||
update(CredentialModel)
|
||||
.where(CredentialModel.credential_id == credential.credential_id)
|
||||
.values(
|
||||
sign_count=credential.sign_count,
|
||||
created_at=credential.created_at,
|
||||
last_used=credential.last_used,
|
||||
last_verified=credential.last_verified,
|
||||
async def get_credential_by_id(self, credential_id: bytes) -> Credential:
|
||||
async with self.session() as session:
|
||||
stmt = select(CredentialModel).where(
|
||||
CredentialModel.credential_id == credential_id
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
result = await session.execute(stmt)
|
||||
credential_model = result.scalar_one_or_none()
|
||||
|
||||
# Update user's last_seen and increment visits
|
||||
stmt = (
|
||||
update(UserModel)
|
||||
.where(UserModel.user_uuid == user_uuid.bytes)
|
||||
.values(last_seen=credential.last_used, visits=UserModel.visits + 1)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
if not credential_model:
|
||||
raise ValueError("Credential not registered")
|
||||
return Credential(
|
||||
uuid=UUID(bytes=credential_model.uuid),
|
||||
credential_id=credential_model.credential_id,
|
||||
user_uuid=UUID(bytes=credential_model.user_uuid),
|
||||
aaguid=UUID(bytes=credential_model.aaguid),
|
||||
public_key=credential_model.public_key,
|
||||
sign_count=credential_model.sign_count,
|
||||
created_at=credential_model.created_at,
|
||||
last_used=credential_model.last_used,
|
||||
last_verified=credential_model.last_verified,
|
||||
)
|
||||
|
||||
async def get_credentials_by_user_uuid(self, user_uuid: UUID) -> list[bytes]:
|
||||
async with self.session() as session:
|
||||
stmt = select(CredentialModel.credential_id).where(
|
||||
CredentialModel.user_uuid == user_uuid.bytes
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return [row[0] for row in result.fetchall()]
|
||||
|
||||
async def update_credential(self, credential: Credential) -> None:
|
||||
async with self.session() as session:
|
||||
stmt = (
|
||||
update(CredentialModel)
|
||||
.where(CredentialModel.credential_id == credential.credential_id)
|
||||
.values(
|
||||
sign_count=credential.sign_count,
|
||||
created_at=credential.created_at,
|
||||
last_used=credential.last_used,
|
||||
last_verified=credential.last_verified,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
|
||||
async def login(self, user_uuid: UUID, credential: Credential) -> None:
|
||||
async with self.session() as session:
|
||||
# Update credential
|
||||
stmt = (
|
||||
update(CredentialModel)
|
||||
.where(CredentialModel.credential_id == credential.credential_id)
|
||||
.values(
|
||||
sign_count=credential.sign_count,
|
||||
created_at=credential.created_at,
|
||||
last_used=credential.last_used,
|
||||
last_verified=credential.last_verified,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
|
||||
# Update user's last_seen and increment visits
|
||||
stmt = (
|
||||
update(UserModel)
|
||||
.where(UserModel.user_uuid == user_uuid.bytes)
|
||||
.values(last_seen=credential.last_used, visits=UserModel.visits + 1)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
|
||||
@with_session
|
||||
async def create_user_and_credential(
|
||||
self, session, user: User, credential: Credential
|
||||
self, user: User, credential: Credential
|
||||
) -> None:
|
||||
"""Create a new user and their first credential in a single transaction."""
|
||||
# Set visits to 1 for the new user since they're creating their first session
|
||||
user.visits = 1
|
||||
async with self.session() as session:
|
||||
# Set visits to 1 for the new user since they're creating their first session
|
||||
user.visits = 1
|
||||
|
||||
# Create user
|
||||
user_model = UserModel(
|
||||
user_uuid=user.user_uuid.bytes,
|
||||
user_name=user.user_name,
|
||||
created_at=user.created_at or datetime.now(),
|
||||
last_seen=user.last_seen,
|
||||
visits=user.visits,
|
||||
)
|
||||
session.add(user_model)
|
||||
# Create user
|
||||
user_model = UserModel(
|
||||
user_uuid=user.user_uuid.bytes,
|
||||
user_name=user.user_name,
|
||||
created_at=user.created_at or datetime.now(),
|
||||
last_seen=user.last_seen,
|
||||
visits=user.visits,
|
||||
)
|
||||
session.add(user_model)
|
||||
|
||||
# Create credential
|
||||
credential_model = CredentialModel(
|
||||
uuid=credential.uuid.bytes,
|
||||
credential_id=credential.credential_id,
|
||||
user_uuid=credential.user_uuid.bytes,
|
||||
aaguid=credential.aaguid.bytes,
|
||||
public_key=credential.public_key,
|
||||
sign_count=credential.sign_count,
|
||||
created_at=credential.created_at,
|
||||
last_used=credential.last_used,
|
||||
last_verified=credential.last_verified,
|
||||
)
|
||||
session.add(credential_model)
|
||||
# Create credential
|
||||
credential_model = CredentialModel(
|
||||
uuid=credential.uuid.bytes,
|
||||
credential_id=credential.credential_id,
|
||||
user_uuid=credential.user_uuid.bytes,
|
||||
aaguid=credential.aaguid.bytes,
|
||||
public_key=credential.public_key,
|
||||
sign_count=credential.sign_count,
|
||||
created_at=credential.created_at,
|
||||
last_used=credential.last_used,
|
||||
last_verified=credential.last_verified,
|
||||
)
|
||||
session.add(credential_model)
|
||||
|
||||
@with_session
|
||||
async def delete_credential(self, session, uuid: UUID, user_uuid: UUID) -> None:
|
||||
"""Delete a credential by its ID."""
|
||||
stmt = (
|
||||
delete(CredentialModel)
|
||||
.where(CredentialModel.uuid == uuid.bytes)
|
||||
.where(CredentialModel.user_uuid == user_uuid.bytes)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
async def delete_credential(self, uuid: UUID, user_uuid: UUID) -> None:
|
||||
async with self.session() as session:
|
||||
stmt = (
|
||||
delete(CredentialModel)
|
||||
.where(CredentialModel.uuid == uuid.bytes)
|
||||
.where(CredentialModel.user_uuid == user_uuid.bytes)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
|
||||
@with_session
|
||||
async def create_session(
|
||||
self,
|
||||
session,
|
||||
user_uuid: UUID,
|
||||
key: bytes,
|
||||
expires: datetime,
|
||||
info: dict,
|
||||
credential_uuid: UUID | None = None,
|
||||
) -> bytes:
|
||||
"""Create a new authentication session for a user. If credential_uuid is None, creates a session without a specific credential."""
|
||||
session_model = SessionModel(
|
||||
key=key,
|
||||
user_uuid=user_uuid.bytes,
|
||||
credential_uuid=credential_uuid.bytes if credential_uuid else None,
|
||||
expires=expires,
|
||||
info=info,
|
||||
)
|
||||
session.add(session_model)
|
||||
return key
|
||||
|
||||
@with_session
|
||||
async def get_session(self, session, key: bytes) -> Session | None:
|
||||
"""Get session by 16-byte key."""
|
||||
stmt = select(SessionModel).where(SessionModel.key == key)
|
||||
result = await session.execute(stmt)
|
||||
session_model = result.scalar_one_or_none()
|
||||
|
||||
if session_model:
|
||||
return Session(
|
||||
key=session_model.key,
|
||||
user_uuid=UUID(bytes=session_model.user_uuid),
|
||||
credential_uuid=UUID(bytes=session_model.credential_uuid)
|
||||
if session_model.credential_uuid
|
||||
else None,
|
||||
expires=session_model.expires,
|
||||
info=session_model.info or {},
|
||||
)
|
||||
return None
|
||||
|
||||
@with_session
|
||||
async def delete_session(self, session, key: bytes) -> None:
|
||||
"""Delete a session by 16-byte key."""
|
||||
await session.execute(delete(SessionModel).where(SessionModel.key == key))
|
||||
|
||||
@with_session
|
||||
async def update_session(
|
||||
self, session, key: bytes, expires: datetime, info: dict
|
||||
) -> None:
|
||||
"""Update session expiration time and/or info."""
|
||||
await session.execute(
|
||||
update(SessionModel)
|
||||
.where(SessionModel.key == key)
|
||||
.values(expires=expires, info=info)
|
||||
)
|
||||
async with self.session() as session:
|
||||
session_model = SessionModel(
|
||||
key=key,
|
||||
user_uuid=user_uuid.bytes,
|
||||
credential_uuid=credential_uuid.bytes if credential_uuid else None,
|
||||
expires=expires,
|
||||
info=info,
|
||||
)
|
||||
session.add(session_model)
|
||||
|
||||
@with_session
|
||||
async def cleanup_expired_sessions(self, session) -> None:
|
||||
"""Remove expired sessions."""
|
||||
current_time = datetime.now()
|
||||
stmt = delete(SessionModel).where(SessionModel.expires < current_time)
|
||||
await session.execute(stmt)
|
||||
async def get_session(self, key: bytes) -> Session | None:
|
||||
async with self.session() as session:
|
||||
stmt = select(SessionModel).where(SessionModel.key == key)
|
||||
result = await session.execute(stmt)
|
||||
session_model = result.scalar_one_or_none()
|
||||
|
||||
if session_model:
|
||||
return Session(
|
||||
key=session_model.key,
|
||||
user_uuid=UUID(bytes=session_model.user_uuid),
|
||||
credential_uuid=UUID(bytes=session_model.credential_uuid)
|
||||
if session_model.credential_uuid
|
||||
else None,
|
||||
expires=session_model.expires,
|
||||
info=session_model.info or {},
|
||||
)
|
||||
return None
|
||||
|
||||
async def delete_session(self, key: bytes) -> None:
|
||||
async with self.session() as session:
|
||||
await session.execute(delete(SessionModel).where(SessionModel.key == key))
|
||||
|
||||
async def update_session(self, key: bytes, expires: datetime, info: dict) -> None:
|
||||
async with self.session() as session:
|
||||
await session.execute(
|
||||
update(SessionModel)
|
||||
.where(SessionModel.key == key)
|
||||
.values(expires=expires, info=info)
|
||||
)
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
async with self.session() as session:
|
||||
current_time = datetime.now()
|
||||
stmt = delete(SessionModel).where(SessionModel.expires < current_time)
|
||||
await session.execute(stmt)
|
||||
|
||||
Reference in New Issue
Block a user