Finish DB cleanup/refactoring. Working now.
This commit is contained in:
parent
c5733eefd6
commit
b58b7d5350
@ -134,15 +134,26 @@ class DatabaseInterface(ABC):
|
||||
"""Create a new user and their first credential in a transaction."""
|
||||
|
||||
|
||||
# Global DB instance
|
||||
database_instance: DatabaseInterface | None = None
|
||||
class DatabaseManager:
|
||||
"""Manager for the global database instance."""
|
||||
|
||||
def __init__(self):
|
||||
self._instance: DatabaseInterface | None = None
|
||||
|
||||
@property
|
||||
def instance(self) -> DatabaseInterface:
|
||||
if self._instance is None:
|
||||
raise RuntimeError(
|
||||
"Database not initialized. Call e.g. db.sql.init() first."
|
||||
)
|
||||
return self._instance
|
||||
|
||||
@instance.setter
|
||||
def instance(self, instance: DatabaseInterface) -> None:
|
||||
self._instance = instance
|
||||
|
||||
|
||||
def database() -> DatabaseInterface:
|
||||
"""Get the global database instance."""
|
||||
if database_instance is None:
|
||||
raise RuntimeError("Database not initialized. Call e.g. db.sql.init() first.")
|
||||
return database_instance
|
||||
db = DatabaseManager()
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -150,6 +161,5 @@ __all__ = [
|
||||
"Credential",
|
||||
"Session",
|
||||
"DatabaseInterface",
|
||||
"database_instance",
|
||||
"database",
|
||||
"db",
|
||||
]
|
||||
|
@ -23,15 +23,14 @@ from sqlalchemy.dialects.sqlite import BLOB, JSON
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
from . import Credential, DatabaseInterface, Session, User
|
||||
from . import Credential, DatabaseInterface, Session, User, db
|
||||
|
||||
DB_PATH = "sqlite+aiosqlite:///webauthn.db"
|
||||
|
||||
|
||||
def init(*args, **kwargs):
|
||||
from .. import db
|
||||
|
||||
db.database_instance = DB()
|
||||
async def init(*args, **kwargs):
|
||||
db.instance = DB()
|
||||
await db.instance.init_db()
|
||||
|
||||
|
||||
# SQLAlchemy Models
|
||||
|
@ -14,7 +14,7 @@ from fastapi import Cookie, Depends, FastAPI, Request, Response
|
||||
from fastapi.security import HTTPBearer
|
||||
|
||||
from .. import aaguid
|
||||
from ..db import database
|
||||
from ..db import db
|
||||
from ..util.tokens import session_key
|
||||
from . import session
|
||||
|
||||
@ -42,15 +42,15 @@ def register_api_routes(app: FastAPI):
|
||||
"""Get full user information for the authenticated user."""
|
||||
try:
|
||||
s = await session.get_session(auth, reset_allowed=True)
|
||||
u = await database().get_user_by_user_uuid(s.user_uuid)
|
||||
u = await db.instance.get_user_by_user_uuid(s.user_uuid)
|
||||
# Get all credentials for the user
|
||||
credential_ids = await database().get_credentials_by_user_uuid(s.user_uuid)
|
||||
credential_ids = await db.instance.get_credentials_by_user_uuid(s.user_uuid)
|
||||
|
||||
credentials = []
|
||||
user_aaguids = set()
|
||||
|
||||
for cred_id in credential_ids:
|
||||
c = await database().get_credential_by_id(cred_id)
|
||||
c = await db.instance.get_credential_by_id(cred_id)
|
||||
|
||||
# Convert AAGUID to string format
|
||||
aaguid_str = str(c.aaguid)
|
||||
@ -102,7 +102,7 @@ def register_api_routes(app: FastAPI):
|
||||
"""Log out the current user by clearing the session cookie and deleting from database."""
|
||||
if not auth:
|
||||
return {"status": "success", "message": "Already logged out"}
|
||||
await database().delete_session(session_key(auth))
|
||||
await db.instance.delete_session(session_key(auth))
|
||||
response.delete_cookie("auth")
|
||||
return {"status": "success", "message": "Logged out successfully"}
|
||||
|
||||
|
@ -1,14 +1,3 @@
|
||||
"""
|
||||
Minimal FastAPI WebAuthn server with WebSocket support for passkey registration and authentication.
|
||||
|
||||
This module provides a simple WebAuthn implementation that:
|
||||
- Uses WebSocket for real-time communication
|
||||
- Supports Resident Keys (discoverable credentials) for passwordless authentication
|
||||
- Maintains challenges locally per connection
|
||||
- Uses async SQLite database for persistent storage of users and credentials
|
||||
- Enables true passwordless authentication where users don't need to enter a user_name
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
@ -20,7 +9,6 @@ from fastapi.responses import (
|
||||
)
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from ..db import close_db, init_db
|
||||
from . import session, ws
|
||||
from .api import register_api_routes
|
||||
from .reset import register_reset_routes
|
||||
@ -30,9 +18,10 @@ STATIC_DIR = Path(__file__).parent.parent / "frontend-build"
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await init_db()
|
||||
from ..db import sql
|
||||
|
||||
await sql.init()
|
||||
yield
|
||||
await close_db()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
@ -3,7 +3,7 @@ import logging
|
||||
from fastapi import Cookie, HTTPException, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
from ..db import database
|
||||
from ..db import db
|
||||
from ..util import passphrase, tokens
|
||||
from . import session
|
||||
|
||||
@ -20,7 +20,7 @@ def register_reset_routes(app):
|
||||
|
||||
# Generate a human-readable token
|
||||
token = passphrase.generate() # e.g., "cross.rotate.yin.note.evoke"
|
||||
await database().create_session(
|
||||
await db.instance.create_session(
|
||||
user_uuid=s.user_uuid,
|
||||
key=tokens.reset_key(token),
|
||||
expires=session.expires(),
|
||||
@ -56,7 +56,7 @@ def register_reset_routes(app):
|
||||
try:
|
||||
# Get session token to validate it exists and get user_id
|
||||
key = tokens.reset_key(reset_token)
|
||||
sess = await database().get_session(key)
|
||||
sess = await db.instance.get_session(key)
|
||||
if not sess:
|
||||
raise ValueError("Invalid or expired registration token")
|
||||
|
||||
|
@ -14,7 +14,7 @@ from uuid import UUID
|
||||
|
||||
from fastapi import Request, Response, WebSocket
|
||||
|
||||
from ..db import Session, database
|
||||
from ..db import Session, db
|
||||
from ..util import passphrase
|
||||
from ..util.tokens import create_token, reset_key, session_key
|
||||
|
||||
@ -37,7 +37,7 @@ def infodict(request: Request | WebSocket, type: str) -> dict:
|
||||
async def create_session(user_uuid: UUID, info: dict, credential_uuid: UUID) -> str:
|
||||
"""Create a new session and return a session token."""
|
||||
token = create_token()
|
||||
await database().create_session(
|
||||
await db.instance.create_session(
|
||||
user_uuid=user_uuid,
|
||||
key=session_key(token),
|
||||
expires=datetime.now() + EXPIRES,
|
||||
@ -56,7 +56,7 @@ async def get_session(token: str, reset_allowed=False) -> Session:
|
||||
else:
|
||||
key = session_key(token)
|
||||
|
||||
session = await database().get_session(key)
|
||||
session = await db.instance.get_session(key)
|
||||
if not session:
|
||||
raise ValueError("Invalid or expired session token")
|
||||
return session
|
||||
@ -65,7 +65,7 @@ async def get_session(token: str, reset_allowed=False) -> Session:
|
||||
async def refresh_session_token(token: str):
|
||||
"""Refresh a session extending its expiry."""
|
||||
# Get the current session
|
||||
s = await database().update_session(
|
||||
s = await db.instance.update_session(
|
||||
session_key(token), datetime.now() + EXPIRES, {}
|
||||
)
|
||||
|
||||
@ -88,4 +88,4 @@ def set_session_cookie(response: Response, token: str) -> None:
|
||||
async def delete_credential(credential_uuid: UUID, auth: str):
|
||||
"""Delete a specific credential for the current user."""
|
||||
s = await get_session(auth)
|
||||
await database().delete_credential(credential_uuid, s.user_uuid)
|
||||
await db.instance.delete_credential(credential_uuid, s.user_uuid)
|
||||
|
@ -18,7 +18,7 @@ from webauthn.helpers.exceptions import InvalidAuthenticationResponse
|
||||
|
||||
from passkey.fastapi import session
|
||||
|
||||
from ..db import User, database
|
||||
from ..db import User, db
|
||||
from ..sansio import Passkey
|
||||
from ..util.tokens import create_token, session_key
|
||||
from .session import create_session, infodict
|
||||
@ -65,13 +65,13 @@ async def websocket_register_new(
|
||||
credential = await register_chat(ws, user_uuid, user_name, origin=origin)
|
||||
|
||||
# Store the user and credential in the database
|
||||
await database().create_user_and_credential(
|
||||
await db.instance.create_user_and_credential(
|
||||
User(user_uuid, user_name, created_at=datetime.now()),
|
||||
credential,
|
||||
)
|
||||
# Create a session token for the new user
|
||||
token = create_token()
|
||||
await database().create_session(
|
||||
await db.instance.create_session(
|
||||
user_uuid=user_uuid,
|
||||
key=session_key(token),
|
||||
expires=datetime.now() + session.EXPIRES,
|
||||
@ -106,16 +106,16 @@ async def websocket_register_add(ws: WebSocket, auth=Cookie(None)):
|
||||
user_uuid = s.user_uuid
|
||||
|
||||
# Get user information to get the user_name
|
||||
user = await database().get_user_by_user_uuid(user_uuid)
|
||||
user = await db.instance.get_user_by_user_uuid(user_uuid)
|
||||
user_name = user.user_name
|
||||
challenge_ids = await database().get_credentials_by_user_uuid(user_uuid)
|
||||
challenge_ids = await db.instance.get_credentials_by_user_uuid(user_uuid)
|
||||
|
||||
# WebAuthn registration
|
||||
credential = await register_chat(
|
||||
ws, user_uuid, user_name, challenge_ids, origin
|
||||
)
|
||||
# Store the new credential in the database
|
||||
await database().create_credential(credential)
|
||||
await db.instance.create_credential(credential)
|
||||
|
||||
await ws.send_json(
|
||||
{
|
||||
@ -144,11 +144,11 @@ async def websocket_authenticate(ws: WebSocket):
|
||||
# Wait for the client to use his authenticator to authenticate
|
||||
credential = passkey.auth_parse(await ws.receive_json())
|
||||
# Fetch from the database by credential ID
|
||||
stored_cred = await database().get_credential_by_id(credential.raw_id)
|
||||
stored_cred = await db.instance.get_credential_by_id(credential.raw_id)
|
||||
# Verify the credential matches the stored data
|
||||
passkey.auth_verify(credential, challenge, stored_cred, origin=origin)
|
||||
# Update both credential and user's last_seen timestamp
|
||||
await database().login(stored_cred.user_uuid, stored_cred)
|
||||
await db.instance.login(stored_cred.user_uuid, stored_cred)
|
||||
|
||||
# Create a session token for the authenticated user
|
||||
assert stored_cred.uuid is not None
|
||||
|
Loading…
x
Reference in New Issue
Block a user