Finish DB cleanup/refactoring. Working now.

This commit is contained in:
Leo Vasanko 2025-08-05 08:26:35 -06:00
parent c5733eefd6
commit b58b7d5350
7 changed files with 47 additions and 49 deletions

View File

@ -134,15 +134,26 @@ class DatabaseInterface(ABC):
"""Create a new user and their first credential in a transaction."""
# Global DB instance
database_instance: DatabaseInterface | None = None
class DatabaseManager:
"""Manager for the global database instance."""
def __init__(self):
self._instance: DatabaseInterface | None = None
@property
def instance(self) -> DatabaseInterface:
if self._instance is None:
raise RuntimeError(
"Database not initialized. Call e.g. db.sql.init() first."
)
return self._instance
@instance.setter
def instance(self, instance: DatabaseInterface) -> None:
self._instance = instance
def database() -> DatabaseInterface:
"""Get the global database instance."""
if database_instance is None:
raise RuntimeError("Database not initialized. Call e.g. db.sql.init() first.")
return database_instance
db = DatabaseManager()
__all__ = [
@ -150,6 +161,5 @@ __all__ = [
"Credential",
"Session",
"DatabaseInterface",
"database_instance",
"database",
"db",
]

View File

@ -23,15 +23,14 @@ from sqlalchemy.dialects.sqlite import BLOB, JSON
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from . import Credential, DatabaseInterface, Session, User
from . import Credential, DatabaseInterface, Session, User, db
DB_PATH = "sqlite+aiosqlite:///webauthn.db"
def init(*args, **kwargs):
from .. import db
db.database_instance = DB()
async def init(*args, **kwargs):
db.instance = DB()
await db.instance.init_db()
# SQLAlchemy Models

View File

@ -14,7 +14,7 @@ from fastapi import Cookie, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBearer
from .. import aaguid
from ..db import database
from ..db import db
from ..util.tokens import session_key
from . import session
@ -42,15 +42,15 @@ def register_api_routes(app: FastAPI):
"""Get full user information for the authenticated user."""
try:
s = await session.get_session(auth, reset_allowed=True)
u = await database().get_user_by_user_uuid(s.user_uuid)
u = await db.instance.get_user_by_user_uuid(s.user_uuid)
# Get all credentials for the user
credential_ids = await database().get_credentials_by_user_uuid(s.user_uuid)
credential_ids = await db.instance.get_credentials_by_user_uuid(s.user_uuid)
credentials = []
user_aaguids = set()
for cred_id in credential_ids:
c = await database().get_credential_by_id(cred_id)
c = await db.instance.get_credential_by_id(cred_id)
# Convert AAGUID to string format
aaguid_str = str(c.aaguid)
@ -102,7 +102,7 @@ def register_api_routes(app: FastAPI):
"""Log out the current user by clearing the session cookie and deleting from database."""
if not auth:
return {"status": "success", "message": "Already logged out"}
await database().delete_session(session_key(auth))
await db.instance.delete_session(session_key(auth))
response.delete_cookie("auth")
return {"status": "success", "message": "Logged out successfully"}

View File

@ -1,14 +1,3 @@
"""
Minimal FastAPI WebAuthn server with WebSocket support for passkey registration and authentication.
This module provides a simple WebAuthn implementation that:
- Uses WebSocket for real-time communication
- Supports Resident Keys (discoverable credentials) for passwordless authentication
- Maintains challenges locally per connection
- Uses async SQLite database for persistent storage of users and credentials
- Enables true passwordless authentication where users don't need to enter a user_name
"""
import contextlib
import logging
from contextlib import asynccontextmanager
@ -20,7 +9,6 @@ from fastapi.responses import (
)
from fastapi.staticfiles import StaticFiles
from ..db import close_db, init_db
from . import session, ws
from .api import register_api_routes
from .reset import register_reset_routes
@ -30,9 +18,10 @@ STATIC_DIR = Path(__file__).parent.parent / "frontend-build"
@asynccontextmanager
async def lifespan(app: FastAPI):
await init_db()
from ..db import sql
await sql.init()
yield
await close_db()
app = FastAPI(lifespan=lifespan)

View File

@ -3,7 +3,7 @@ import logging
from fastapi import Cookie, HTTPException, Request
from fastapi.responses import RedirectResponse
from ..db import database
from ..db import db
from ..util import passphrase, tokens
from . import session
@ -20,7 +20,7 @@ def register_reset_routes(app):
# Generate a human-readable token
token = passphrase.generate() # e.g., "cross.rotate.yin.note.evoke"
await database().create_session(
await db.instance.create_session(
user_uuid=s.user_uuid,
key=tokens.reset_key(token),
expires=session.expires(),
@ -56,7 +56,7 @@ def register_reset_routes(app):
try:
# Get session token to validate it exists and get user_id
key = tokens.reset_key(reset_token)
sess = await database().get_session(key)
sess = await db.instance.get_session(key)
if not sess:
raise ValueError("Invalid or expired registration token")

View File

@ -14,7 +14,7 @@ from uuid import UUID
from fastapi import Request, Response, WebSocket
from ..db import Session, database
from ..db import Session, db
from ..util import passphrase
from ..util.tokens import create_token, reset_key, session_key
@ -37,7 +37,7 @@ def infodict(request: Request | WebSocket, type: str) -> dict:
async def create_session(user_uuid: UUID, info: dict, credential_uuid: UUID) -> str:
"""Create a new session and return a session token."""
token = create_token()
await database().create_session(
await db.instance.create_session(
user_uuid=user_uuid,
key=session_key(token),
expires=datetime.now() + EXPIRES,
@ -56,7 +56,7 @@ async def get_session(token: str, reset_allowed=False) -> Session:
else:
key = session_key(token)
session = await database().get_session(key)
session = await db.instance.get_session(key)
if not session:
raise ValueError("Invalid or expired session token")
return session
@ -65,7 +65,7 @@ async def get_session(token: str, reset_allowed=False) -> Session:
async def refresh_session_token(token: str):
"""Refresh a session extending its expiry."""
# Get the current session
s = await database().update_session(
s = await db.instance.update_session(
session_key(token), datetime.now() + EXPIRES, {}
)
@ -88,4 +88,4 @@ def set_session_cookie(response: Response, token: str) -> None:
async def delete_credential(credential_uuid: UUID, auth: str):
"""Delete a specific credential for the current user."""
s = await get_session(auth)
await database().delete_credential(credential_uuid, s.user_uuid)
await db.instance.delete_credential(credential_uuid, s.user_uuid)

View File

@ -18,7 +18,7 @@ from webauthn.helpers.exceptions import InvalidAuthenticationResponse
from passkey.fastapi import session
from ..db import User, database
from ..db import User, db
from ..sansio import Passkey
from ..util.tokens import create_token, session_key
from .session import create_session, infodict
@ -65,13 +65,13 @@ async def websocket_register_new(
credential = await register_chat(ws, user_uuid, user_name, origin=origin)
# Store the user and credential in the database
await database().create_user_and_credential(
await db.instance.create_user_and_credential(
User(user_uuid, user_name, created_at=datetime.now()),
credential,
)
# Create a session token for the new user
token = create_token()
await database().create_session(
await db.instance.create_session(
user_uuid=user_uuid,
key=session_key(token),
expires=datetime.now() + session.EXPIRES,
@ -106,16 +106,16 @@ async def websocket_register_add(ws: WebSocket, auth=Cookie(None)):
user_uuid = s.user_uuid
# Get user information to get the user_name
user = await database().get_user_by_user_uuid(user_uuid)
user = await db.instance.get_user_by_user_uuid(user_uuid)
user_name = user.user_name
challenge_ids = await database().get_credentials_by_user_uuid(user_uuid)
challenge_ids = await db.instance.get_credentials_by_user_uuid(user_uuid)
# WebAuthn registration
credential = await register_chat(
ws, user_uuid, user_name, challenge_ids, origin
)
# Store the new credential in the database
await database().create_credential(credential)
await db.instance.create_credential(credential)
await ws.send_json(
{
@ -144,11 +144,11 @@ async def websocket_authenticate(ws: WebSocket):
# Wait for the client to use his authenticator to authenticate
credential = passkey.auth_parse(await ws.receive_json())
# Fetch from the database by credential ID
stored_cred = await database().get_credential_by_id(credential.raw_id)
stored_cred = await db.instance.get_credential_by_id(credential.raw_id)
# Verify the credential matches the stored data
passkey.auth_verify(credential, challenge, stored_cred, origin=origin)
# Update both credential and user's last_seen timestamp
await database().login(stored_cred.user_uuid, stored_cred)
await db.instance.login(stored_cred.user_uuid, stored_cred)
# Create a session token for the authenticated user
assert stored_cred.uuid is not None