diff --git a/passkeyauth/api_handlers.py b/passkeyauth/api_handlers.py index fef29b0..25d3947 100644 --- a/passkeyauth/api_handlers.py +++ b/passkeyauth/api_handlers.py @@ -10,8 +10,8 @@ This module contains all the HTTP API endpoints for: from fastapi import Request, Response +from . import db from .aaguid_manager import get_aaguid_manager -from .db import connect from .jwt_manager import refresh_session_token, validate_session_token from .session_manager import ( clear_session_cookie, @@ -57,57 +57,54 @@ async def get_user_credentials(request: Request) -> dict: if token_data: current_credential_id = token_data.get("credential_id") - async with connect() as db: - # Get all credentials for the user - credential_ids = await db.get_credentials_by_user_id(user.user_id.bytes) + # Get all credentials for the user + credential_ids = await db.get_user_credentials(user.user_id) - credentials = [] - user_aaguids = set() + credentials = [] + user_aaguids = set() - for cred_id in credential_ids: - try: - stored_cred = await db.get_credential_by_id(cred_id) + for cred_id in credential_ids: + try: + stored_cred = await db.get_credential_by_id(cred_id) - # Convert AAGUID to string format - aaguid_str = str(stored_cred.aaguid) - user_aaguids.add(aaguid_str) + # Convert AAGUID to string format + aaguid_str = str(stored_cred.aaguid) + user_aaguids.add(aaguid_str) - # Check if this is the current session credential - is_current_session = ( - current_credential_id == stored_cred.credential_id - ) + # Check if this is the current session credential + is_current_session = current_credential_id == stored_cred.credential_id - credentials.append( - { - "credential_id": stored_cred.credential_id.hex(), - "aaguid": aaguid_str, - "created_at": stored_cred.created_at.isoformat(), - "last_used": stored_cred.last_used.isoformat() - if stored_cred.last_used - else None, - "last_verified": stored_cred.last_verified.isoformat() - if stored_cred.last_verified - else None, - "sign_count": stored_cred.sign_count, - "is_current_session": is_current_session, - } - ) - except ValueError: - # Skip invalid credentials - continue + credentials.append( + { + "credential_id": stored_cred.credential_id.hex(), + "aaguid": aaguid_str, + "created_at": stored_cred.created_at.isoformat(), + "last_used": stored_cred.last_used.isoformat() + if stored_cred.last_used + else None, + "last_verified": stored_cred.last_verified.isoformat() + if stored_cred.last_verified + else None, + "sign_count": stored_cred.sign_count, + "is_current_session": is_current_session, + } + ) + except ValueError: + # Skip invalid credentials + continue - # Get AAGUID information for only the AAGUIDs that the user has - aaguid_manager = get_aaguid_manager() - aaguid_info = aaguid_manager.get_relevant_aaguids(user_aaguids) + # Get AAGUID information for only the AAGUIDs that the user has + aaguid_manager = get_aaguid_manager() + aaguid_info = aaguid_manager.get_relevant_aaguids(user_aaguids) - # Sort credentials by creation date (earliest first, most recently created last) - credentials.sort(key=lambda cred: cred["created_at"]) + # Sort credentials by creation date (earliest first, most recently created last) + credentials.sort(key=lambda cred: cred["created_at"]) - return { - "status": "success", - "credentials": credentials, - "aaguid_info": aaguid_info, - } + return { + "status": "success", + "credentials": credentials, + "aaguid_info": aaguid_info, + } except Exception as e: return {"error": f"Failed to get credentials: {str(e)}"} @@ -212,36 +209,30 @@ async def delete_credential(request: Request) -> dict: except ValueError: return {"error": "Invalid credential_id format"} - async with connect() as db: - # First, verify the credential belongs to the current user - try: - stored_cred = await db.get_credential_by_id(credential_id_bytes) - if stored_cred.user_id != user.user_id: - return {"error": "Credential not found or access denied"} - except ValueError: - return {"error": "Credential not found"} + # First, verify the credential belongs to the current user + try: + stored_cred = await db.get_credential_by_id(credential_id_bytes) + if stored_cred.user_id != user.user_id: + return {"error": "Credential not found or access denied"} + except ValueError: + return {"error": "Credential not found"} - # Check if this is the current session credential - session_token = get_session_token_from_request(request) - if session_token: - token_data = validate_session_token(session_token) - if ( - token_data - and token_data.get("credential_id") == credential_id_bytes - ): - return {"error": "Cannot delete current session credential"} + # Check if this is the current session credential + session_token = get_session_token_from_request(request) + if session_token: + token_data = validate_session_token(session_token) + if token_data and token_data.get("credential_id") == credential_id_bytes: + return {"error": "Cannot delete current session credential"} - # Get user's remaining credentials count - remaining_credentials = await db.get_credentials_by_user_id( - user.user_id.bytes - ) - if len(remaining_credentials) <= 1: - return {"error": "Cannot delete last remaining credential"} + # Get user's remaining credentials count + remaining_credentials = await db.get_user_credentials(user.user_id) + if len(remaining_credentials) <= 1: + return {"error": "Cannot delete last remaining credential"} - # Delete the credential - await db.delete_credential(credential_id_bytes) + # Delete the credential + await db.delete_user_credential(credential_id_bytes) - return {"status": "success", "message": "Credential deleted successfully"} + return {"status": "success", "message": "Credential deleted successfully"} except Exception as e: return {"error": f"Failed to delete credential: {str(e)}"} diff --git a/passkeyauth/db.py b/passkeyauth/db.py index 3b05bf0..30c2338 100644 --- a/passkeyauth/db.py +++ b/passkeyauth/db.py @@ -1,7 +1,7 @@ """ Async database implementation for WebAuthn passkey authentication. -This module provides an async database layer using dataclasses and aiosqlite +This module provides an async database layer using SQLAlchemy async mode for managing users and credentials in a WebAuthn authentication system. """ @@ -10,66 +10,60 @@ from dataclasses import dataclass from datetime import datetime from uuid import UUID -import aiosqlite +from sqlalchemy import ( + DateTime, + ForeignKey, + Integer, + LargeBinary, + String, + delete, + select, + update, +) +from sqlalchemy.dialects.sqlite import BLOB +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from .passkey import StoredCredential -DB_PATH = "webauthn.db" +DB_PATH = "sqlite+aiosqlite:///webauthn.db" -# SQL Statements -SQL_CREATE_USERS = """ - CREATE TABLE IF NOT EXISTS users ( - user_id BINARY(16) PRIMARY KEY NOT NULL, - user_name TEXT NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - last_seen TIMESTAMP NULL + +# SQLAlchemy Models +class Base(DeclarativeBase): + pass + + +class UserModel(Base): + __tablename__ = "users" + + user_id: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True) + user_name: Mapped[str] = mapped_column(String, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now) + last_seen: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + + # Relationship to credentials + credentials: Mapped[list["CredentialModel"]] = relationship( + "CredentialModel", back_populates="user", cascade="all, delete-orphan" ) -""" -SQL_CREATE_CREDENTIALS = """ - CREATE TABLE IF NOT EXISTS credentials ( - credential_id BINARY(64) PRIMARY KEY NOT NULL, - user_id BINARY(16) NOT NULL, - aaguid BINARY(16) NOT NULL, - public_key BLOB NOT NULL, - sign_count INTEGER NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - last_used TIMESTAMP NULL, - last_verified TIMESTAMP NULL, - FOREIGN KEY (user_id) REFERENCES users (user_id) ON DELETE CASCADE + +class CredentialModel(Base): + __tablename__ = "credentials" + + credential_id: Mapped[bytes] = mapped_column(LargeBinary(64), primary_key=True) + user_id: Mapped[bytes] = mapped_column( + LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE") ) -""" + aaguid: Mapped[bytes] = mapped_column(LargeBinary(16), nullable=False) + public_key: Mapped[bytes] = mapped_column(BLOB, nullable=False) + sign_count: Mapped[int] = mapped_column(Integer, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now) + last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + last_verified: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) -SQL_GET_USER_BY_USER_ID = """ - SELECT * FROM users WHERE user_id = ? -""" - -SQL_CREATE_USER = """ - INSERT INTO users (user_id, user_name, created_at, last_seen) VALUES (?, ?, ?, ?) -""" - -SQL_STORE_CREDENTIAL = """ - INSERT INTO credentials (credential_id, user_id, aaguid, public_key, sign_count, created_at, last_used, last_verified) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) -""" - -SQL_GET_CREDENTIAL_BY_ID = """ - SELECT * FROM credentials WHERE credential_id = ? -""" - -SQL_GET_USER_CREDENTIALS = """ - SELECT credential_id FROM credentials WHERE user_id = ? -""" - -SQL_UPDATE_CREDENTIAL = """ - UPDATE credentials - SET sign_count = ?, created_at = ?, last_used = ?, last_verified = ? - WHERE credential_id = ? -""" - -SQL_DELETE_CREDENTIAL = """ - DELETE FROM credentials WHERE credential_id = ? -""" + # Relationship to user + user: Mapped["UserModel"] = relationship("UserModel", back_populates="credentials") @dataclass @@ -80,121 +74,181 @@ class User: last_seen: datetime | None = None +# Global engine and session factory +engine = create_async_engine(DB_PATH, echo=False) +async_session_factory = async_sessionmaker(engine, expire_on_commit=False) + + @asynccontextmanager async def connect(): - conn = await aiosqlite.connect(DB_PATH) - try: - yield DB(conn) - await conn.commit() - finally: - await conn.close() + """Context manager for database connections.""" + async with async_session_factory() as session: + yield DB(session) + await session.commit() class DB: - def __init__(self, conn: aiosqlite.Connection): - self.conn = conn + def __init__(self, session: AsyncSession): + self.session = session async def init_db(self) -> None: """Initialize database tables.""" - await self.conn.execute(SQL_CREATE_USERS) - await self.conn.execute(SQL_CREATE_CREDENTIALS) - await self.conn.commit() + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) - # Database operation functions that work with a connection - async def get_user_by_user_id(self, user_id: bytes) -> User: + async def get_user_by_user_id(self, user_id: UUID) -> User: """Get user record by WebAuthn user ID.""" - async with self.conn.execute(SQL_GET_USER_BY_USER_ID, (user_id,)) as cursor: - row = await cursor.fetchone() - if row: - return User( - user_id=UUID(bytes=row[0]), - user_name=row[1], - created_at=_convert_datetime(row[2]), - last_seen=_convert_datetime(row[3]), - ) - raise ValueError("User not found") + stmt = select(UserModel).where(UserModel.user_id == user_id.bytes) + result = await self.session.execute(stmt) + user_model = result.scalar_one_or_none() + + if user_model: + return User( + user_id=UUID(bytes=user_model.user_id), + user_name=user_model.user_name, + created_at=user_model.created_at, + last_seen=user_model.last_seen, + ) + raise ValueError("User not found") async def create_user(self, user: User) -> None: - """Create a new user and return the User dataclass.""" - await self.conn.execute( - SQL_CREATE_USER, - ( - user.user_id.bytes, - user.user_name, - user.created_at or datetime.now(), - user.last_seen, - ), + """Create a new user.""" + user_model = UserModel( + user_id=user.user_id.bytes, + user_name=user.user_name, + created_at=user.created_at or datetime.now(), + last_seen=user.last_seen, ) + self.session.add(user_model) + await self.session.flush() async def create_credential(self, credential: StoredCredential) -> None: """Store a credential for a user.""" - await self.conn.execute( - SQL_STORE_CREDENTIAL, - ( - credential.credential_id, - credential.user_id.bytes, - credential.aaguid.bytes, - credential.public_key, - credential.sign_count, - credential.created_at, - credential.last_used, - credential.last_verified, - ), + credential_model = CredentialModel( + credential_id=credential.credential_id, + user_id=credential.user_id.bytes, + aaguid=credential.aaguid.bytes, + public_key=credential.public_key, + sign_count=credential.sign_count, + created_at=credential.created_at, + last_used=credential.last_used, + last_verified=credential.last_verified, ) + self.session.add(credential_model) + await self.session.flush() async def get_credential_by_id(self, credential_id: bytes) -> StoredCredential: """Get credential by credential ID.""" - async with self.conn.execute( - SQL_GET_CREDENTIAL_BY_ID, (credential_id,) - ) as cursor: - row = await cursor.fetchone() - if row: - return StoredCredential( - credential_id=row[0], - user_id=UUID(bytes=row[1]), - aaguid=UUID(bytes=row[2]), - public_key=row[3], - sign_count=row[4], - created_at=datetime.fromisoformat(row[5]), - last_used=_convert_datetime(row[6]), - last_verified=_convert_datetime(row[7]), - ) - raise ValueError("Credential not registered") + stmt = select(CredentialModel).where( + CredentialModel.credential_id == credential_id + ) + result = await self.session.execute(stmt) + credential_model = result.scalar_one_or_none() - async def get_credentials_by_user_id(self, user_id: bytes) -> list[bytes]: + if credential_model: + return StoredCredential( + credential_id=credential_model.credential_id, + user_id=UUID(bytes=credential_model.user_id), + aaguid=UUID(bytes=credential_model.aaguid), + public_key=credential_model.public_key, + sign_count=credential_model.sign_count, + created_at=credential_model.created_at, + last_used=credential_model.last_used, + last_verified=credential_model.last_verified, + ) + raise ValueError("Credential not registered") + + async def get_credentials_by_user_id(self, user_id: UUID) -> list[bytes]: """Get all credential IDs for a user.""" - async with self.conn.execute(SQL_GET_USER_CREDENTIALS, (user_id,)) as cursor: - rows = await cursor.fetchall() - return [row[0] for row in rows] + stmt = select(CredentialModel.credential_id).where( + CredentialModel.user_id == user_id.bytes + ) + result = await self.session.execute(stmt) + return [row[0] for row in result.fetchall()] async def update_credential(self, credential: StoredCredential) -> None: """Update the sign count, created_at, last_used, and last_verified for a credential.""" - await self.conn.execute( - SQL_UPDATE_CREDENTIAL, - ( - credential.sign_count, - credential.created_at, - credential.last_used, - credential.last_verified, - credential.credential_id, - ), + stmt = ( + update(CredentialModel) + .where(CredentialModel.credential_id == credential.credential_id) + .values( + sign_count=credential.sign_count, + created_at=credential.created_at, + last_used=credential.last_used, + last_verified=credential.last_verified, + ) ) + await self.session.execute(stmt) - async def login(self, user_id: bytes, credential: StoredCredential) -> None: + async def login(self, user_id: UUID, credential: StoredCredential) -> None: """Update the last_seen timestamp for a user and the credential record used for logging in.""" - await self.conn.execute("BEGIN") - await self.update_credential(credential) - await self.conn.execute( - "UPDATE users SET last_seen = ? WHERE user_id = ?", - (credential.last_used, user_id), - ) + async with self.session.begin(): + # Update credential + await self.update_credential(credential) + + # Update user's last_seen + stmt = ( + update(UserModel) + .where(UserModel.user_id == user_id.bytes) + .values(last_seen=credential.last_used) + ) + await self.session.execute(stmt) async def delete_credential(self, credential_id: bytes) -> None: """Delete a credential by its ID.""" - await self.conn.execute(SQL_DELETE_CREDENTIAL, (credential_id,)) - await self.conn.commit() + stmt = delete(CredentialModel).where( + CredentialModel.credential_id == credential_id + ) + await self.session.execute(stmt) + await self.session.commit() -def _convert_datetime(val): - """Convert string from SQLite to datetime object (pass through None).""" - return val and datetime.fromisoformat(val) +# Standalone functions that handle database connections internally +async def init_database() -> None: + """Initialize database tables.""" + async with connect() as db: + await db.init_db() + + +async def create_user_and_credential(user: User, credential: StoredCredential) -> None: + """Create a new user and their first credential in a single transaction.""" + async with connect() as db: + await db.session.begin() + await db.create_user(user) + await db.create_credential(credential) + + +async def get_user_by_id(user_id: UUID) -> User: + """Get user record by WebAuthn user ID.""" + async with connect() as db: + return await db.get_user_by_user_id(user_id) + + +async def create_credential_for_user(credential: StoredCredential) -> None: + """Store a credential for an existing user.""" + async with connect() as db: + await db.create_credential(credential) + + +async def get_credential_by_id(credential_id: bytes) -> StoredCredential: + """Get credential by credential ID.""" + async with connect() as db: + return await db.get_credential_by_id(credential_id) + + +async def get_user_credentials(user_id: UUID) -> list[bytes]: + """Get all credential IDs for a user.""" + async with connect() as db: + return await db.get_credentials_by_user_id(user_id) + + +async def login_user(user_id: UUID, credential: StoredCredential) -> None: + """Update the last_seen timestamp for a user and the credential record used for logging in.""" + async with connect() as db: + await db.login(user_id, credential) + + +async def delete_user_credential(credential_id: bytes) -> None: + """Delete a credential by its ID.""" + async with connect() as db: + await db.delete_credential(credential_id) diff --git a/passkeyauth/main.py b/passkeyauth/main.py index 0468519..9625b2f 100644 --- a/passkeyauth/main.py +++ b/passkeyauth/main.py @@ -18,6 +18,7 @@ from fastapi import FastAPI, Request, Response, WebSocket, WebSocketDisconnect from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles +from . import db from .api_handlers import ( delete_credential, get_user_credentials, @@ -27,7 +28,7 @@ from .api_handlers import ( set_session, validate_token, ) -from .db import User, connect +from .db import User from .jwt_manager import create_session_token from .passkey import Passkey from .session_manager import get_user_from_cookie_string @@ -44,8 +45,7 @@ passkey = Passkey( @asynccontextmanager async def lifespan(app: FastAPI): - async with connect() as db: - await db.init_db() + await db.init_database() yield @@ -65,11 +65,11 @@ async def websocket_register_new(ws: WebSocket): # WebAuthn registration credential = await register_chat(ws, user_id, user_name) - # Store the user in the database - async with connect() as db: - await db.conn.execute("BEGIN") - await db.create_user(User(user_id, user_name, created_at=datetime.now())) - await db.create_credential(credential) + # Store the user and credential in the database + await db.create_user_and_credential( + User(user_id, user_name, created_at=datetime.now()), + credential, + ) # Create a session token for the new user session_token = create_session_token(user_id, credential.credential_id) @@ -101,16 +101,15 @@ async def websocket_register_add(ws: WebSocket): return # Get user information to get the user_name - async with connect() as db: - user = await db.get_user_by_user_id(user_id.bytes) - user_name = user.user_name + user = await db.get_user_by_id(user_id) + user_name = user.user_name + challenge_ids = await db.get_user_credentials(user_id) # WebAuthn registration - credential = await register_chat(ws, user_id, user_name) + credential = await register_chat(ws, user_id, user_name, challenge_ids) print(f"New credential for user {user_id}: {credential}") # Store the new credential in the database - async with connect() as db: - await db.create_credential(credential) + await db.create_credential_for_user(credential) await ws.send_json( { @@ -128,11 +127,17 @@ async def websocket_register_add(ws: WebSocket): await ws.send_json({"error": f"Server error: {str(e)}"}) -async def register_chat(ws: WebSocket, user_id: UUID, user_name: str): +async def register_chat( + ws: WebSocket, + user_id: UUID, + user_name: str, + credential_ids: list[bytes] | None = None, +): """Generate registration options and send them to the client.""" options, challenge = passkey.reg_generate_options( user_id=user_id, user_name=user_name, + credential_ids=credential_ids, ) await ws.send_json(options) response = await ws.receive_json() @@ -144,17 +149,16 @@ async def register_chat(ws: WebSocket, user_id: UUID, user_name: str): async def websocket_authenticate(ws: WebSocket): await ws.accept() try: - options, challenge = await passkey.auth_generate_options() + options, challenge = passkey.auth_generate_options() await ws.send_json(options) # Wait for the client to use his authenticator to authenticate credential = passkey.auth_parse(await ws.receive_json()) - async with connect() as db: - # Fetch from the database by credential ID - stored_cred = await db.get_credential_by_id(credential.raw_id) - # Verify the credential matches the stored data - await passkey.auth_verify(credential, challenge, stored_cred) - # Update both credential and user's last_seen timestamp - await db.login(stored_cred.user_id.bytes, stored_cred) + # Fetch from the database by credential ID + stored_cred = await db.get_credential_by_id(credential.raw_id) + # Verify the credential matches the stored data + passkey.auth_verify(credential, challenge, stored_cred) + # Update both credential and user's last_seen timestamp + await db.login_user(stored_cred.user_id, stored_cred) # Create a session token for the authenticated user session_token = create_session_token( diff --git a/passkeyauth/passkey.py b/passkeyauth/passkey.py index 04d2ce1..fdb56b4 100644 --- a/passkeyauth/passkey.py +++ b/passkeyauth/passkey.py @@ -155,7 +155,7 @@ class Passkey: ### Authentication Methods ### - async def auth_generate_options( + def auth_generate_options( self, *, user_verification_required=False, @@ -178,7 +178,7 @@ class Passkey: user_verification=( UserVerificationRequirement.REQUIRED if user_verification_required - else UserVerificationRequirement.PREFERRED + else UserVerificationRequirement.DISCOURAGED ), allow_credentials=_convert_credential_ids(credential_ids), **authopts, @@ -188,7 +188,7 @@ class Passkey: def auth_parse(self, response: dict | str) -> AuthenticationCredential: return parse_authentication_credential_json(response) - async def auth_verify( + def auth_verify( self, credential: AuthenticationCredential, expected_challenge: bytes, diff --git a/passkeyauth/session_manager.py b/passkeyauth/session_manager.py index 3cf8287..af75281 100644 --- a/passkeyauth/session_manager.py +++ b/passkeyauth/session_manager.py @@ -12,7 +12,7 @@ from uuid import UUID from fastapi import Request, Response -from .db import User, connect +from .db import User, get_user_by_id from .jwt_manager import validate_session_token COOKIE_NAME = "session_token" @@ -29,12 +29,11 @@ async def get_current_user(request: Request) -> Optional[User]: if not token_data: return None - async with connect() as db: - try: - user = await db.get_user_by_user_id(token_data["user_id"].bytes) - return user - except Exception: - return None + try: + user = await get_user_by_id(token_data["user_id"]) + return user + except Exception: + return None def set_session_cookie(response: Response, session_token: str) -> None: diff --git a/pyproject.toml b/pyproject.toml index ebbf5b1..ffe6947 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "websockets>=12.0", "webauthn>=1.11.1", "base64url>=1.0.0", + "sqlalchemy[asyncio]>=2.0.0", "aiosqlite>=0.19.0", "uuid7-standard>=1.0.0", "pyjwt>=2.8.0", diff --git a/static/app.js b/static/app.js index 913fd19..f83a8fd 100644 --- a/static/app.js +++ b/static/app.js @@ -334,7 +334,7 @@ async function addNewCredential() { clearStatus('dashboardStatus') } catch (error) { - showStatus('dashboardStatus', `Failed to add new passkey: ${error.message}`, 'error') + showStatus('dashboardStatus', 'Registration cancelled', 'error') } }