More refactoring. Prevent registering another key on the same authenticator for the same user.

This commit is contained in:
Leo Vasanko 2025-07-07 11:20:28 -06:00
parent eb56c000e8
commit 1c9044054a
7 changed files with 291 additions and 242 deletions

View File

@ -10,8 +10,8 @@ This module contains all the HTTP API endpoints for:
from fastapi import Request, Response
from . import db
from .aaguid_manager import get_aaguid_manager
from .db import connect
from .jwt_manager import refresh_session_token, validate_session_token
from .session_manager import (
clear_session_cookie,
@ -57,57 +57,54 @@ async def get_user_credentials(request: Request) -> dict:
if token_data:
current_credential_id = token_data.get("credential_id")
async with connect() as db:
# Get all credentials for the user
credential_ids = await db.get_credentials_by_user_id(user.user_id.bytes)
# Get all credentials for the user
credential_ids = await db.get_user_credentials(user.user_id)
credentials = []
user_aaguids = set()
credentials = []
user_aaguids = set()
for cred_id in credential_ids:
try:
stored_cred = await db.get_credential_by_id(cred_id)
for cred_id in credential_ids:
try:
stored_cred = await db.get_credential_by_id(cred_id)
# Convert AAGUID to string format
aaguid_str = str(stored_cred.aaguid)
user_aaguids.add(aaguid_str)
# Convert AAGUID to string format
aaguid_str = str(stored_cred.aaguid)
user_aaguids.add(aaguid_str)
# Check if this is the current session credential
is_current_session = (
current_credential_id == stored_cred.credential_id
)
# Check if this is the current session credential
is_current_session = current_credential_id == stored_cred.credential_id
credentials.append(
{
"credential_id": stored_cred.credential_id.hex(),
"aaguid": aaguid_str,
"created_at": stored_cred.created_at.isoformat(),
"last_used": stored_cred.last_used.isoformat()
if stored_cred.last_used
else None,
"last_verified": stored_cred.last_verified.isoformat()
if stored_cred.last_verified
else None,
"sign_count": stored_cred.sign_count,
"is_current_session": is_current_session,
}
)
except ValueError:
# Skip invalid credentials
continue
credentials.append(
{
"credential_id": stored_cred.credential_id.hex(),
"aaguid": aaguid_str,
"created_at": stored_cred.created_at.isoformat(),
"last_used": stored_cred.last_used.isoformat()
if stored_cred.last_used
else None,
"last_verified": stored_cred.last_verified.isoformat()
if stored_cred.last_verified
else None,
"sign_count": stored_cred.sign_count,
"is_current_session": is_current_session,
}
)
except ValueError:
# Skip invalid credentials
continue
# Get AAGUID information for only the AAGUIDs that the user has
aaguid_manager = get_aaguid_manager()
aaguid_info = aaguid_manager.get_relevant_aaguids(user_aaguids)
# Get AAGUID information for only the AAGUIDs that the user has
aaguid_manager = get_aaguid_manager()
aaguid_info = aaguid_manager.get_relevant_aaguids(user_aaguids)
# Sort credentials by creation date (earliest first, most recently created last)
credentials.sort(key=lambda cred: cred["created_at"])
# Sort credentials by creation date (earliest first, most recently created last)
credentials.sort(key=lambda cred: cred["created_at"])
return {
"status": "success",
"credentials": credentials,
"aaguid_info": aaguid_info,
}
return {
"status": "success",
"credentials": credentials,
"aaguid_info": aaguid_info,
}
except Exception as e:
return {"error": f"Failed to get credentials: {str(e)}"}
@ -212,36 +209,30 @@ async def delete_credential(request: Request) -> dict:
except ValueError:
return {"error": "Invalid credential_id format"}
async with connect() as db:
# First, verify the credential belongs to the current user
try:
stored_cred = await db.get_credential_by_id(credential_id_bytes)
if stored_cred.user_id != user.user_id:
return {"error": "Credential not found or access denied"}
except ValueError:
return {"error": "Credential not found"}
# First, verify the credential belongs to the current user
try:
stored_cred = await db.get_credential_by_id(credential_id_bytes)
if stored_cred.user_id != user.user_id:
return {"error": "Credential not found or access denied"}
except ValueError:
return {"error": "Credential not found"}
# Check if this is the current session credential
session_token = get_session_token_from_request(request)
if session_token:
token_data = validate_session_token(session_token)
if (
token_data
and token_data.get("credential_id") == credential_id_bytes
):
return {"error": "Cannot delete current session credential"}
# Check if this is the current session credential
session_token = get_session_token_from_request(request)
if session_token:
token_data = validate_session_token(session_token)
if token_data and token_data.get("credential_id") == credential_id_bytes:
return {"error": "Cannot delete current session credential"}
# Get user's remaining credentials count
remaining_credentials = await db.get_credentials_by_user_id(
user.user_id.bytes
)
if len(remaining_credentials) <= 1:
return {"error": "Cannot delete last remaining credential"}
# Get user's remaining credentials count
remaining_credentials = await db.get_user_credentials(user.user_id)
if len(remaining_credentials) <= 1:
return {"error": "Cannot delete last remaining credential"}
# Delete the credential
await db.delete_credential(credential_id_bytes)
# Delete the credential
await db.delete_user_credential(credential_id_bytes)
return {"status": "success", "message": "Credential deleted successfully"}
return {"status": "success", "message": "Credential deleted successfully"}
except Exception as e:
return {"error": f"Failed to delete credential: {str(e)}"}

View File

@ -1,7 +1,7 @@
"""
Async database implementation for WebAuthn passkey authentication.
This module provides an async database layer using dataclasses and aiosqlite
This module provides an async database layer using SQLAlchemy async mode
for managing users and credentials in a WebAuthn authentication system.
"""
@ -10,66 +10,60 @@ from dataclasses import dataclass
from datetime import datetime
from uuid import UUID
import aiosqlite
from sqlalchemy import (
DateTime,
ForeignKey,
Integer,
LargeBinary,
String,
delete,
select,
update,
)
from sqlalchemy.dialects.sqlite import BLOB
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from .passkey import StoredCredential
DB_PATH = "webauthn.db"
DB_PATH = "sqlite+aiosqlite:///webauthn.db"
# SQL Statements
SQL_CREATE_USERS = """
CREATE TABLE IF NOT EXISTS users (
user_id BINARY(16) PRIMARY KEY NOT NULL,
user_name TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_seen TIMESTAMP NULL
# SQLAlchemy Models
class Base(DeclarativeBase):
pass
class UserModel(Base):
__tablename__ = "users"
user_id: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
user_name: Mapped[str] = mapped_column(String, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
last_seen: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
# Relationship to credentials
credentials: Mapped[list["CredentialModel"]] = relationship(
"CredentialModel", back_populates="user", cascade="all, delete-orphan"
)
"""
SQL_CREATE_CREDENTIALS = """
CREATE TABLE IF NOT EXISTS credentials (
credential_id BINARY(64) PRIMARY KEY NOT NULL,
user_id BINARY(16) NOT NULL,
aaguid BINARY(16) NOT NULL,
public_key BLOB NOT NULL,
sign_count INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_used TIMESTAMP NULL,
last_verified TIMESTAMP NULL,
FOREIGN KEY (user_id) REFERENCES users (user_id) ON DELETE CASCADE
class CredentialModel(Base):
__tablename__ = "credentials"
credential_id: Mapped[bytes] = mapped_column(LargeBinary(64), primary_key=True)
user_id: Mapped[bytes] = mapped_column(
LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
)
"""
aaguid: Mapped[bytes] = mapped_column(LargeBinary(16), nullable=False)
public_key: Mapped[bytes] = mapped_column(BLOB, nullable=False)
sign_count: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
last_verified: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
SQL_GET_USER_BY_USER_ID = """
SELECT * FROM users WHERE user_id = ?
"""
SQL_CREATE_USER = """
INSERT INTO users (user_id, user_name, created_at, last_seen) VALUES (?, ?, ?, ?)
"""
SQL_STORE_CREDENTIAL = """
INSERT INTO credentials (credential_id, user_id, aaguid, public_key, sign_count, created_at, last_used, last_verified)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
"""
SQL_GET_CREDENTIAL_BY_ID = """
SELECT * FROM credentials WHERE credential_id = ?
"""
SQL_GET_USER_CREDENTIALS = """
SELECT credential_id FROM credentials WHERE user_id = ?
"""
SQL_UPDATE_CREDENTIAL = """
UPDATE credentials
SET sign_count = ?, created_at = ?, last_used = ?, last_verified = ?
WHERE credential_id = ?
"""
SQL_DELETE_CREDENTIAL = """
DELETE FROM credentials WHERE credential_id = ?
"""
# Relationship to user
user: Mapped["UserModel"] = relationship("UserModel", back_populates="credentials")
@dataclass
@ -80,121 +74,181 @@ class User:
last_seen: datetime | None = None
# Global engine and session factory
engine = create_async_engine(DB_PATH, echo=False)
async_session_factory = async_sessionmaker(engine, expire_on_commit=False)
@asynccontextmanager
async def connect():
conn = await aiosqlite.connect(DB_PATH)
try:
yield DB(conn)
await conn.commit()
finally:
await conn.close()
"""Context manager for database connections."""
async with async_session_factory() as session:
yield DB(session)
await session.commit()
class DB:
def __init__(self, conn: aiosqlite.Connection):
self.conn = conn
def __init__(self, session: AsyncSession):
self.session = session
async def init_db(self) -> None:
"""Initialize database tables."""
await self.conn.execute(SQL_CREATE_USERS)
await self.conn.execute(SQL_CREATE_CREDENTIALS)
await self.conn.commit()
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Database operation functions that work with a connection
async def get_user_by_user_id(self, user_id: bytes) -> User:
async def get_user_by_user_id(self, user_id: UUID) -> User:
"""Get user record by WebAuthn user ID."""
async with self.conn.execute(SQL_GET_USER_BY_USER_ID, (user_id,)) as cursor:
row = await cursor.fetchone()
if row:
return User(
user_id=UUID(bytes=row[0]),
user_name=row[1],
created_at=_convert_datetime(row[2]),
last_seen=_convert_datetime(row[3]),
)
raise ValueError("User not found")
stmt = select(UserModel).where(UserModel.user_id == user_id.bytes)
result = await self.session.execute(stmt)
user_model = result.scalar_one_or_none()
if user_model:
return User(
user_id=UUID(bytes=user_model.user_id),
user_name=user_model.user_name,
created_at=user_model.created_at,
last_seen=user_model.last_seen,
)
raise ValueError("User not found")
async def create_user(self, user: User) -> None:
"""Create a new user and return the User dataclass."""
await self.conn.execute(
SQL_CREATE_USER,
(
user.user_id.bytes,
user.user_name,
user.created_at or datetime.now(),
user.last_seen,
),
"""Create a new user."""
user_model = UserModel(
user_id=user.user_id.bytes,
user_name=user.user_name,
created_at=user.created_at or datetime.now(),
last_seen=user.last_seen,
)
self.session.add(user_model)
await self.session.flush()
async def create_credential(self, credential: StoredCredential) -> None:
"""Store a credential for a user."""
await self.conn.execute(
SQL_STORE_CREDENTIAL,
(
credential.credential_id,
credential.user_id.bytes,
credential.aaguid.bytes,
credential.public_key,
credential.sign_count,
credential.created_at,
credential.last_used,
credential.last_verified,
),
credential_model = CredentialModel(
credential_id=credential.credential_id,
user_id=credential.user_id.bytes,
aaguid=credential.aaguid.bytes,
public_key=credential.public_key,
sign_count=credential.sign_count,
created_at=credential.created_at,
last_used=credential.last_used,
last_verified=credential.last_verified,
)
self.session.add(credential_model)
await self.session.flush()
async def get_credential_by_id(self, credential_id: bytes) -> StoredCredential:
"""Get credential by credential ID."""
async with self.conn.execute(
SQL_GET_CREDENTIAL_BY_ID, (credential_id,)
) as cursor:
row = await cursor.fetchone()
if row:
return StoredCredential(
credential_id=row[0],
user_id=UUID(bytes=row[1]),
aaguid=UUID(bytes=row[2]),
public_key=row[3],
sign_count=row[4],
created_at=datetime.fromisoformat(row[5]),
last_used=_convert_datetime(row[6]),
last_verified=_convert_datetime(row[7]),
)
raise ValueError("Credential not registered")
stmt = select(CredentialModel).where(
CredentialModel.credential_id == credential_id
)
result = await self.session.execute(stmt)
credential_model = result.scalar_one_or_none()
async def get_credentials_by_user_id(self, user_id: bytes) -> list[bytes]:
if credential_model:
return StoredCredential(
credential_id=credential_model.credential_id,
user_id=UUID(bytes=credential_model.user_id),
aaguid=UUID(bytes=credential_model.aaguid),
public_key=credential_model.public_key,
sign_count=credential_model.sign_count,
created_at=credential_model.created_at,
last_used=credential_model.last_used,
last_verified=credential_model.last_verified,
)
raise ValueError("Credential not registered")
async def get_credentials_by_user_id(self, user_id: UUID) -> list[bytes]:
"""Get all credential IDs for a user."""
async with self.conn.execute(SQL_GET_USER_CREDENTIALS, (user_id,)) as cursor:
rows = await cursor.fetchall()
return [row[0] for row in rows]
stmt = select(CredentialModel.credential_id).where(
CredentialModel.user_id == user_id.bytes
)
result = await self.session.execute(stmt)
return [row[0] for row in result.fetchall()]
async def update_credential(self, credential: StoredCredential) -> None:
"""Update the sign count, created_at, last_used, and last_verified for a credential."""
await self.conn.execute(
SQL_UPDATE_CREDENTIAL,
(
credential.sign_count,
credential.created_at,
credential.last_used,
credential.last_verified,
credential.credential_id,
),
stmt = (
update(CredentialModel)
.where(CredentialModel.credential_id == credential.credential_id)
.values(
sign_count=credential.sign_count,
created_at=credential.created_at,
last_used=credential.last_used,
last_verified=credential.last_verified,
)
)
await self.session.execute(stmt)
async def login(self, user_id: bytes, credential: StoredCredential) -> None:
async def login(self, user_id: UUID, credential: StoredCredential) -> None:
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
await self.conn.execute("BEGIN")
await self.update_credential(credential)
await self.conn.execute(
"UPDATE users SET last_seen = ? WHERE user_id = ?",
(credential.last_used, user_id),
)
async with self.session.begin():
# Update credential
await self.update_credential(credential)
# Update user's last_seen
stmt = (
update(UserModel)
.where(UserModel.user_id == user_id.bytes)
.values(last_seen=credential.last_used)
)
await self.session.execute(stmt)
async def delete_credential(self, credential_id: bytes) -> None:
"""Delete a credential by its ID."""
await self.conn.execute(SQL_DELETE_CREDENTIAL, (credential_id,))
await self.conn.commit()
stmt = delete(CredentialModel).where(
CredentialModel.credential_id == credential_id
)
await self.session.execute(stmt)
await self.session.commit()
def _convert_datetime(val):
"""Convert string from SQLite to datetime object (pass through None)."""
return val and datetime.fromisoformat(val)
# Standalone functions that handle database connections internally
async def init_database() -> None:
"""Initialize database tables."""
async with connect() as db:
await db.init_db()
async def create_user_and_credential(user: User, credential: StoredCredential) -> None:
"""Create a new user and their first credential in a single transaction."""
async with connect() as db:
await db.session.begin()
await db.create_user(user)
await db.create_credential(credential)
async def get_user_by_id(user_id: UUID) -> User:
"""Get user record by WebAuthn user ID."""
async with connect() as db:
return await db.get_user_by_user_id(user_id)
async def create_credential_for_user(credential: StoredCredential) -> None:
"""Store a credential for an existing user."""
async with connect() as db:
await db.create_credential(credential)
async def get_credential_by_id(credential_id: bytes) -> StoredCredential:
"""Get credential by credential ID."""
async with connect() as db:
return await db.get_credential_by_id(credential_id)
async def get_user_credentials(user_id: UUID) -> list[bytes]:
"""Get all credential IDs for a user."""
async with connect() as db:
return await db.get_credentials_by_user_id(user_id)
async def login_user(user_id: UUID, credential: StoredCredential) -> None:
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
async with connect() as db:
await db.login(user_id, credential)
async def delete_user_credential(credential_id: bytes) -> None:
"""Delete a credential by its ID."""
async with connect() as db:
await db.delete_credential(credential_id)

View File

@ -18,6 +18,7 @@ from fastapi import FastAPI, Request, Response, WebSocket, WebSocketDisconnect
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from . import db
from .api_handlers import (
delete_credential,
get_user_credentials,
@ -27,7 +28,7 @@ from .api_handlers import (
set_session,
validate_token,
)
from .db import User, connect
from .db import User
from .jwt_manager import create_session_token
from .passkey import Passkey
from .session_manager import get_user_from_cookie_string
@ -44,8 +45,7 @@ passkey = Passkey(
@asynccontextmanager
async def lifespan(app: FastAPI):
async with connect() as db:
await db.init_db()
await db.init_database()
yield
@ -65,11 +65,11 @@ async def websocket_register_new(ws: WebSocket):
# WebAuthn registration
credential = await register_chat(ws, user_id, user_name)
# Store the user in the database
async with connect() as db:
await db.conn.execute("BEGIN")
await db.create_user(User(user_id, user_name, created_at=datetime.now()))
await db.create_credential(credential)
# Store the user and credential in the database
await db.create_user_and_credential(
User(user_id, user_name, created_at=datetime.now()),
credential,
)
# Create a session token for the new user
session_token = create_session_token(user_id, credential.credential_id)
@ -101,16 +101,15 @@ async def websocket_register_add(ws: WebSocket):
return
# Get user information to get the user_name
async with connect() as db:
user = await db.get_user_by_user_id(user_id.bytes)
user_name = user.user_name
user = await db.get_user_by_id(user_id)
user_name = user.user_name
challenge_ids = await db.get_user_credentials(user_id)
# WebAuthn registration
credential = await register_chat(ws, user_id, user_name)
credential = await register_chat(ws, user_id, user_name, challenge_ids)
print(f"New credential for user {user_id}: {credential}")
# Store the new credential in the database
async with connect() as db:
await db.create_credential(credential)
await db.create_credential_for_user(credential)
await ws.send_json(
{
@ -128,11 +127,17 @@ async def websocket_register_add(ws: WebSocket):
await ws.send_json({"error": f"Server error: {str(e)}"})
async def register_chat(ws: WebSocket, user_id: UUID, user_name: str):
async def register_chat(
ws: WebSocket,
user_id: UUID,
user_name: str,
credential_ids: list[bytes] | None = None,
):
"""Generate registration options and send them to the client."""
options, challenge = passkey.reg_generate_options(
user_id=user_id,
user_name=user_name,
credential_ids=credential_ids,
)
await ws.send_json(options)
response = await ws.receive_json()
@ -144,17 +149,16 @@ async def register_chat(ws: WebSocket, user_id: UUID, user_name: str):
async def websocket_authenticate(ws: WebSocket):
await ws.accept()
try:
options, challenge = await passkey.auth_generate_options()
options, challenge = passkey.auth_generate_options()
await ws.send_json(options)
# Wait for the client to use his authenticator to authenticate
credential = passkey.auth_parse(await ws.receive_json())
async with connect() as db:
# Fetch from the database by credential ID
stored_cred = await db.get_credential_by_id(credential.raw_id)
# Verify the credential matches the stored data
await passkey.auth_verify(credential, challenge, stored_cred)
# Update both credential and user's last_seen timestamp
await db.login(stored_cred.user_id.bytes, stored_cred)
# Fetch from the database by credential ID
stored_cred = await db.get_credential_by_id(credential.raw_id)
# Verify the credential matches the stored data
passkey.auth_verify(credential, challenge, stored_cred)
# Update both credential and user's last_seen timestamp
await db.login_user(stored_cred.user_id, stored_cred)
# Create a session token for the authenticated user
session_token = create_session_token(

View File

@ -155,7 +155,7 @@ class Passkey:
### Authentication Methods ###
async def auth_generate_options(
def auth_generate_options(
self,
*,
user_verification_required=False,
@ -178,7 +178,7 @@ class Passkey:
user_verification=(
UserVerificationRequirement.REQUIRED
if user_verification_required
else UserVerificationRequirement.PREFERRED
else UserVerificationRequirement.DISCOURAGED
),
allow_credentials=_convert_credential_ids(credential_ids),
**authopts,
@ -188,7 +188,7 @@ class Passkey:
def auth_parse(self, response: dict | str) -> AuthenticationCredential:
return parse_authentication_credential_json(response)
async def auth_verify(
def auth_verify(
self,
credential: AuthenticationCredential,
expected_challenge: bytes,

View File

@ -12,7 +12,7 @@ from uuid import UUID
from fastapi import Request, Response
from .db import User, connect
from .db import User, get_user_by_id
from .jwt_manager import validate_session_token
COOKIE_NAME = "session_token"
@ -29,12 +29,11 @@ async def get_current_user(request: Request) -> Optional[User]:
if not token_data:
return None
async with connect() as db:
try:
user = await db.get_user_by_user_id(token_data["user_id"].bytes)
return user
except Exception:
return None
try:
user = await get_user_by_id(token_data["user_id"])
return user
except Exception:
return None
def set_session_cookie(response: Response, session_token: str) -> None:

View File

@ -15,6 +15,7 @@ dependencies = [
"websockets>=12.0",
"webauthn>=1.11.1",
"base64url>=1.0.0",
"sqlalchemy[asyncio]>=2.0.0",
"aiosqlite>=0.19.0",
"uuid7-standard>=1.0.0",
"pyjwt>=2.8.0",

View File

@ -334,7 +334,7 @@ async function addNewCredential() {
clearStatus('dashboardStatus')
} catch (error) {
showStatus('dashboardStatus', `Failed to add new passkey: ${error.message}`, 'error')
showStatus('dashboardStatus', 'Registration cancelled', 'error')
}
}