Finish DB cleanup/refactoring. Working now.

This commit is contained in:
Leo Vasanko 2025-08-05 08:26:35 -06:00
parent c5733eefd6
commit b58b7d5350
7 changed files with 47 additions and 49 deletions

View File

@ -134,15 +134,26 @@ class DatabaseInterface(ABC):
"""Create a new user and their first credential in a transaction.""" """Create a new user and their first credential in a transaction."""
# Global DB instance class DatabaseManager:
database_instance: DatabaseInterface | None = None """Manager for the global database instance."""
def __init__(self):
self._instance: DatabaseInterface | None = None
@property
def instance(self) -> DatabaseInterface:
if self._instance is None:
raise RuntimeError(
"Database not initialized. Call e.g. db.sql.init() first."
)
return self._instance
@instance.setter
def instance(self, instance: DatabaseInterface) -> None:
self._instance = instance
def database() -> DatabaseInterface: db = DatabaseManager()
"""Get the global database instance."""
if database_instance is None:
raise RuntimeError("Database not initialized. Call e.g. db.sql.init() first.")
return database_instance
__all__ = [ __all__ = [
@ -150,6 +161,5 @@ __all__ = [
"Credential", "Credential",
"Session", "Session",
"DatabaseInterface", "DatabaseInterface",
"database_instance", "db",
"database",
] ]

View File

@ -23,15 +23,14 @@ from sqlalchemy.dialects.sqlite import BLOB, JSON
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from . import Credential, DatabaseInterface, Session, User from . import Credential, DatabaseInterface, Session, User, db
DB_PATH = "sqlite+aiosqlite:///webauthn.db" DB_PATH = "sqlite+aiosqlite:///webauthn.db"
def init(*args, **kwargs): async def init(*args, **kwargs):
from .. import db db.instance = DB()
await db.instance.init_db()
db.database_instance = DB()
# SQLAlchemy Models # SQLAlchemy Models

View File

@ -14,7 +14,7 @@ from fastapi import Cookie, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBearer from fastapi.security import HTTPBearer
from .. import aaguid from .. import aaguid
from ..db import database from ..db import db
from ..util.tokens import session_key from ..util.tokens import session_key
from . import session from . import session
@ -42,15 +42,15 @@ def register_api_routes(app: FastAPI):
"""Get full user information for the authenticated user.""" """Get full user information for the authenticated user."""
try: try:
s = await session.get_session(auth, reset_allowed=True) s = await session.get_session(auth, reset_allowed=True)
u = await database().get_user_by_user_uuid(s.user_uuid) u = await db.instance.get_user_by_user_uuid(s.user_uuid)
# Get all credentials for the user # Get all credentials for the user
credential_ids = await database().get_credentials_by_user_uuid(s.user_uuid) credential_ids = await db.instance.get_credentials_by_user_uuid(s.user_uuid)
credentials = [] credentials = []
user_aaguids = set() user_aaguids = set()
for cred_id in credential_ids: for cred_id in credential_ids:
c = await database().get_credential_by_id(cred_id) c = await db.instance.get_credential_by_id(cred_id)
# Convert AAGUID to string format # Convert AAGUID to string format
aaguid_str = str(c.aaguid) aaguid_str = str(c.aaguid)
@ -102,7 +102,7 @@ def register_api_routes(app: FastAPI):
"""Log out the current user by clearing the session cookie and deleting from database.""" """Log out the current user by clearing the session cookie and deleting from database."""
if not auth: if not auth:
return {"status": "success", "message": "Already logged out"} return {"status": "success", "message": "Already logged out"}
await database().delete_session(session_key(auth)) await db.instance.delete_session(session_key(auth))
response.delete_cookie("auth") response.delete_cookie("auth")
return {"status": "success", "message": "Logged out successfully"} return {"status": "success", "message": "Logged out successfully"}

View File

@ -1,14 +1,3 @@
"""
Minimal FastAPI WebAuthn server with WebSocket support for passkey registration and authentication.
This module provides a simple WebAuthn implementation that:
- Uses WebSocket for real-time communication
- Supports Resident Keys (discoverable credentials) for passwordless authentication
- Maintains challenges locally per connection
- Uses async SQLite database for persistent storage of users and credentials
- Enables true passwordless authentication where users don't need to enter a user_name
"""
import contextlib import contextlib
import logging import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -20,7 +9,6 @@ from fastapi.responses import (
) )
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from ..db import close_db, init_db
from . import session, ws from . import session, ws
from .api import register_api_routes from .api import register_api_routes
from .reset import register_reset_routes from .reset import register_reset_routes
@ -30,9 +18,10 @@ STATIC_DIR = Path(__file__).parent.parent / "frontend-build"
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
await init_db() from ..db import sql
await sql.init()
yield yield
await close_db()
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)

View File

@ -3,7 +3,7 @@ import logging
from fastapi import Cookie, HTTPException, Request from fastapi import Cookie, HTTPException, Request
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from ..db import database from ..db import db
from ..util import passphrase, tokens from ..util import passphrase, tokens
from . import session from . import session
@ -20,7 +20,7 @@ def register_reset_routes(app):
# Generate a human-readable token # Generate a human-readable token
token = passphrase.generate() # e.g., "cross.rotate.yin.note.evoke" token = passphrase.generate() # e.g., "cross.rotate.yin.note.evoke"
await database().create_session( await db.instance.create_session(
user_uuid=s.user_uuid, user_uuid=s.user_uuid,
key=tokens.reset_key(token), key=tokens.reset_key(token),
expires=session.expires(), expires=session.expires(),
@ -56,7 +56,7 @@ def register_reset_routes(app):
try: try:
# Get session token to validate it exists and get user_id # Get session token to validate it exists and get user_id
key = tokens.reset_key(reset_token) key = tokens.reset_key(reset_token)
sess = await database().get_session(key) sess = await db.instance.get_session(key)
if not sess: if not sess:
raise ValueError("Invalid or expired registration token") raise ValueError("Invalid or expired registration token")

View File

@ -14,7 +14,7 @@ from uuid import UUID
from fastapi import Request, Response, WebSocket from fastapi import Request, Response, WebSocket
from ..db import Session, database from ..db import Session, db
from ..util import passphrase from ..util import passphrase
from ..util.tokens import create_token, reset_key, session_key from ..util.tokens import create_token, reset_key, session_key
@ -37,7 +37,7 @@ def infodict(request: Request | WebSocket, type: str) -> dict:
async def create_session(user_uuid: UUID, info: dict, credential_uuid: UUID) -> str: async def create_session(user_uuid: UUID, info: dict, credential_uuid: UUID) -> str:
"""Create a new session and return a session token.""" """Create a new session and return a session token."""
token = create_token() token = create_token()
await database().create_session( await db.instance.create_session(
user_uuid=user_uuid, user_uuid=user_uuid,
key=session_key(token), key=session_key(token),
expires=datetime.now() + EXPIRES, expires=datetime.now() + EXPIRES,
@ -56,7 +56,7 @@ async def get_session(token: str, reset_allowed=False) -> Session:
else: else:
key = session_key(token) key = session_key(token)
session = await database().get_session(key) session = await db.instance.get_session(key)
if not session: if not session:
raise ValueError("Invalid or expired session token") raise ValueError("Invalid or expired session token")
return session return session
@ -65,7 +65,7 @@ async def get_session(token: str, reset_allowed=False) -> Session:
async def refresh_session_token(token: str): async def refresh_session_token(token: str):
"""Refresh a session extending its expiry.""" """Refresh a session extending its expiry."""
# Get the current session # Get the current session
s = await database().update_session( s = await db.instance.update_session(
session_key(token), datetime.now() + EXPIRES, {} session_key(token), datetime.now() + EXPIRES, {}
) )
@ -88,4 +88,4 @@ def set_session_cookie(response: Response, token: str) -> None:
async def delete_credential(credential_uuid: UUID, auth: str): async def delete_credential(credential_uuid: UUID, auth: str):
"""Delete a specific credential for the current user.""" """Delete a specific credential for the current user."""
s = await get_session(auth) s = await get_session(auth)
await database().delete_credential(credential_uuid, s.user_uuid) await db.instance.delete_credential(credential_uuid, s.user_uuid)

View File

@ -18,7 +18,7 @@ from webauthn.helpers.exceptions import InvalidAuthenticationResponse
from passkey.fastapi import session from passkey.fastapi import session
from ..db import User, database from ..db import User, db
from ..sansio import Passkey from ..sansio import Passkey
from ..util.tokens import create_token, session_key from ..util.tokens import create_token, session_key
from .session import create_session, infodict from .session import create_session, infodict
@ -65,13 +65,13 @@ async def websocket_register_new(
credential = await register_chat(ws, user_uuid, user_name, origin=origin) credential = await register_chat(ws, user_uuid, user_name, origin=origin)
# Store the user and credential in the database # Store the user and credential in the database
await database().create_user_and_credential( await db.instance.create_user_and_credential(
User(user_uuid, user_name, created_at=datetime.now()), User(user_uuid, user_name, created_at=datetime.now()),
credential, credential,
) )
# Create a session token for the new user # Create a session token for the new user
token = create_token() token = create_token()
await database().create_session( await db.instance.create_session(
user_uuid=user_uuid, user_uuid=user_uuid,
key=session_key(token), key=session_key(token),
expires=datetime.now() + session.EXPIRES, expires=datetime.now() + session.EXPIRES,
@ -106,16 +106,16 @@ async def websocket_register_add(ws: WebSocket, auth=Cookie(None)):
user_uuid = s.user_uuid user_uuid = s.user_uuid
# Get user information to get the user_name # Get user information to get the user_name
user = await database().get_user_by_user_uuid(user_uuid) user = await db.instance.get_user_by_user_uuid(user_uuid)
user_name = user.user_name user_name = user.user_name
challenge_ids = await database().get_credentials_by_user_uuid(user_uuid) challenge_ids = await db.instance.get_credentials_by_user_uuid(user_uuid)
# WebAuthn registration # WebAuthn registration
credential = await register_chat( credential = await register_chat(
ws, user_uuid, user_name, challenge_ids, origin ws, user_uuid, user_name, challenge_ids, origin
) )
# Store the new credential in the database # Store the new credential in the database
await database().create_credential(credential) await db.instance.create_credential(credential)
await ws.send_json( await ws.send_json(
{ {
@ -144,11 +144,11 @@ async def websocket_authenticate(ws: WebSocket):
# Wait for the client to use his authenticator to authenticate # Wait for the client to use his authenticator to authenticate
credential = passkey.auth_parse(await ws.receive_json()) credential = passkey.auth_parse(await ws.receive_json())
# Fetch from the database by credential ID # Fetch from the database by credential ID
stored_cred = await database().get_credential_by_id(credential.raw_id) stored_cred = await db.instance.get_credential_by_id(credential.raw_id)
# Verify the credential matches the stored data # Verify the credential matches the stored data
passkey.auth_verify(credential, challenge, stored_cred, origin=origin) passkey.auth_verify(credential, challenge, stored_cred, origin=origin)
# Update both credential and user's last_seen timestamp # Update both credential and user's last_seen timestamp
await database().login(stored_cred.user_uuid, stored_cred) await db.instance.login(stored_cred.user_uuid, stored_cred)
# Create a session token for the authenticated user # Create a session token for the authenticated user
assert stored_cred.uuid is not None assert stored_cred.uuid is not None