More refactoring. Prevent registering another key on the same authenticator for the same user.

This commit is contained in:
Leo Vasanko 2025-07-07 11:20:28 -06:00
parent eb56c000e8
commit 1c9044054a
7 changed files with 291 additions and 242 deletions

View File

@ -10,8 +10,8 @@ This module contains all the HTTP API endpoints for:
from fastapi import Request, Response from fastapi import Request, Response
from . import db
from .aaguid_manager import get_aaguid_manager from .aaguid_manager import get_aaguid_manager
from .db import connect
from .jwt_manager import refresh_session_token, validate_session_token from .jwt_manager import refresh_session_token, validate_session_token
from .session_manager import ( from .session_manager import (
clear_session_cookie, clear_session_cookie,
@ -57,9 +57,8 @@ async def get_user_credentials(request: Request) -> dict:
if token_data: if token_data:
current_credential_id = token_data.get("credential_id") current_credential_id = token_data.get("credential_id")
async with connect() as db:
# Get all credentials for the user # Get all credentials for the user
credential_ids = await db.get_credentials_by_user_id(user.user_id.bytes) credential_ids = await db.get_user_credentials(user.user_id)
credentials = [] credentials = []
user_aaguids = set() user_aaguids = set()
@ -73,9 +72,7 @@ async def get_user_credentials(request: Request) -> dict:
user_aaguids.add(aaguid_str) user_aaguids.add(aaguid_str)
# Check if this is the current session credential # Check if this is the current session credential
is_current_session = ( is_current_session = current_credential_id == stored_cred.credential_id
current_credential_id == stored_cred.credential_id
)
credentials.append( credentials.append(
{ {
@ -212,7 +209,6 @@ async def delete_credential(request: Request) -> dict:
except ValueError: except ValueError:
return {"error": "Invalid credential_id format"} return {"error": "Invalid credential_id format"}
async with connect() as db:
# First, verify the credential belongs to the current user # First, verify the credential belongs to the current user
try: try:
stored_cred = await db.get_credential_by_id(credential_id_bytes) stored_cred = await db.get_credential_by_id(credential_id_bytes)
@ -225,21 +221,16 @@ async def delete_credential(request: Request) -> dict:
session_token = get_session_token_from_request(request) session_token = get_session_token_from_request(request)
if session_token: if session_token:
token_data = validate_session_token(session_token) token_data = validate_session_token(session_token)
if ( if token_data and token_data.get("credential_id") == credential_id_bytes:
token_data
and token_data.get("credential_id") == credential_id_bytes
):
return {"error": "Cannot delete current session credential"} return {"error": "Cannot delete current session credential"}
# Get user's remaining credentials count # Get user's remaining credentials count
remaining_credentials = await db.get_credentials_by_user_id( remaining_credentials = await db.get_user_credentials(user.user_id)
user.user_id.bytes
)
if len(remaining_credentials) <= 1: if len(remaining_credentials) <= 1:
return {"error": "Cannot delete last remaining credential"} return {"error": "Cannot delete last remaining credential"}
# Delete the credential # Delete the credential
await db.delete_credential(credential_id_bytes) await db.delete_user_credential(credential_id_bytes)
return {"status": "success", "message": "Credential deleted successfully"} return {"status": "success", "message": "Credential deleted successfully"}

View File

@ -1,7 +1,7 @@
""" """
Async database implementation for WebAuthn passkey authentication. Async database implementation for WebAuthn passkey authentication.
This module provides an async database layer using dataclasses and aiosqlite This module provides an async database layer using SQLAlchemy async mode
for managing users and credentials in a WebAuthn authentication system. for managing users and credentials in a WebAuthn authentication system.
""" """
@ -10,66 +10,60 @@ from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from uuid import UUID from uuid import UUID
import aiosqlite from sqlalchemy import (
DateTime,
ForeignKey,
Integer,
LargeBinary,
String,
delete,
select,
update,
)
from sqlalchemy.dialects.sqlite import BLOB
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from .passkey import StoredCredential from .passkey import StoredCredential
DB_PATH = "webauthn.db" DB_PATH = "sqlite+aiosqlite:///webauthn.db"
# SQL Statements
SQL_CREATE_USERS = """ # SQLAlchemy Models
CREATE TABLE IF NOT EXISTS users ( class Base(DeclarativeBase):
user_id BINARY(16) PRIMARY KEY NOT NULL, pass
user_name TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_seen TIMESTAMP NULL class UserModel(Base):
__tablename__ = "users"
user_id: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
user_name: Mapped[str] = mapped_column(String, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
last_seen: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
# Relationship to credentials
credentials: Mapped[list["CredentialModel"]] = relationship(
"CredentialModel", back_populates="user", cascade="all, delete-orphan"
) )
"""
SQL_CREATE_CREDENTIALS = """
CREATE TABLE IF NOT EXISTS credentials ( class CredentialModel(Base):
credential_id BINARY(64) PRIMARY KEY NOT NULL, __tablename__ = "credentials"
user_id BINARY(16) NOT NULL,
aaguid BINARY(16) NOT NULL, credential_id: Mapped[bytes] = mapped_column(LargeBinary(64), primary_key=True)
public_key BLOB NOT NULL, user_id: Mapped[bytes] = mapped_column(
sign_count INTEGER NOT NULL, LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_used TIMESTAMP NULL,
last_verified TIMESTAMP NULL,
FOREIGN KEY (user_id) REFERENCES users (user_id) ON DELETE CASCADE
) )
""" aaguid: Mapped[bytes] = mapped_column(LargeBinary(16), nullable=False)
public_key: Mapped[bytes] = mapped_column(BLOB, nullable=False)
sign_count: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
last_verified: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
SQL_GET_USER_BY_USER_ID = """ # Relationship to user
SELECT * FROM users WHERE user_id = ? user: Mapped["UserModel"] = relationship("UserModel", back_populates="credentials")
"""
SQL_CREATE_USER = """
INSERT INTO users (user_id, user_name, created_at, last_seen) VALUES (?, ?, ?, ?)
"""
SQL_STORE_CREDENTIAL = """
INSERT INTO credentials (credential_id, user_id, aaguid, public_key, sign_count, created_at, last_used, last_verified)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
"""
SQL_GET_CREDENTIAL_BY_ID = """
SELECT * FROM credentials WHERE credential_id = ?
"""
SQL_GET_USER_CREDENTIALS = """
SELECT credential_id FROM credentials WHERE user_id = ?
"""
SQL_UPDATE_CREDENTIAL = """
UPDATE credentials
SET sign_count = ?, created_at = ?, last_used = ?, last_verified = ?
WHERE credential_id = ?
"""
SQL_DELETE_CREDENTIAL = """
DELETE FROM credentials WHERE credential_id = ?
"""
@dataclass @dataclass
@ -80,121 +74,181 @@ class User:
last_seen: datetime | None = None last_seen: datetime | None = None
# Global engine and session factory
engine = create_async_engine(DB_PATH, echo=False)
async_session_factory = async_sessionmaker(engine, expire_on_commit=False)
@asynccontextmanager @asynccontextmanager
async def connect(): async def connect():
conn = await aiosqlite.connect(DB_PATH) """Context manager for database connections."""
try: async with async_session_factory() as session:
yield DB(conn) yield DB(session)
await conn.commit() await session.commit()
finally:
await conn.close()
class DB: class DB:
def __init__(self, conn: aiosqlite.Connection): def __init__(self, session: AsyncSession):
self.conn = conn self.session = session
async def init_db(self) -> None: async def init_db(self) -> None:
"""Initialize database tables.""" """Initialize database tables."""
await self.conn.execute(SQL_CREATE_USERS) async with engine.begin() as conn:
await self.conn.execute(SQL_CREATE_CREDENTIALS) await conn.run_sync(Base.metadata.create_all)
await self.conn.commit()
# Database operation functions that work with a connection async def get_user_by_user_id(self, user_id: UUID) -> User:
async def get_user_by_user_id(self, user_id: bytes) -> User:
"""Get user record by WebAuthn user ID.""" """Get user record by WebAuthn user ID."""
async with self.conn.execute(SQL_GET_USER_BY_USER_ID, (user_id,)) as cursor: stmt = select(UserModel).where(UserModel.user_id == user_id.bytes)
row = await cursor.fetchone() result = await self.session.execute(stmt)
if row: user_model = result.scalar_one_or_none()
if user_model:
return User( return User(
user_id=UUID(bytes=row[0]), user_id=UUID(bytes=user_model.user_id),
user_name=row[1], user_name=user_model.user_name,
created_at=_convert_datetime(row[2]), created_at=user_model.created_at,
last_seen=_convert_datetime(row[3]), last_seen=user_model.last_seen,
) )
raise ValueError("User not found") raise ValueError("User not found")
async def create_user(self, user: User) -> None: async def create_user(self, user: User) -> None:
"""Create a new user and return the User dataclass.""" """Create a new user."""
await self.conn.execute( user_model = UserModel(
SQL_CREATE_USER, user_id=user.user_id.bytes,
( user_name=user.user_name,
user.user_id.bytes, created_at=user.created_at or datetime.now(),
user.user_name, last_seen=user.last_seen,
user.created_at or datetime.now(),
user.last_seen,
),
) )
self.session.add(user_model)
await self.session.flush()
async def create_credential(self, credential: StoredCredential) -> None: async def create_credential(self, credential: StoredCredential) -> None:
"""Store a credential for a user.""" """Store a credential for a user."""
await self.conn.execute( credential_model = CredentialModel(
SQL_STORE_CREDENTIAL, credential_id=credential.credential_id,
( user_id=credential.user_id.bytes,
credential.credential_id, aaguid=credential.aaguid.bytes,
credential.user_id.bytes, public_key=credential.public_key,
credential.aaguid.bytes, sign_count=credential.sign_count,
credential.public_key, created_at=credential.created_at,
credential.sign_count, last_used=credential.last_used,
credential.created_at, last_verified=credential.last_verified,
credential.last_used,
credential.last_verified,
),
) )
self.session.add(credential_model)
await self.session.flush()
async def get_credential_by_id(self, credential_id: bytes) -> StoredCredential: async def get_credential_by_id(self, credential_id: bytes) -> StoredCredential:
"""Get credential by credential ID.""" """Get credential by credential ID."""
async with self.conn.execute( stmt = select(CredentialModel).where(
SQL_GET_CREDENTIAL_BY_ID, (credential_id,) CredentialModel.credential_id == credential_id
) as cursor: )
row = await cursor.fetchone() result = await self.session.execute(stmt)
if row: credential_model = result.scalar_one_or_none()
if credential_model:
return StoredCredential( return StoredCredential(
credential_id=row[0], credential_id=credential_model.credential_id,
user_id=UUID(bytes=row[1]), user_id=UUID(bytes=credential_model.user_id),
aaguid=UUID(bytes=row[2]), aaguid=UUID(bytes=credential_model.aaguid),
public_key=row[3], public_key=credential_model.public_key,
sign_count=row[4], sign_count=credential_model.sign_count,
created_at=datetime.fromisoformat(row[5]), created_at=credential_model.created_at,
last_used=_convert_datetime(row[6]), last_used=credential_model.last_used,
last_verified=_convert_datetime(row[7]), last_verified=credential_model.last_verified,
) )
raise ValueError("Credential not registered") raise ValueError("Credential not registered")
async def get_credentials_by_user_id(self, user_id: bytes) -> list[bytes]: async def get_credentials_by_user_id(self, user_id: UUID) -> list[bytes]:
"""Get all credential IDs for a user.""" """Get all credential IDs for a user."""
async with self.conn.execute(SQL_GET_USER_CREDENTIALS, (user_id,)) as cursor: stmt = select(CredentialModel.credential_id).where(
rows = await cursor.fetchall() CredentialModel.user_id == user_id.bytes
return [row[0] for row in rows] )
result = await self.session.execute(stmt)
return [row[0] for row in result.fetchall()]
async def update_credential(self, credential: StoredCredential) -> None: async def update_credential(self, credential: StoredCredential) -> None:
"""Update the sign count, created_at, last_used, and last_verified for a credential.""" """Update the sign count, created_at, last_used, and last_verified for a credential."""
await self.conn.execute( stmt = (
SQL_UPDATE_CREDENTIAL, update(CredentialModel)
( .where(CredentialModel.credential_id == credential.credential_id)
credential.sign_count, .values(
credential.created_at, sign_count=credential.sign_count,
credential.last_used, created_at=credential.created_at,
credential.last_verified, last_used=credential.last_used,
credential.credential_id, last_verified=credential.last_verified,
),
) )
)
await self.session.execute(stmt)
async def login(self, user_id: bytes, credential: StoredCredential) -> None: async def login(self, user_id: UUID, credential: StoredCredential) -> None:
"""Update the last_seen timestamp for a user and the credential record used for logging in.""" """Update the last_seen timestamp for a user and the credential record used for logging in."""
await self.conn.execute("BEGIN") async with self.session.begin():
# Update credential
await self.update_credential(credential) await self.update_credential(credential)
await self.conn.execute(
"UPDATE users SET last_seen = ? WHERE user_id = ?", # Update user's last_seen
(credential.last_used, user_id), stmt = (
update(UserModel)
.where(UserModel.user_id == user_id.bytes)
.values(last_seen=credential.last_used)
) )
await self.session.execute(stmt)
async def delete_credential(self, credential_id: bytes) -> None: async def delete_credential(self, credential_id: bytes) -> None:
"""Delete a credential by its ID.""" """Delete a credential by its ID."""
await self.conn.execute(SQL_DELETE_CREDENTIAL, (credential_id,)) stmt = delete(CredentialModel).where(
await self.conn.commit() CredentialModel.credential_id == credential_id
)
await self.session.execute(stmt)
await self.session.commit()
def _convert_datetime(val): # Standalone functions that handle database connections internally
"""Convert string from SQLite to datetime object (pass through None).""" async def init_database() -> None:
return val and datetime.fromisoformat(val) """Initialize database tables."""
async with connect() as db:
await db.init_db()
async def create_user_and_credential(user: User, credential: StoredCredential) -> None:
"""Create a new user and their first credential in a single transaction."""
async with connect() as db:
await db.session.begin()
await db.create_user(user)
await db.create_credential(credential)
async def get_user_by_id(user_id: UUID) -> User:
"""Get user record by WebAuthn user ID."""
async with connect() as db:
return await db.get_user_by_user_id(user_id)
async def create_credential_for_user(credential: StoredCredential) -> None:
"""Store a credential for an existing user."""
async with connect() as db:
await db.create_credential(credential)
async def get_credential_by_id(credential_id: bytes) -> StoredCredential:
"""Get credential by credential ID."""
async with connect() as db:
return await db.get_credential_by_id(credential_id)
async def get_user_credentials(user_id: UUID) -> list[bytes]:
"""Get all credential IDs for a user."""
async with connect() as db:
return await db.get_credentials_by_user_id(user_id)
async def login_user(user_id: UUID, credential: StoredCredential) -> None:
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
async with connect() as db:
await db.login(user_id, credential)
async def delete_user_credential(credential_id: bytes) -> None:
"""Delete a credential by its ID."""
async with connect() as db:
await db.delete_credential(credential_id)

View File

@ -18,6 +18,7 @@ from fastapi import FastAPI, Request, Response, WebSocket, WebSocketDisconnect
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from . import db
from .api_handlers import ( from .api_handlers import (
delete_credential, delete_credential,
get_user_credentials, get_user_credentials,
@ -27,7 +28,7 @@ from .api_handlers import (
set_session, set_session,
validate_token, validate_token,
) )
from .db import User, connect from .db import User
from .jwt_manager import create_session_token from .jwt_manager import create_session_token
from .passkey import Passkey from .passkey import Passkey
from .session_manager import get_user_from_cookie_string from .session_manager import get_user_from_cookie_string
@ -44,8 +45,7 @@ passkey = Passkey(
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
async with connect() as db: await db.init_database()
await db.init_db()
yield yield
@ -65,11 +65,11 @@ async def websocket_register_new(ws: WebSocket):
# WebAuthn registration # WebAuthn registration
credential = await register_chat(ws, user_id, user_name) credential = await register_chat(ws, user_id, user_name)
# Store the user in the database # Store the user and credential in the database
async with connect() as db: await db.create_user_and_credential(
await db.conn.execute("BEGIN") User(user_id, user_name, created_at=datetime.now()),
await db.create_user(User(user_id, user_name, created_at=datetime.now())) credential,
await db.create_credential(credential) )
# Create a session token for the new user # Create a session token for the new user
session_token = create_session_token(user_id, credential.credential_id) session_token = create_session_token(user_id, credential.credential_id)
@ -101,16 +101,15 @@ async def websocket_register_add(ws: WebSocket):
return return
# Get user information to get the user_name # Get user information to get the user_name
async with connect() as db: user = await db.get_user_by_id(user_id)
user = await db.get_user_by_user_id(user_id.bytes)
user_name = user.user_name user_name = user.user_name
challenge_ids = await db.get_user_credentials(user_id)
# WebAuthn registration # WebAuthn registration
credential = await register_chat(ws, user_id, user_name) credential = await register_chat(ws, user_id, user_name, challenge_ids)
print(f"New credential for user {user_id}: {credential}") print(f"New credential for user {user_id}: {credential}")
# Store the new credential in the database # Store the new credential in the database
async with connect() as db: await db.create_credential_for_user(credential)
await db.create_credential(credential)
await ws.send_json( await ws.send_json(
{ {
@ -128,11 +127,17 @@ async def websocket_register_add(ws: WebSocket):
await ws.send_json({"error": f"Server error: {str(e)}"}) await ws.send_json({"error": f"Server error: {str(e)}"})
async def register_chat(ws: WebSocket, user_id: UUID, user_name: str): async def register_chat(
ws: WebSocket,
user_id: UUID,
user_name: str,
credential_ids: list[bytes] | None = None,
):
"""Generate registration options and send them to the client.""" """Generate registration options and send them to the client."""
options, challenge = passkey.reg_generate_options( options, challenge = passkey.reg_generate_options(
user_id=user_id, user_id=user_id,
user_name=user_name, user_name=user_name,
credential_ids=credential_ids,
) )
await ws.send_json(options) await ws.send_json(options)
response = await ws.receive_json() response = await ws.receive_json()
@ -144,17 +149,16 @@ async def register_chat(ws: WebSocket, user_id: UUID, user_name: str):
async def websocket_authenticate(ws: WebSocket): async def websocket_authenticate(ws: WebSocket):
await ws.accept() await ws.accept()
try: try:
options, challenge = await passkey.auth_generate_options() options, challenge = passkey.auth_generate_options()
await ws.send_json(options) await ws.send_json(options)
# Wait for the client to use his authenticator to authenticate # Wait for the client to use his authenticator to authenticate
credential = passkey.auth_parse(await ws.receive_json()) credential = passkey.auth_parse(await ws.receive_json())
async with connect() as db:
# Fetch from the database by credential ID # Fetch from the database by credential ID
stored_cred = await db.get_credential_by_id(credential.raw_id) stored_cred = await db.get_credential_by_id(credential.raw_id)
# Verify the credential matches the stored data # Verify the credential matches the stored data
await passkey.auth_verify(credential, challenge, stored_cred) passkey.auth_verify(credential, challenge, stored_cred)
# Update both credential and user's last_seen timestamp # Update both credential and user's last_seen timestamp
await db.login(stored_cred.user_id.bytes, stored_cred) await db.login_user(stored_cred.user_id, stored_cred)
# Create a session token for the authenticated user # Create a session token for the authenticated user
session_token = create_session_token( session_token = create_session_token(

View File

@ -155,7 +155,7 @@ class Passkey:
### Authentication Methods ### ### Authentication Methods ###
async def auth_generate_options( def auth_generate_options(
self, self,
*, *,
user_verification_required=False, user_verification_required=False,
@ -178,7 +178,7 @@ class Passkey:
user_verification=( user_verification=(
UserVerificationRequirement.REQUIRED UserVerificationRequirement.REQUIRED
if user_verification_required if user_verification_required
else UserVerificationRequirement.PREFERRED else UserVerificationRequirement.DISCOURAGED
), ),
allow_credentials=_convert_credential_ids(credential_ids), allow_credentials=_convert_credential_ids(credential_ids),
**authopts, **authopts,
@ -188,7 +188,7 @@ class Passkey:
def auth_parse(self, response: dict | str) -> AuthenticationCredential: def auth_parse(self, response: dict | str) -> AuthenticationCredential:
return parse_authentication_credential_json(response) return parse_authentication_credential_json(response)
async def auth_verify( def auth_verify(
self, self,
credential: AuthenticationCredential, credential: AuthenticationCredential,
expected_challenge: bytes, expected_challenge: bytes,

View File

@ -12,7 +12,7 @@ from uuid import UUID
from fastapi import Request, Response from fastapi import Request, Response
from .db import User, connect from .db import User, get_user_by_id
from .jwt_manager import validate_session_token from .jwt_manager import validate_session_token
COOKIE_NAME = "session_token" COOKIE_NAME = "session_token"
@ -29,9 +29,8 @@ async def get_current_user(request: Request) -> Optional[User]:
if not token_data: if not token_data:
return None return None
async with connect() as db:
try: try:
user = await db.get_user_by_user_id(token_data["user_id"].bytes) user = await get_user_by_id(token_data["user_id"])
return user return user
except Exception: except Exception:
return None return None

View File

@ -15,6 +15,7 @@ dependencies = [
"websockets>=12.0", "websockets>=12.0",
"webauthn>=1.11.1", "webauthn>=1.11.1",
"base64url>=1.0.0", "base64url>=1.0.0",
"sqlalchemy[asyncio]>=2.0.0",
"aiosqlite>=0.19.0", "aiosqlite>=0.19.0",
"uuid7-standard>=1.0.0", "uuid7-standard>=1.0.0",
"pyjwt>=2.8.0", "pyjwt>=2.8.0",

View File

@ -334,7 +334,7 @@ async function addNewCredential() {
clearStatus('dashboardStatus') clearStatus('dashboardStatus')
} catch (error) { } catch (error) {
showStatus('dashboardStatus', `Failed to add new passkey: ${error.message}`, 'error') showStatus('dashboardStatus', 'Registration cancelled', 'error')
} }
} }