Database cleanup, base class, separated from FastAPI app.

This commit is contained in:
Leo Vasanko 2025-08-05 07:55:31 -06:00
parent 00693c56fa
commit c5733eefd6
7 changed files with 325 additions and 239 deletions

View File

@ -5,7 +5,8 @@ This module provides dataclasses and database abstractions for managing
users, credentials, and sessions in a WebAuthn authentication system. users, credentials, and sessions in a WebAuthn authentication system.
""" """
from dataclasses import dataclass, field from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from uuid import UUID from uuid import UUID
@ -47,4 +48,108 @@ class Session:
credential_uuid: UUID | None = None credential_uuid: UUID | None = None
__all__ = ["User", "Credential", "Session"] class DatabaseInterface(ABC):
"""Abstract base class defining the database interface.
This class defines the public API that database implementations should provide.
Implementations may use decorators like @with_session that modify method signatures
at runtime, so this interface focuses on the logical operations rather than
exact parameter matching.
"""
@abstractmethod
async def init_db(self) -> None:
"""Initialize database tables."""
pass
# User operations
@abstractmethod
async def get_user_by_user_uuid(self, user_uuid: UUID) -> User:
"""Get user record by WebAuthn user UUID."""
@abstractmethod
async def create_user(self, user: User) -> None:
"""Create a new user."""
# Credential operations
@abstractmethod
async def create_credential(self, credential: Credential) -> None:
"""Store a credential for a user."""
@abstractmethod
async def get_credential_by_id(self, credential_id: bytes) -> Credential:
"""Get credential by credential ID."""
@abstractmethod
async def get_credentials_by_user_uuid(self, user_uuid: UUID) -> list[bytes]:
"""Get all credential IDs for a user."""
@abstractmethod
async def update_credential(self, credential: Credential) -> None:
"""Update the sign count, created_at, last_used, and last_verified for a credential."""
@abstractmethod
async def delete_credential(self, uuid: UUID, user_uuid: UUID) -> None:
"""Delete a specific credential for a user."""
# Session operations
@abstractmethod
async def create_session(
self,
user_uuid: UUID,
key: bytes,
expires: datetime,
info: dict,
credential_uuid: UUID | None = None,
) -> None:
"""Create a new session."""
@abstractmethod
async def get_session(self, key: bytes) -> Session | None:
"""Get session by key."""
@abstractmethod
async def delete_session(self, key: bytes) -> None:
"""Delete session by key."""
@abstractmethod
async def update_session(
self, key: bytes, expires: datetime, info: dict
) -> Session | None:
"""Update session expiry and info."""
@abstractmethod
async def cleanup(self) -> None:
"""Called periodically to clean up expired records."""
# Combined operations
@abstractmethod
async def login(self, user_uuid: UUID, credential: Credential) -> None:
"""Update user and credential timestamps after successful login."""
@abstractmethod
async def create_user_and_credential(
self, user: User, credential: Credential
) -> None:
"""Create a new user and their first credential in a transaction."""
# Global DB instance
database_instance: DatabaseInterface | None = None
def database() -> DatabaseInterface:
"""Get the global database instance."""
if database_instance is None:
raise RuntimeError("Database not initialized. Call e.g. db.sql.init() first.")
return database_instance
__all__ = [
"User",
"Credential",
"Session",
"DatabaseInterface",
"database_instance",
"database",
]

View File

@ -5,8 +5,8 @@ This module provides an async database layer using SQLAlchemy async mode
for managing users and credentials in a WebAuthn authentication system. for managing users and credentials in a WebAuthn authentication system.
""" """
from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
from functools import wraps
from uuid import UUID from uuid import UUID
from sqlalchemy import ( from sqlalchemy import (
@ -23,24 +23,15 @@ from sqlalchemy.dialects.sqlite import BLOB, JSON
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from . import Credential, Session, User from . import Credential, DatabaseInterface, Session, User
DB_PATH = "sqlite+aiosqlite:///webauthn.db" DB_PATH = "sqlite+aiosqlite:///webauthn.db"
def with_session(func): def init(*args, **kwargs):
"""Decorator that provides a database session with transaction to the method.""" from .. import db
@wraps(func) db.database_instance = DB()
async def wrapper(self, *args, **kwargs):
async with self.async_session_factory() as session:
async with session.begin():
result = await func(self, session, *args, **kwargs)
await session.flush()
return result
await session.commit()
return wrapper
# SQLAlchemy Models # SQLAlchemy Models
@ -100,7 +91,7 @@ class SessionModel(Base):
user: Mapped["UserModel"] = relationship("UserModel") user: Mapped["UserModel"] = relationship("UserModel")
class DB: class DB(DatabaseInterface):
"""Database class that handles its own connections.""" """Database class that handles its own connections."""
def __init__(self, db_path: str = DB_PATH): def __init__(self, db_path: str = DB_PATH):
@ -110,14 +101,22 @@ class DB:
self.engine, expire_on_commit=False self.engine, expire_on_commit=False
) )
@asynccontextmanager
async def session(self):
"""Async context manager that provides a database session with transaction."""
async with self.async_session_factory() as session:
async with session.begin():
yield session
await session.flush()
await session.commit()
async def init_db(self) -> None: async def init_db(self) -> None:
"""Initialize database tables.""" """Initialize database tables."""
async with self.engine.begin() as conn: async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
@with_session async def get_user_by_user_uuid(self, user_uuid: UUID) -> User:
async def get_user_by_user_uuid(self, session, user_uuid: UUID) -> User: async with self.session() as session:
"""Get user record by WebAuthn user UUID."""
stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes) stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes)
result = await session.execute(stmt) result = await session.execute(stmt)
user_model = result.scalar_one_or_none() user_model = result.scalar_one_or_none()
@ -132,9 +131,8 @@ class DB:
) )
raise ValueError("User not found") raise ValueError("User not found")
@with_session async def create_user(self, user: User) -> None:
async def create_user(self, session, user: User) -> None: async with self.session() as session:
"""Create a new user."""
user_model = UserModel( user_model = UserModel(
user_uuid=user.user_uuid.bytes, user_uuid=user.user_uuid.bytes,
user_name=user.user_name, user_name=user.user_name,
@ -144,9 +142,8 @@ class DB:
) )
session.add(user_model) session.add(user_model)
@with_session async def create_credential(self, credential: Credential) -> None:
async def create_credential(self, session, credential: Credential) -> None: async with self.session() as session:
"""Store a credential for a user."""
credential_model = CredentialModel( credential_model = CredentialModel(
uuid=credential.uuid.bytes, uuid=credential.uuid.bytes,
credential_id=credential.credential_id, credential_id=credential.credential_id,
@ -160,9 +157,8 @@ class DB:
) )
session.add(credential_model) session.add(credential_model)
@with_session async def get_credential_by_id(self, credential_id: bytes) -> Credential:
async def get_credential_by_id(self, session, credential_id: bytes) -> Credential: async with self.session() as session:
"""Get credential by credential ID."""
stmt = select(CredentialModel).where( stmt = select(CredentialModel).where(
CredentialModel.credential_id == credential_id CredentialModel.credential_id == credential_id
) )
@ -183,20 +179,16 @@ class DB:
last_verified=credential_model.last_verified, last_verified=credential_model.last_verified,
) )
@with_session async def get_credentials_by_user_uuid(self, user_uuid: UUID) -> list[bytes]:
async def get_credentials_by_user_uuid( async with self.session() as session:
self, session, user_uuid: UUID
) -> list[bytes]:
"""Get all credential IDs for a user."""
stmt = select(CredentialModel.credential_id).where( stmt = select(CredentialModel.credential_id).where(
CredentialModel.user_uuid == user_uuid.bytes CredentialModel.user_uuid == user_uuid.bytes
) )
result = await session.execute(stmt) result = await session.execute(stmt)
return [row[0] for row in result.fetchall()] return [row[0] for row in result.fetchall()]
@with_session async def update_credential(self, credential: Credential) -> None:
async def update_credential(self, session, credential: Credential) -> None: async with self.session() as session:
"""Update the sign count, created_at, last_used, and last_verified for a credential."""
stmt = ( stmt = (
update(CredentialModel) update(CredentialModel)
.where(CredentialModel.credential_id == credential.credential_id) .where(CredentialModel.credential_id == credential.credential_id)
@ -209,9 +201,8 @@ class DB:
) )
await session.execute(stmt) await session.execute(stmt)
@with_session async def login(self, user_uuid: UUID, credential: Credential) -> None:
async def login(self, session, user_uuid: UUID, credential: Credential) -> None: async with self.session() as session:
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
# Update credential # Update credential
stmt = ( stmt = (
update(CredentialModel) update(CredentialModel)
@ -233,11 +224,10 @@ class DB:
) )
await session.execute(stmt) await session.execute(stmt)
@with_session
async def create_user_and_credential( async def create_user_and_credential(
self, session, user: User, credential: Credential self, user: User, credential: Credential
) -> None: ) -> None:
"""Create a new user and their first credential in a single transaction.""" async with self.session() as session:
# Set visits to 1 for the new user since they're creating their first session # Set visits to 1 for the new user since they're creating their first session
user.visits = 1 user.visits = 1
@ -265,9 +255,8 @@ class DB:
) )
session.add(credential_model) session.add(credential_model)
@with_session async def delete_credential(self, uuid: UUID, user_uuid: UUID) -> None:
async def delete_credential(self, session, uuid: UUID, user_uuid: UUID) -> None: async with self.session() as session:
"""Delete a credential by its ID."""
stmt = ( stmt = (
delete(CredentialModel) delete(CredentialModel)
.where(CredentialModel.uuid == uuid.bytes) .where(CredentialModel.uuid == uuid.bytes)
@ -275,17 +264,15 @@ class DB:
) )
await session.execute(stmt) await session.execute(stmt)
@with_session
async def create_session( async def create_session(
self, self,
session,
user_uuid: UUID, user_uuid: UUID,
key: bytes, key: bytes,
expires: datetime, expires: datetime,
info: dict, info: dict,
credential_uuid: UUID | None = None, credential_uuid: UUID | None = None,
) -> bytes: ) -> None:
"""Create a new authentication session for a user. If credential_uuid is None, creates a session without a specific credential.""" async with self.session() as session:
session_model = SessionModel( session_model = SessionModel(
key=key, key=key,
user_uuid=user_uuid.bytes, user_uuid=user_uuid.bytes,
@ -294,11 +281,9 @@ class DB:
info=info, info=info,
) )
session.add(session_model) session.add(session_model)
return key
@with_session async def get_session(self, key: bytes) -> Session | None:
async def get_session(self, session, key: bytes) -> Session | None: async with self.session() as session:
"""Get session by 16-byte key."""
stmt = select(SessionModel).where(SessionModel.key == key) stmt = select(SessionModel).where(SessionModel.key == key)
result = await session.execute(stmt) result = await session.execute(stmt)
session_model = result.scalar_one_or_none() session_model = result.scalar_one_or_none()
@ -315,25 +300,20 @@ class DB:
) )
return None return None
@with_session async def delete_session(self, key: bytes) -> None:
async def delete_session(self, session, key: bytes) -> None: async with self.session() as session:
"""Delete a session by 16-byte key."""
await session.execute(delete(SessionModel).where(SessionModel.key == key)) await session.execute(delete(SessionModel).where(SessionModel.key == key))
@with_session async def update_session(self, key: bytes, expires: datetime, info: dict) -> None:
async def update_session( async with self.session() as session:
self, session, key: bytes, expires: datetime, info: dict
) -> None:
"""Update session expiration time and/or info."""
await session.execute( await session.execute(
update(SessionModel) update(SessionModel)
.where(SessionModel.key == key) .where(SessionModel.key == key)
.values(expires=expires, info=info) .values(expires=expires, info=info)
) )
@with_session async def cleanup(self) -> None:
async def cleanup_expired_sessions(self, session) -> None: async with self.session() as session:
"""Remove expired sessions."""
current_time = datetime.now() current_time = datetime.now()
stmt = delete(SessionModel).where(SessionModel.expires < current_time) stmt = delete(SessionModel).where(SessionModel.expires < current_time)
await session.execute(stmt) await session.execute(stmt)

View File

@ -14,7 +14,7 @@ from fastapi import Cookie, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBearer from fastapi.security import HTTPBearer
from .. import aaguid from .. import aaguid
from ..db import sql from ..db import database
from ..util.tokens import session_key from ..util.tokens import session_key
from . import session from . import session
@ -38,19 +38,19 @@ def register_api_routes(app: FastAPI):
return {"status": "error", "valid": False} return {"status": "error", "valid": False}
@app.post("/auth/user-info") @app.post("/auth/user-info")
async def api_user_info(request: Request, response: Response, auth=Cookie(None)): async def api_user_info(auth=Cookie(None)):
"""Get full user information for the authenticated user.""" """Get full user information for the authenticated user."""
try: try:
s = await session.get_session(auth, reset_allowed=True) s = await session.get_session(auth, reset_allowed=True)
u = await sql.get_user_by_uuid(s.user_uuid) u = await database().get_user_by_user_uuid(s.user_uuid)
# Get all credentials for the user # Get all credentials for the user
credential_ids = await sql.get_user_credentials(s.user_uuid) credential_ids = await database().get_credentials_by_user_uuid(s.user_uuid)
credentials = [] credentials = []
user_aaguids = set() user_aaguids = set()
for cred_id in credential_ids: for cred_id in credential_ids:
c = await sql.get_credential_by_id(cred_id) c = await database().get_credential_by_id(cred_id)
# Convert AAGUID to string format # Convert AAGUID to string format
aaguid_str = str(c.aaguid) aaguid_str = str(c.aaguid)
@ -102,14 +102,12 @@ def register_api_routes(app: FastAPI):
"""Log out the current user by clearing the session cookie and deleting from database.""" """Log out the current user by clearing the session cookie and deleting from database."""
if not auth: if not auth:
return {"status": "success", "message": "Already logged out"} return {"status": "success", "message": "Already logged out"}
await sql.delete_session(session_key(auth)) await database().delete_session(session_key(auth))
response.delete_cookie("auth") response.delete_cookie("auth")
return {"status": "success", "message": "Logged out successfully"} return {"status": "success", "message": "Logged out successfully"}
@app.post("/auth/set-session") @app.post("/auth/set-session")
async def api_set_session( async def api_set_session(response: Response, auth=Depends(bearer_auth)):
request: Request, response: Response, auth=Depends(bearer_auth)
):
"""Set session cookie from Authorization header. Fetched after login by WebSocket.""" """Set session cookie from Authorization header. Fetched after login by WebSocket."""
try: try:
user = await session.get_session(auth.credentials) user = await session.get_session(auth.credentials)

View File

@ -20,7 +20,7 @@ from fastapi.responses import (
) )
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from ..db import sql from ..db import close_db, init_db
from . import session, ws from . import session, ws
from .api import register_api_routes from .api import register_api_routes
from .reset import register_reset_routes from .reset import register_reset_routes
@ -30,8 +30,9 @@ STATIC_DIR = Path(__file__).parent.parent / "frontend-build"
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
await sql.init_database() await init_db()
yield yield
await close_db()
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)

View File

@ -3,7 +3,7 @@ import logging
from fastapi import Cookie, HTTPException, Request from fastapi import Cookie, HTTPException, Request
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from ..db import sql from ..db import database
from ..util import passphrase, tokens from ..util import passphrase, tokens
from . import session from . import session
@ -20,7 +20,7 @@ def register_reset_routes(app):
# Generate a human-readable token # Generate a human-readable token
token = passphrase.generate() # e.g., "cross.rotate.yin.note.evoke" token = passphrase.generate() # e.g., "cross.rotate.yin.note.evoke"
await sql.create_session( await database().create_session(
user_uuid=s.user_uuid, user_uuid=s.user_uuid,
key=tokens.reset_key(token), key=tokens.reset_key(token),
expires=session.expires(), expires=session.expires(),
@ -56,7 +56,7 @@ def register_reset_routes(app):
try: try:
# Get session token to validate it exists and get user_id # Get session token to validate it exists and get user_id
key = tokens.reset_key(reset_token) key = tokens.reset_key(reset_token)
sess = await sql.get_session(key) sess = await database().get_session(key)
if not sess: if not sess:
raise ValueError("Invalid or expired registration token") raise ValueError("Invalid or expired registration token")

View File

@ -14,7 +14,7 @@ from uuid import UUID
from fastapi import Request, Response, WebSocket from fastapi import Request, Response, WebSocket
from ..db import Session, sql from ..db import Session, database
from ..util import passphrase from ..util import passphrase
from ..util.tokens import create_token, reset_key, session_key from ..util.tokens import create_token, reset_key, session_key
@ -37,7 +37,7 @@ def infodict(request: Request | WebSocket, type: str) -> dict:
async def create_session(user_uuid: UUID, info: dict, credential_uuid: UUID) -> str: async def create_session(user_uuid: UUID, info: dict, credential_uuid: UUID) -> str:
"""Create a new session and return a session token.""" """Create a new session and return a session token."""
token = create_token() token = create_token()
await sql.create_session( await database().create_session(
user_uuid=user_uuid, user_uuid=user_uuid,
key=session_key(token), key=session_key(token),
expires=datetime.now() + EXPIRES, expires=datetime.now() + EXPIRES,
@ -56,7 +56,7 @@ async def get_session(token: str, reset_allowed=False) -> Session:
else: else:
key = session_key(token) key = session_key(token)
session = await sql.get_session(key) session = await database().get_session(key)
if not session: if not session:
raise ValueError("Invalid or expired session token") raise ValueError("Invalid or expired session token")
return session return session
@ -65,7 +65,9 @@ async def get_session(token: str, reset_allowed=False) -> Session:
async def refresh_session_token(token: str): async def refresh_session_token(token: str):
"""Refresh a session extending its expiry.""" """Refresh a session extending its expiry."""
# Get the current session # Get the current session
s = await sql.update_session(session_key(token), datetime.now() + EXPIRES, {}) s = await database().update_session(
session_key(token), datetime.now() + EXPIRES, {}
)
if not s: if not s:
raise ValueError("Session not found or expired") raise ValueError("Session not found or expired")
@ -86,4 +88,4 @@ def set_session_cookie(response: Response, token: str) -> None:
async def delete_credential(credential_uuid: UUID, auth: str): async def delete_credential(credential_uuid: UUID, auth: str):
"""Delete a specific credential for the current user.""" """Delete a specific credential for the current user."""
s = await get_session(auth) s = await get_session(auth)
await sql.delete_credential(credential_uuid, s.user_uuid) await database().delete_credential(credential_uuid, s.user_uuid)

View File

@ -18,7 +18,7 @@ from webauthn.helpers.exceptions import InvalidAuthenticationResponse
from passkey.fastapi import session from passkey.fastapi import session
from ..db import User, sql from ..db import User, database
from ..sansio import Passkey from ..sansio import Passkey
from ..util.tokens import create_token, session_key from ..util.tokens import create_token, session_key
from .session import create_session, infodict from .session import create_session, infodict
@ -65,13 +65,13 @@ async def websocket_register_new(
credential = await register_chat(ws, user_uuid, user_name, origin=origin) credential = await register_chat(ws, user_uuid, user_name, origin=origin)
# Store the user and credential in the database # Store the user and credential in the database
await sql.create_user_and_credential( await database().create_user_and_credential(
User(user_uuid, user_name, created_at=datetime.now()), User(user_uuid, user_name, created_at=datetime.now()),
credential, credential,
) )
# Create a session token for the new user # Create a session token for the new user
token = create_token() token = create_token()
await sql.create_session( await database().create_session(
user_uuid=user_uuid, user_uuid=user_uuid,
key=session_key(token), key=session_key(token),
expires=datetime.now() + session.EXPIRES, expires=datetime.now() + session.EXPIRES,
@ -106,16 +106,16 @@ async def websocket_register_add(ws: WebSocket, auth=Cookie(None)):
user_uuid = s.user_uuid user_uuid = s.user_uuid
# Get user information to get the user_name # Get user information to get the user_name
user = await sql.get_user_by_uuid(user_uuid) user = await database().get_user_by_user_uuid(user_uuid)
user_name = user.user_name user_name = user.user_name
challenge_ids = await sql.get_user_credentials(user_uuid) challenge_ids = await database().get_credentials_by_user_uuid(user_uuid)
# WebAuthn registration # WebAuthn registration
credential = await register_chat( credential = await register_chat(
ws, user_uuid, user_name, challenge_ids, origin ws, user_uuid, user_name, challenge_ids, origin
) )
# Store the new credential in the database # Store the new credential in the database
await sql.create_credential_for_user(credential) await database().create_credential(credential)
await ws.send_json( await ws.send_json(
{ {
@ -144,11 +144,11 @@ async def websocket_authenticate(ws: WebSocket):
# Wait for the client to use his authenticator to authenticate # Wait for the client to use his authenticator to authenticate
credential = passkey.auth_parse(await ws.receive_json()) credential = passkey.auth_parse(await ws.receive_json())
# Fetch from the database by credential ID # Fetch from the database by credential ID
stored_cred = await sql.get_credential_by_id(credential.raw_id) stored_cred = await database().get_credential_by_id(credential.raw_id)
# Verify the credential matches the stored data # Verify the credential matches the stored data
passkey.auth_verify(credential, challenge, stored_cred, origin=origin) passkey.auth_verify(credential, challenge, stored_cred, origin=origin)
# Update both credential and user's last_seen timestamp # Update both credential and user's last_seen timestamp
await sql.login_user(stored_cred.user_uuid, stored_cred) await database().login(stored_cred.user_uuid, stored_cred)
# Create a session token for the authenticated user # Create a session token for the authenticated user
assert stored_cred.uuid is not None assert stored_cred.uuid is not None