diff --git a/passkeyauth/db.py b/passkeyauth/db.py index 99e4a63..1a0b0aa 100644 --- a/passkeyauth/db.py +++ b/passkeyauth/db.py @@ -5,6 +5,7 @@ This module provides an async database layer using dataclasses and aiosqlite for managing users and credentials in a WebAuthn authentication system. """ +from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import datetime from typing import Optional @@ -35,6 +36,7 @@ SQL_CREATE_CREDENTIALS = """ sign_count INTEGER NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, last_used TIMESTAMP NULL, + last_verified TIMESTAMP NULL, FOREIGN KEY (user_id) REFERENCES users (user_id) ON DELETE CASCADE ) """ @@ -48,12 +50,12 @@ SQL_CREATE_USER = """ """ SQL_STORE_CREDENTIAL = """ - INSERT INTO credentials (credential_id, user_id, aaguid, public_key, sign_count) - VALUES (?, ?, ?, ?, ?) + INSERT INTO credentials (credential_id, user_id, aaguid, public_key, sign_count, created_at, last_used, last_verified) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) """ SQL_GET_CREDENTIAL_BY_ID = """ - SELECT credential_id, user_id, aaguid, public_key, sign_count, created_at, last_used + SELECT credential_id, user_id, aaguid, public_key, sign_count, created_at, last_used, last_verified FROM credentials WHERE credential_id = ? """ @@ -62,7 +64,7 @@ SQL_GET_USER_CREDENTIALS = """ SELECT c.credential_id FROM credentials c JOIN users u ON c.user_id = u.user_id - WHERE u.user_name = ? + WHERE u.user_id = ? """ SQL_UPDATE_CREDENTIAL_SIGN_COUNT = """ @@ -71,6 +73,12 @@ SQL_UPDATE_CREDENTIAL_SIGN_COUNT = """ WHERE credential_id = ? """ +SQL_UPDATE_CREDENTIAL = """ + UPDATE credentials + SET sign_count = ?, created_at = ?, last_used = ?, last_verified = ? + WHERE credential_id = ? +""" + @dataclass class User: @@ -82,104 +90,108 @@ class User: last_seen: Optional[datetime] = None -class Database: - """Async database handler for WebAuthn operations.""" +@asynccontextmanager +async def connect(): + conn = await aiosqlite.connect(DB_PATH) + try: + yield DB(conn) + await conn.commit() + finally: + await conn.close() - def __init__(self, db_path: str = DB_PATH): - self.db_path = db_path - async def init_database(self): - """Initialize the SQLite database with required tables.""" - async with aiosqlite.connect(self.db_path) as conn: - await conn.execute(SQL_CREATE_USERS) - await conn.execute(SQL_CREATE_CREDENTIALS) - await conn.commit() +class DB: + def __init__(self, conn: aiosqlite.Connection): + self.conn = conn + async def init_db(self) -> None: + """Initialize database tables.""" + await self.conn.execute(SQL_CREATE_USERS) + await self.conn.execute(SQL_CREATE_CREDENTIALS) + await self.conn.commit() + + # Database operation functions that work with a connection async def get_user_by_user_id(self, user_id: bytes) -> User: """Get user record by WebAuthn user ID.""" - async with aiosqlite.connect(self.db_path) as conn: - async with conn.execute(SQL_GET_USER_BY_USER_ID, (user_id,)) as cursor: - row = await cursor.fetchone() - if row: - return User( - user_id=row[0], - user_name=row[1], - created_at=row[2], - last_seen=row[3], - ) - raise ValueError("User not found") + async with self.conn.execute(SQL_GET_USER_BY_USER_ID, (user_id,)) as cursor: + row = await cursor.fetchone() + if row: + return User( + user_id=UUID(bytes=row[0]), + user_name=row[1], + created_at=row[2], + last_seen=row[3], + ) + raise ValueError("User not found") async def create_user(self, user: User) -> User: """Create a new user and return the User dataclass.""" - async with aiosqlite.connect(self.db_path) as conn: - await conn.execute( - SQL_CREATE_USER, - (user.user_id, user.user_name, user.created_at, user.last_seen), - ) - await conn.commit() - return user + await self.conn.execute( + SQL_CREATE_USER, + (user.user_id.bytes, user.user_name, user.created_at, user.last_seen), + ) + return user async def store_credential(self, credential: StoredCredential) -> None: """Store a credential for a user.""" - async with aiosqlite.connect(self.db_path) as conn: - await conn.execute( - SQL_STORE_CREDENTIAL, - ( - credential.credential_id, - credential.user_id, - credential.aaguid.bytes, - credential.public_key, - credential.sign_count, - ), - ) - await conn.commit() + await self.conn.execute( + SQL_STORE_CREDENTIAL, + ( + credential.credential_id, + credential.user_id.bytes, + credential.aaguid.bytes, + credential.public_key, + credential.sign_count, + credential.created_at, + credential.last_used, + credential.last_verified, + ), + ) async def get_credential_by_id(self, credential_id: bytes) -> StoredCredential: """Get credential by credential ID.""" - async with aiosqlite.connect(self.db_path) as conn: - async with conn.execute( - SQL_GET_CREDENTIAL_BY_ID, (credential_id,) - ) as cursor: - row = await cursor.fetchone() - if row: - return StoredCredential( - credential_id=row[0], - user_id=UUID(bytes=row[1]), - aaguid=UUID(bytes=row[2]), - public_key=row[3], - sign_count=row[4], - created_at=row[5], - last_used=row[6], - ) - raise ValueError("Credential not found") + async with self.conn.execute( + SQL_GET_CREDENTIAL_BY_ID, (credential_id,) + ) as cursor: + row = await cursor.fetchone() + if row: + return StoredCredential( + credential_id=row[0], + user_id=UUID(bytes=row[1]), + aaguid=UUID(bytes=row[2]), + public_key=row[3], + sign_count=row[4], + created_at=row[5], + last_used=row[6], + last_verified=row[7], + ) + raise ValueError("Credential not found") async def get_credentials_by_user_id(self, user_id: bytes) -> list[bytes]: """Get all credential IDs for a user.""" - async with aiosqlite.connect(self.db_path) as conn: - async with conn.execute(SQL_GET_USER_CREDENTIALS, (user_id,)) as cursor: - rows = await cursor.fetchall() - return [row[0] for row in rows] + async with self.conn.execute(SQL_GET_USER_CREDENTIALS, (user_id,)) as cursor: + rows = await cursor.fetchall() + return [row[0] for row in rows] async def update_credential(self, credential: StoredCredential) -> None: - """Update the sign count for a credential.""" - async with aiosqlite.connect(self.db_path) as conn: - await conn.execute( - SQL_UPDATE_CREDENTIAL_SIGN_COUNT, - (credential.sign_count, credential.credential_id), - ) - await conn.commit() + """Update the sign count, created_at, last_used, and last_verified for a credential.""" + await self.conn.execute( + SQL_UPDATE_CREDENTIAL, + ( + credential.sign_count, + credential.created_at, + credential.last_used, + credential.last_verified, + credential.credential_id, + ), + ) async def login(self, user_id: bytes, credential: StoredCredential) -> None: - """Update the last_seen timestamp for a user.""" - async with aiosqlite.connect(self.db_path) as conn: - # Do these in a single transaction - self.store_credential(credential) - await conn.execute( - "UPDATE users SET last_seen = ? WHERE user_id = ?", - (last_seen, user_id), - ) - await conn.commit() - - -# Global database instance -db = Database() + """Update the last_seen timestamp for a user and the credential record used for logging in.""" + # Update credential + await self.update_credential(credential) + # Update user's last_seen timestamp + await self.conn.execute( + "UPDATE users SET last_seen = ? WHERE user_id = ?", + (credential.last_used, user_id), + ) diff --git a/passkeyauth/passkey.py b/passkeyauth/passkey.py index df55b1f..b0b4ecb 100644 --- a/passkeyauth/passkey.py +++ b/passkeyauth/passkey.py @@ -40,10 +40,12 @@ from webauthn.helpers.structs import ( class StoredCredential: """Credential data stored in the database.""" + # Fields set only at registration time credential_id: bytes user_id: UUID aaguid: UUID public_key: bytes + # Mutable fields that may be updated during authentication sign_count: int created_at: datetime last_used: datetime | None = None