Compare commits

..

No commits in common. "b58b7d5350ee5196604e8969542a1959e88207f0" and "00693c56fa31999d4d8b397dd758a3a1fc167d47" have entirely different histories.

7 changed files with 251 additions and 335 deletions

View File

@ -5,8 +5,7 @@ This module provides dataclasses and database abstractions for managing
users, credentials, and sessions in a WebAuthn authentication system. users, credentials, and sessions in a WebAuthn authentication system.
""" """
from abc import ABC, abstractmethod from dataclasses import dataclass, field
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from uuid import UUID from uuid import UUID
@ -48,118 +47,4 @@ class Session:
credential_uuid: UUID | None = None credential_uuid: UUID | None = None
class DatabaseInterface(ABC): __all__ = ["User", "Credential", "Session"]
"""Abstract base class defining the database interface.
This class defines the public API that database implementations should provide.
Implementations may use decorators like @with_session that modify method signatures
at runtime, so this interface focuses on the logical operations rather than
exact parameter matching.
"""
@abstractmethod
async def init_db(self) -> None:
"""Initialize database tables."""
pass
# User operations
@abstractmethod
async def get_user_by_user_uuid(self, user_uuid: UUID) -> User:
"""Get user record by WebAuthn user UUID."""
@abstractmethod
async def create_user(self, user: User) -> None:
"""Create a new user."""
# Credential operations
@abstractmethod
async def create_credential(self, credential: Credential) -> None:
"""Store a credential for a user."""
@abstractmethod
async def get_credential_by_id(self, credential_id: bytes) -> Credential:
"""Get credential by credential ID."""
@abstractmethod
async def get_credentials_by_user_uuid(self, user_uuid: UUID) -> list[bytes]:
"""Get all credential IDs for a user."""
@abstractmethod
async def update_credential(self, credential: Credential) -> None:
"""Update the sign count, created_at, last_used, and last_verified for a credential."""
@abstractmethod
async def delete_credential(self, uuid: UUID, user_uuid: UUID) -> None:
"""Delete a specific credential for a user."""
# Session operations
@abstractmethod
async def create_session(
self,
user_uuid: UUID,
key: bytes,
expires: datetime,
info: dict,
credential_uuid: UUID | None = None,
) -> None:
"""Create a new session."""
@abstractmethod
async def get_session(self, key: bytes) -> Session | None:
"""Get session by key."""
@abstractmethod
async def delete_session(self, key: bytes) -> None:
"""Delete session by key."""
@abstractmethod
async def update_session(
self, key: bytes, expires: datetime, info: dict
) -> Session | None:
"""Update session expiry and info."""
@abstractmethod
async def cleanup(self) -> None:
"""Called periodically to clean up expired records."""
# Combined operations
@abstractmethod
async def login(self, user_uuid: UUID, credential: Credential) -> None:
"""Update user and credential timestamps after successful login."""
@abstractmethod
async def create_user_and_credential(
self, user: User, credential: Credential
) -> None:
"""Create a new user and their first credential in a transaction."""
class DatabaseManager:
"""Manager for the global database instance."""
def __init__(self):
self._instance: DatabaseInterface | None = None
@property
def instance(self) -> DatabaseInterface:
if self._instance is None:
raise RuntimeError(
"Database not initialized. Call e.g. db.sql.init() first."
)
return self._instance
@instance.setter
def instance(self, instance: DatabaseInterface) -> None:
self._instance = instance
db = DatabaseManager()
__all__ = [
"User",
"Credential",
"Session",
"DatabaseInterface",
"db",
]

View File

@ -5,8 +5,8 @@ This module provides an async database layer using SQLAlchemy async mode
for managing users and credentials in a WebAuthn authentication system. for managing users and credentials in a WebAuthn authentication system.
""" """
from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
from functools import wraps
from uuid import UUID from uuid import UUID
from sqlalchemy import ( from sqlalchemy import (
@ -23,14 +23,24 @@ from sqlalchemy.dialects.sqlite import BLOB, JSON
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from . import Credential, DatabaseInterface, Session, User, db from . import Credential, Session, User
DB_PATH = "sqlite+aiosqlite:///webauthn.db" DB_PATH = "sqlite+aiosqlite:///webauthn.db"
async def init(*args, **kwargs): def with_session(func):
db.instance = DB() """Decorator that provides a database session with transaction to the method."""
await db.instance.init_db()
@wraps(func)
async def wrapper(self, *args, **kwargs):
async with self.async_session_factory() as session:
async with session.begin():
result = await func(self, session, *args, **kwargs)
await session.flush()
return result
await session.commit()
return wrapper
# SQLAlchemy Models # SQLAlchemy Models
@ -90,7 +100,7 @@ class SessionModel(Base):
user: Mapped["UserModel"] = relationship("UserModel") user: Mapped["UserModel"] = relationship("UserModel")
class DB(DatabaseInterface): class DB:
"""Database class that handles its own connections.""" """Database class that handles its own connections."""
def __init__(self, db_path: str = DB_PATH): def __init__(self, db_path: str = DB_PATH):
@ -100,22 +110,14 @@ class DB(DatabaseInterface):
self.engine, expire_on_commit=False self.engine, expire_on_commit=False
) )
@asynccontextmanager
async def session(self):
"""Async context manager that provides a database session with transaction."""
async with self.async_session_factory() as session:
async with session.begin():
yield session
await session.flush()
await session.commit()
async def init_db(self) -> None: async def init_db(self) -> None:
"""Initialize database tables.""" """Initialize database tables."""
async with self.engine.begin() as conn: async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
async def get_user_by_user_uuid(self, user_uuid: UUID) -> User: @with_session
async with self.session() as session: async def get_user_by_user_uuid(self, session, user_uuid: UUID) -> User:
"""Get user record by WebAuthn user UUID."""
stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes) stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes)
result = await session.execute(stmt) result = await session.execute(stmt)
user_model = result.scalar_one_or_none() user_model = result.scalar_one_or_none()
@ -130,8 +132,9 @@ class DB(DatabaseInterface):
) )
raise ValueError("User not found") raise ValueError("User not found")
async def create_user(self, user: User) -> None: @with_session
async with self.session() as session: async def create_user(self, session, user: User) -> None:
"""Create a new user."""
user_model = UserModel( user_model = UserModel(
user_uuid=user.user_uuid.bytes, user_uuid=user.user_uuid.bytes,
user_name=user.user_name, user_name=user.user_name,
@ -141,8 +144,9 @@ class DB(DatabaseInterface):
) )
session.add(user_model) session.add(user_model)
async def create_credential(self, credential: Credential) -> None: @with_session
async with self.session() as session: async def create_credential(self, session, credential: Credential) -> None:
"""Store a credential for a user."""
credential_model = CredentialModel( credential_model = CredentialModel(
uuid=credential.uuid.bytes, uuid=credential.uuid.bytes,
credential_id=credential.credential_id, credential_id=credential.credential_id,
@ -156,8 +160,9 @@ class DB(DatabaseInterface):
) )
session.add(credential_model) session.add(credential_model)
async def get_credential_by_id(self, credential_id: bytes) -> Credential: @with_session
async with self.session() as session: async def get_credential_by_id(self, session, credential_id: bytes) -> Credential:
"""Get credential by credential ID."""
stmt = select(CredentialModel).where( stmt = select(CredentialModel).where(
CredentialModel.credential_id == credential_id CredentialModel.credential_id == credential_id
) )
@ -178,16 +183,20 @@ class DB(DatabaseInterface):
last_verified=credential_model.last_verified, last_verified=credential_model.last_verified,
) )
async def get_credentials_by_user_uuid(self, user_uuid: UUID) -> list[bytes]: @with_session
async with self.session() as session: async def get_credentials_by_user_uuid(
self, session, user_uuid: UUID
) -> list[bytes]:
"""Get all credential IDs for a user."""
stmt = select(CredentialModel.credential_id).where( stmt = select(CredentialModel.credential_id).where(
CredentialModel.user_uuid == user_uuid.bytes CredentialModel.user_uuid == user_uuid.bytes
) )
result = await session.execute(stmt) result = await session.execute(stmt)
return [row[0] for row in result.fetchall()] return [row[0] for row in result.fetchall()]
async def update_credential(self, credential: Credential) -> None: @with_session
async with self.session() as session: async def update_credential(self, session, credential: Credential) -> None:
"""Update the sign count, created_at, last_used, and last_verified for a credential."""
stmt = ( stmt = (
update(CredentialModel) update(CredentialModel)
.where(CredentialModel.credential_id == credential.credential_id) .where(CredentialModel.credential_id == credential.credential_id)
@ -200,8 +209,9 @@ class DB(DatabaseInterface):
) )
await session.execute(stmt) await session.execute(stmt)
async def login(self, user_uuid: UUID, credential: Credential) -> None: @with_session
async with self.session() as session: async def login(self, session, user_uuid: UUID, credential: Credential) -> None:
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
# Update credential # Update credential
stmt = ( stmt = (
update(CredentialModel) update(CredentialModel)
@ -223,10 +233,11 @@ class DB(DatabaseInterface):
) )
await session.execute(stmt) await session.execute(stmt)
@with_session
async def create_user_and_credential( async def create_user_and_credential(
self, user: User, credential: Credential self, session, user: User, credential: Credential
) -> None: ) -> None:
async with self.session() as session: """Create a new user and their first credential in a single transaction."""
# Set visits to 1 for the new user since they're creating their first session # Set visits to 1 for the new user since they're creating their first session
user.visits = 1 user.visits = 1
@ -254,8 +265,9 @@ class DB(DatabaseInterface):
) )
session.add(credential_model) session.add(credential_model)
async def delete_credential(self, uuid: UUID, user_uuid: UUID) -> None: @with_session
async with self.session() as session: async def delete_credential(self, session, uuid: UUID, user_uuid: UUID) -> None:
"""Delete a credential by its ID."""
stmt = ( stmt = (
delete(CredentialModel) delete(CredentialModel)
.where(CredentialModel.uuid == uuid.bytes) .where(CredentialModel.uuid == uuid.bytes)
@ -263,15 +275,17 @@ class DB(DatabaseInterface):
) )
await session.execute(stmt) await session.execute(stmt)
@with_session
async def create_session( async def create_session(
self, self,
session,
user_uuid: UUID, user_uuid: UUID,
key: bytes, key: bytes,
expires: datetime, expires: datetime,
info: dict, info: dict,
credential_uuid: UUID | None = None, credential_uuid: UUID | None = None,
) -> None: ) -> bytes:
async with self.session() as session: """Create a new authentication session for a user. If credential_uuid is None, creates a session without a specific credential."""
session_model = SessionModel( session_model = SessionModel(
key=key, key=key,
user_uuid=user_uuid.bytes, user_uuid=user_uuid.bytes,
@ -280,9 +294,11 @@ class DB(DatabaseInterface):
info=info, info=info,
) )
session.add(session_model) session.add(session_model)
return key
async def get_session(self, key: bytes) -> Session | None: @with_session
async with self.session() as session: async def get_session(self, session, key: bytes) -> Session | None:
"""Get session by 16-byte key."""
stmt = select(SessionModel).where(SessionModel.key == key) stmt = select(SessionModel).where(SessionModel.key == key)
result = await session.execute(stmt) result = await session.execute(stmt)
session_model = result.scalar_one_or_none() session_model = result.scalar_one_or_none()
@ -299,20 +315,25 @@ class DB(DatabaseInterface):
) )
return None return None
async def delete_session(self, key: bytes) -> None: @with_session
async with self.session() as session: async def delete_session(self, session, key: bytes) -> None:
"""Delete a session by 16-byte key."""
await session.execute(delete(SessionModel).where(SessionModel.key == key)) await session.execute(delete(SessionModel).where(SessionModel.key == key))
async def update_session(self, key: bytes, expires: datetime, info: dict) -> None: @with_session
async with self.session() as session: async def update_session(
self, session, key: bytes, expires: datetime, info: dict
) -> None:
"""Update session expiration time and/or info."""
await session.execute( await session.execute(
update(SessionModel) update(SessionModel)
.where(SessionModel.key == key) .where(SessionModel.key == key)
.values(expires=expires, info=info) .values(expires=expires, info=info)
) )
async def cleanup(self) -> None: @with_session
async with self.session() as session: async def cleanup_expired_sessions(self, session) -> None:
"""Remove expired sessions."""
current_time = datetime.now() current_time = datetime.now()
stmt = delete(SessionModel).where(SessionModel.expires < current_time) stmt = delete(SessionModel).where(SessionModel.expires < current_time)
await session.execute(stmt) await session.execute(stmt)

View File

@ -14,7 +14,7 @@ from fastapi import Cookie, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBearer from fastapi.security import HTTPBearer
from .. import aaguid from .. import aaguid
from ..db import db from ..db import sql
from ..util.tokens import session_key from ..util.tokens import session_key
from . import session from . import session
@ -38,19 +38,19 @@ def register_api_routes(app: FastAPI):
return {"status": "error", "valid": False} return {"status": "error", "valid": False}
@app.post("/auth/user-info") @app.post("/auth/user-info")
async def api_user_info(auth=Cookie(None)): async def api_user_info(request: Request, response: Response, auth=Cookie(None)):
"""Get full user information for the authenticated user.""" """Get full user information for the authenticated user."""
try: try:
s = await session.get_session(auth, reset_allowed=True) s = await session.get_session(auth, reset_allowed=True)
u = await db.instance.get_user_by_user_uuid(s.user_uuid) u = await sql.get_user_by_uuid(s.user_uuid)
# Get all credentials for the user # Get all credentials for the user
credential_ids = await db.instance.get_credentials_by_user_uuid(s.user_uuid) credential_ids = await sql.get_user_credentials(s.user_uuid)
credentials = [] credentials = []
user_aaguids = set() user_aaguids = set()
for cred_id in credential_ids: for cred_id in credential_ids:
c = await db.instance.get_credential_by_id(cred_id) c = await sql.get_credential_by_id(cred_id)
# Convert AAGUID to string format # Convert AAGUID to string format
aaguid_str = str(c.aaguid) aaguid_str = str(c.aaguid)
@ -102,12 +102,14 @@ def register_api_routes(app: FastAPI):
"""Log out the current user by clearing the session cookie and deleting from database.""" """Log out the current user by clearing the session cookie and deleting from database."""
if not auth: if not auth:
return {"status": "success", "message": "Already logged out"} return {"status": "success", "message": "Already logged out"}
await db.instance.delete_session(session_key(auth)) await sql.delete_session(session_key(auth))
response.delete_cookie("auth") response.delete_cookie("auth")
return {"status": "success", "message": "Logged out successfully"} return {"status": "success", "message": "Logged out successfully"}
@app.post("/auth/set-session") @app.post("/auth/set-session")
async def api_set_session(response: Response, auth=Depends(bearer_auth)): async def api_set_session(
request: Request, response: Response, auth=Depends(bearer_auth)
):
"""Set session cookie from Authorization header. Fetched after login by WebSocket.""" """Set session cookie from Authorization header. Fetched after login by WebSocket."""
try: try:
user = await session.get_session(auth.credentials) user = await session.get_session(auth.credentials)

View File

@ -1,3 +1,14 @@
"""
Minimal FastAPI WebAuthn server with WebSocket support for passkey registration and authentication.
This module provides a simple WebAuthn implementation that:
- Uses WebSocket for real-time communication
- Supports Resident Keys (discoverable credentials) for passwordless authentication
- Maintains challenges locally per connection
- Uses async SQLite database for persistent storage of users and credentials
- Enables true passwordless authentication where users don't need to enter a user_name
"""
import contextlib import contextlib
import logging import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -9,6 +20,7 @@ from fastapi.responses import (
) )
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from ..db import sql
from . import session, ws from . import session, ws
from .api import register_api_routes from .api import register_api_routes
from .reset import register_reset_routes from .reset import register_reset_routes
@ -18,9 +30,7 @@ STATIC_DIR = Path(__file__).parent.parent / "frontend-build"
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
from ..db import sql await sql.init_database()
await sql.init()
yield yield

View File

@ -3,7 +3,7 @@ import logging
from fastapi import Cookie, HTTPException, Request from fastapi import Cookie, HTTPException, Request
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from ..db import db from ..db import sql
from ..util import passphrase, tokens from ..util import passphrase, tokens
from . import session from . import session
@ -20,7 +20,7 @@ def register_reset_routes(app):
# Generate a human-readable token # Generate a human-readable token
token = passphrase.generate() # e.g., "cross.rotate.yin.note.evoke" token = passphrase.generate() # e.g., "cross.rotate.yin.note.evoke"
await db.instance.create_session( await sql.create_session(
user_uuid=s.user_uuid, user_uuid=s.user_uuid,
key=tokens.reset_key(token), key=tokens.reset_key(token),
expires=session.expires(), expires=session.expires(),
@ -56,7 +56,7 @@ def register_reset_routes(app):
try: try:
# Get session token to validate it exists and get user_id # Get session token to validate it exists and get user_id
key = tokens.reset_key(reset_token) key = tokens.reset_key(reset_token)
sess = await db.instance.get_session(key) sess = await sql.get_session(key)
if not sess: if not sess:
raise ValueError("Invalid or expired registration token") raise ValueError("Invalid or expired registration token")

View File

@ -14,7 +14,7 @@ from uuid import UUID
from fastapi import Request, Response, WebSocket from fastapi import Request, Response, WebSocket
from ..db import Session, db from ..db import Session, sql
from ..util import passphrase from ..util import passphrase
from ..util.tokens import create_token, reset_key, session_key from ..util.tokens import create_token, reset_key, session_key
@ -37,7 +37,7 @@ def infodict(request: Request | WebSocket, type: str) -> dict:
async def create_session(user_uuid: UUID, info: dict, credential_uuid: UUID) -> str: async def create_session(user_uuid: UUID, info: dict, credential_uuid: UUID) -> str:
"""Create a new session and return a session token.""" """Create a new session and return a session token."""
token = create_token() token = create_token()
await db.instance.create_session( await sql.create_session(
user_uuid=user_uuid, user_uuid=user_uuid,
key=session_key(token), key=session_key(token),
expires=datetime.now() + EXPIRES, expires=datetime.now() + EXPIRES,
@ -56,7 +56,7 @@ async def get_session(token: str, reset_allowed=False) -> Session:
else: else:
key = session_key(token) key = session_key(token)
session = await db.instance.get_session(key) session = await sql.get_session(key)
if not session: if not session:
raise ValueError("Invalid or expired session token") raise ValueError("Invalid or expired session token")
return session return session
@ -65,9 +65,7 @@ async def get_session(token: str, reset_allowed=False) -> Session:
async def refresh_session_token(token: str): async def refresh_session_token(token: str):
"""Refresh a session extending its expiry.""" """Refresh a session extending its expiry."""
# Get the current session # Get the current session
s = await db.instance.update_session( s = await sql.update_session(session_key(token), datetime.now() + EXPIRES, {})
session_key(token), datetime.now() + EXPIRES, {}
)
if not s: if not s:
raise ValueError("Session not found or expired") raise ValueError("Session not found or expired")
@ -88,4 +86,4 @@ def set_session_cookie(response: Response, token: str) -> None:
async def delete_credential(credential_uuid: UUID, auth: str): async def delete_credential(credential_uuid: UUID, auth: str):
"""Delete a specific credential for the current user.""" """Delete a specific credential for the current user."""
s = await get_session(auth) s = await get_session(auth)
await db.instance.delete_credential(credential_uuid, s.user_uuid) await sql.delete_credential(credential_uuid, s.user_uuid)

View File

@ -18,7 +18,7 @@ from webauthn.helpers.exceptions import InvalidAuthenticationResponse
from passkey.fastapi import session from passkey.fastapi import session
from ..db import User, db from ..db import User, sql
from ..sansio import Passkey from ..sansio import Passkey
from ..util.tokens import create_token, session_key from ..util.tokens import create_token, session_key
from .session import create_session, infodict from .session import create_session, infodict
@ -65,13 +65,13 @@ async def websocket_register_new(
credential = await register_chat(ws, user_uuid, user_name, origin=origin) credential = await register_chat(ws, user_uuid, user_name, origin=origin)
# Store the user and credential in the database # Store the user and credential in the database
await db.instance.create_user_and_credential( await sql.create_user_and_credential(
User(user_uuid, user_name, created_at=datetime.now()), User(user_uuid, user_name, created_at=datetime.now()),
credential, credential,
) )
# Create a session token for the new user # Create a session token for the new user
token = create_token() token = create_token()
await db.instance.create_session( await sql.create_session(
user_uuid=user_uuid, user_uuid=user_uuid,
key=session_key(token), key=session_key(token),
expires=datetime.now() + session.EXPIRES, expires=datetime.now() + session.EXPIRES,
@ -106,16 +106,16 @@ async def websocket_register_add(ws: WebSocket, auth=Cookie(None)):
user_uuid = s.user_uuid user_uuid = s.user_uuid
# Get user information to get the user_name # Get user information to get the user_name
user = await db.instance.get_user_by_user_uuid(user_uuid) user = await sql.get_user_by_uuid(user_uuid)
user_name = user.user_name user_name = user.user_name
challenge_ids = await db.instance.get_credentials_by_user_uuid(user_uuid) challenge_ids = await sql.get_user_credentials(user_uuid)
# WebAuthn registration # WebAuthn registration
credential = await register_chat( credential = await register_chat(
ws, user_uuid, user_name, challenge_ids, origin ws, user_uuid, user_name, challenge_ids, origin
) )
# Store the new credential in the database # Store the new credential in the database
await db.instance.create_credential(credential) await sql.create_credential_for_user(credential)
await ws.send_json( await ws.send_json(
{ {
@ -144,11 +144,11 @@ async def websocket_authenticate(ws: WebSocket):
# Wait for the client to use his authenticator to authenticate # Wait for the client to use his authenticator to authenticate
credential = passkey.auth_parse(await ws.receive_json()) credential = passkey.auth_parse(await ws.receive_json())
# Fetch from the database by credential ID # Fetch from the database by credential ID
stored_cred = await db.instance.get_credential_by_id(credential.raw_id) stored_cred = await sql.get_credential_by_id(credential.raw_id)
# Verify the credential matches the stored data # Verify the credential matches the stored data
passkey.auth_verify(credential, challenge, stored_cred, origin=origin) passkey.auth_verify(credential, challenge, stored_cred, origin=origin)
# Update both credential and user's last_seen timestamp # Update both credential and user's last_seen timestamp
await db.instance.login(stored_cred.user_uuid, stored_cred) await sql.login_user(stored_cred.user_uuid, stored_cred)
# Create a session token for the authenticated user # Create a session token for the authenticated user
assert stored_cred.uuid is not None assert stored_cred.uuid is not None