Database cleanup, base class, separated from FastAPI app.

This commit is contained in:
Leo Vasanko
2025-08-05 07:55:31 -06:00
parent 00693c56fa
commit c5733eefd6
7 changed files with 325 additions and 239 deletions

View File

@@ -14,7 +14,7 @@ from fastapi import Cookie, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBearer
from .. import aaguid
from ..db import sql
from ..db import database
from ..util.tokens import session_key
from . import session
@@ -38,19 +38,19 @@ def register_api_routes(app: FastAPI):
return {"status": "error", "valid": False}
@app.post("/auth/user-info")
async def api_user_info(request: Request, response: Response, auth=Cookie(None)):
async def api_user_info(auth=Cookie(None)):
"""Get full user information for the authenticated user."""
try:
s = await session.get_session(auth, reset_allowed=True)
u = await sql.get_user_by_uuid(s.user_uuid)
u = await database().get_user_by_user_uuid(s.user_uuid)
# Get all credentials for the user
credential_ids = await sql.get_user_credentials(s.user_uuid)
credential_ids = await database().get_credentials_by_user_uuid(s.user_uuid)
credentials = []
user_aaguids = set()
for cred_id in credential_ids:
c = await sql.get_credential_by_id(cred_id)
c = await database().get_credential_by_id(cred_id)
# Convert AAGUID to string format
aaguid_str = str(c.aaguid)
@@ -102,14 +102,12 @@ def register_api_routes(app: FastAPI):
"""Log out the current user by clearing the session cookie and deleting from database."""
if not auth:
return {"status": "success", "message": "Already logged out"}
await sql.delete_session(session_key(auth))
await database().delete_session(session_key(auth))
response.delete_cookie("auth")
return {"status": "success", "message": "Logged out successfully"}
@app.post("/auth/set-session")
async def api_set_session(
request: Request, response: Response, auth=Depends(bearer_auth)
):
async def api_set_session(response: Response, auth=Depends(bearer_auth)):
"""Set session cookie from Authorization header. Fetched after login by WebSocket."""
try:
user = await session.get_session(auth.credentials)

View File

@@ -20,7 +20,7 @@ from fastapi.responses import (
)
from fastapi.staticfiles import StaticFiles
from ..db import sql
from ..db import close_db, init_db
from . import session, ws
from .api import register_api_routes
from .reset import register_reset_routes
@@ -30,8 +30,9 @@ STATIC_DIR = Path(__file__).parent.parent / "frontend-build"
@asynccontextmanager
async def lifespan(app: FastAPI):
await sql.init_database()
await init_db()
yield
await close_db()
app = FastAPI(lifespan=lifespan)

View File

@@ -3,7 +3,7 @@ import logging
from fastapi import Cookie, HTTPException, Request
from fastapi.responses import RedirectResponse
from ..db import sql
from ..db import database
from ..util import passphrase, tokens
from . import session
@@ -20,7 +20,7 @@ def register_reset_routes(app):
# Generate a human-readable token
token = passphrase.generate() # e.g., "cross.rotate.yin.note.evoke"
await sql.create_session(
await database().create_session(
user_uuid=s.user_uuid,
key=tokens.reset_key(token),
expires=session.expires(),
@@ -56,7 +56,7 @@ def register_reset_routes(app):
try:
# Get session token to validate it exists and get user_id
key = tokens.reset_key(reset_token)
sess = await sql.get_session(key)
sess = await database().get_session(key)
if not sess:
raise ValueError("Invalid or expired registration token")

View File

@@ -14,7 +14,7 @@ from uuid import UUID
from fastapi import Request, Response, WebSocket
from ..db import Session, sql
from ..db import Session, database
from ..util import passphrase
from ..util.tokens import create_token, reset_key, session_key
@@ -37,7 +37,7 @@ def infodict(request: Request | WebSocket, type: str) -> dict:
async def create_session(user_uuid: UUID, info: dict, credential_uuid: UUID) -> str:
"""Create a new session and return a session token."""
token = create_token()
await sql.create_session(
await database().create_session(
user_uuid=user_uuid,
key=session_key(token),
expires=datetime.now() + EXPIRES,
@@ -56,7 +56,7 @@ async def get_session(token: str, reset_allowed=False) -> Session:
else:
key = session_key(token)
session = await sql.get_session(key)
session = await database().get_session(key)
if not session:
raise ValueError("Invalid or expired session token")
return session
@@ -65,7 +65,9 @@ async def get_session(token: str, reset_allowed=False) -> Session:
async def refresh_session_token(token: str):
"""Refresh a session extending its expiry."""
# Get the current session
s = await sql.update_session(session_key(token), datetime.now() + EXPIRES, {})
s = await database().update_session(
session_key(token), datetime.now() + EXPIRES, {}
)
if not s:
raise ValueError("Session not found or expired")
@@ -86,4 +88,4 @@ def set_session_cookie(response: Response, token: str) -> None:
async def delete_credential(credential_uuid: UUID, auth: str):
"""Delete a specific credential for the current user."""
s = await get_session(auth)
await sql.delete_credential(credential_uuid, s.user_uuid)
await database().delete_credential(credential_uuid, s.user_uuid)

View File

@@ -18,7 +18,7 @@ from webauthn.helpers.exceptions import InvalidAuthenticationResponse
from passkey.fastapi import session
from ..db import User, sql
from ..db import User, database
from ..sansio import Passkey
from ..util.tokens import create_token, session_key
from .session import create_session, infodict
@@ -65,13 +65,13 @@ async def websocket_register_new(
credential = await register_chat(ws, user_uuid, user_name, origin=origin)
# Store the user and credential in the database
await sql.create_user_and_credential(
await database().create_user_and_credential(
User(user_uuid, user_name, created_at=datetime.now()),
credential,
)
# Create a session token for the new user
token = create_token()
await sql.create_session(
await database().create_session(
user_uuid=user_uuid,
key=session_key(token),
expires=datetime.now() + session.EXPIRES,
@@ -106,16 +106,16 @@ async def websocket_register_add(ws: WebSocket, auth=Cookie(None)):
user_uuid = s.user_uuid
# Get user information to get the user_name
user = await sql.get_user_by_uuid(user_uuid)
user = await database().get_user_by_user_uuid(user_uuid)
user_name = user.user_name
challenge_ids = await sql.get_user_credentials(user_uuid)
challenge_ids = await database().get_credentials_by_user_uuid(user_uuid)
# WebAuthn registration
credential = await register_chat(
ws, user_uuid, user_name, challenge_ids, origin
)
# Store the new credential in the database
await sql.create_credential_for_user(credential)
await database().create_credential(credential)
await ws.send_json(
{
@@ -144,11 +144,11 @@ async def websocket_authenticate(ws: WebSocket):
# Wait for the client to use his authenticator to authenticate
credential = passkey.auth_parse(await ws.receive_json())
# Fetch from the database by credential ID
stored_cred = await sql.get_credential_by_id(credential.raw_id)
stored_cred = await database().get_credential_by_id(credential.raw_id)
# Verify the credential matches the stored data
passkey.auth_verify(credential, challenge, stored_cred, origin=origin)
# Update both credential and user's last_seen timestamp
await sql.login_user(stored_cred.user_uuid, stored_cred)
await database().login(stored_cred.user_uuid, stored_cred)
# Create a session token for the authenticated user
assert stored_cred.uuid is not None