Database refactoring

This commit is contained in:
Leo Vasanko 2025-07-06 11:41:49 -06:00
parent 66384da8ce
commit 4f8b5f837c
2 changed files with 98 additions and 84 deletions

View File

@ -5,6 +5,7 @@ This module provides an async database layer using dataclasses and aiosqlite
for managing users and credentials in a WebAuthn authentication system. for managing users and credentials in a WebAuthn authentication system.
""" """
from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
@ -35,6 +36,7 @@ SQL_CREATE_CREDENTIALS = """
sign_count INTEGER NOT NULL, sign_count INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_used TIMESTAMP NULL, last_used TIMESTAMP NULL,
last_verified TIMESTAMP NULL,
FOREIGN KEY (user_id) REFERENCES users (user_id) ON DELETE CASCADE FOREIGN KEY (user_id) REFERENCES users (user_id) ON DELETE CASCADE
) )
""" """
@ -48,12 +50,12 @@ SQL_CREATE_USER = """
""" """
SQL_STORE_CREDENTIAL = """ SQL_STORE_CREDENTIAL = """
INSERT INTO credentials (credential_id, user_id, aaguid, public_key, sign_count) INSERT INTO credentials (credential_id, user_id, aaguid, public_key, sign_count, created_at, last_used, last_verified)
VALUES (?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""" """
SQL_GET_CREDENTIAL_BY_ID = """ SQL_GET_CREDENTIAL_BY_ID = """
SELECT credential_id, user_id, aaguid, public_key, sign_count, created_at, last_used SELECT credential_id, user_id, aaguid, public_key, sign_count, created_at, last_used, last_verified
FROM credentials FROM credentials
WHERE credential_id = ? WHERE credential_id = ?
""" """
@ -62,7 +64,7 @@ SQL_GET_USER_CREDENTIALS = """
SELECT c.credential_id SELECT c.credential_id
FROM credentials c FROM credentials c
JOIN users u ON c.user_id = u.user_id JOIN users u ON c.user_id = u.user_id
WHERE u.user_name = ? WHERE u.user_id = ?
""" """
SQL_UPDATE_CREDENTIAL_SIGN_COUNT = """ SQL_UPDATE_CREDENTIAL_SIGN_COUNT = """
@ -71,6 +73,12 @@ SQL_UPDATE_CREDENTIAL_SIGN_COUNT = """
WHERE credential_id = ? WHERE credential_id = ?
""" """
SQL_UPDATE_CREDENTIAL = """
UPDATE credentials
SET sign_count = ?, created_at = ?, last_used = ?, last_verified = ?
WHERE credential_id = ?
"""
@dataclass @dataclass
class User: class User:
@ -82,27 +90,34 @@ class User:
last_seen: Optional[datetime] = None last_seen: Optional[datetime] = None
class Database: @asynccontextmanager
"""Async database handler for WebAuthn operations.""" async def connect():
conn = await aiosqlite.connect(DB_PATH)
def __init__(self, db_path: str = DB_PATH): try:
self.db_path = db_path yield DB(conn)
async def init_database(self):
"""Initialize the SQLite database with required tables."""
async with aiosqlite.connect(self.db_path) as conn:
await conn.execute(SQL_CREATE_USERS)
await conn.execute(SQL_CREATE_CREDENTIALS)
await conn.commit() await conn.commit()
finally:
await conn.close()
class DB:
def __init__(self, conn: aiosqlite.Connection):
self.conn = conn
async def init_db(self) -> None:
"""Initialize database tables."""
await self.conn.execute(SQL_CREATE_USERS)
await self.conn.execute(SQL_CREATE_CREDENTIALS)
await self.conn.commit()
# Database operation functions that work with a connection
async def get_user_by_user_id(self, user_id: bytes) -> User: async def get_user_by_user_id(self, user_id: bytes) -> User:
"""Get user record by WebAuthn user ID.""" """Get user record by WebAuthn user ID."""
async with aiosqlite.connect(self.db_path) as conn: async with self.conn.execute(SQL_GET_USER_BY_USER_ID, (user_id,)) as cursor:
async with conn.execute(SQL_GET_USER_BY_USER_ID, (user_id,)) as cursor:
row = await cursor.fetchone() row = await cursor.fetchone()
if row: if row:
return User( return User(
user_id=row[0], user_id=UUID(bytes=row[0]),
user_name=row[1], user_name=row[1],
created_at=row[2], created_at=row[2],
last_seen=row[3], last_seen=row[3],
@ -111,33 +126,31 @@ class Database:
async def create_user(self, user: User) -> User: async def create_user(self, user: User) -> User:
"""Create a new user and return the User dataclass.""" """Create a new user and return the User dataclass."""
async with aiosqlite.connect(self.db_path) as conn: await self.conn.execute(
await conn.execute(
SQL_CREATE_USER, SQL_CREATE_USER,
(user.user_id, user.user_name, user.created_at, user.last_seen), (user.user_id.bytes, user.user_name, user.created_at, user.last_seen),
) )
await conn.commit()
return user return user
async def store_credential(self, credential: StoredCredential) -> None: async def store_credential(self, credential: StoredCredential) -> None:
"""Store a credential for a user.""" """Store a credential for a user."""
async with aiosqlite.connect(self.db_path) as conn: await self.conn.execute(
await conn.execute(
SQL_STORE_CREDENTIAL, SQL_STORE_CREDENTIAL,
( (
credential.credential_id, credential.credential_id,
credential.user_id, credential.user_id.bytes,
credential.aaguid.bytes, credential.aaguid.bytes,
credential.public_key, credential.public_key,
credential.sign_count, credential.sign_count,
credential.created_at,
credential.last_used,
credential.last_verified,
), ),
) )
await conn.commit()
async def get_credential_by_id(self, credential_id: bytes) -> StoredCredential: async def get_credential_by_id(self, credential_id: bytes) -> StoredCredential:
"""Get credential by credential ID.""" """Get credential by credential ID."""
async with aiosqlite.connect(self.db_path) as conn: async with self.conn.execute(
async with conn.execute(
SQL_GET_CREDENTIAL_BY_ID, (credential_id,) SQL_GET_CREDENTIAL_BY_ID, (credential_id,)
) as cursor: ) as cursor:
row = await cursor.fetchone() row = await cursor.fetchone()
@ -150,36 +163,35 @@ class Database:
sign_count=row[4], sign_count=row[4],
created_at=row[5], created_at=row[5],
last_used=row[6], last_used=row[6],
last_verified=row[7],
) )
raise ValueError("Credential not found") raise ValueError("Credential not found")
async def get_credentials_by_user_id(self, user_id: bytes) -> list[bytes]: async def get_credentials_by_user_id(self, user_id: bytes) -> list[bytes]:
"""Get all credential IDs for a user.""" """Get all credential IDs for a user."""
async with aiosqlite.connect(self.db_path) as conn: async with self.conn.execute(SQL_GET_USER_CREDENTIALS, (user_id,)) as cursor:
async with conn.execute(SQL_GET_USER_CREDENTIALS, (user_id,)) as cursor:
rows = await cursor.fetchall() rows = await cursor.fetchall()
return [row[0] for row in rows] return [row[0] for row in rows]
async def update_credential(self, credential: StoredCredential) -> None: async def update_credential(self, credential: StoredCredential) -> None:
"""Update the sign count for a credential.""" """Update the sign count, created_at, last_used, and last_verified for a credential."""
async with aiosqlite.connect(self.db_path) as conn: await self.conn.execute(
await conn.execute( SQL_UPDATE_CREDENTIAL,
SQL_UPDATE_CREDENTIAL_SIGN_COUNT, (
(credential.sign_count, credential.credential_id), credential.sign_count,
credential.created_at,
credential.last_used,
credential.last_verified,
credential.credential_id,
),
) )
await conn.commit()
async def login(self, user_id: bytes, credential: StoredCredential) -> None: async def login(self, user_id: bytes, credential: StoredCredential) -> None:
"""Update the last_seen timestamp for a user.""" """Update the last_seen timestamp for a user and the credential record used for logging in."""
async with aiosqlite.connect(self.db_path) as conn: # Update credential
# Do these in a single transaction await self.update_credential(credential)
self.store_credential(credential) # Update user's last_seen timestamp
await conn.execute( await self.conn.execute(
"UPDATE users SET last_seen = ? WHERE user_id = ?", "UPDATE users SET last_seen = ? WHERE user_id = ?",
(last_seen, user_id), (credential.last_used, user_id),
) )
await conn.commit()
# Global database instance
db = Database()

View File

@ -40,10 +40,12 @@ from webauthn.helpers.structs import (
class StoredCredential: class StoredCredential:
"""Credential data stored in the database.""" """Credential data stored in the database."""
# Fields set only at registration time
credential_id: bytes credential_id: bytes
user_id: UUID user_id: UUID
aaguid: UUID aaguid: UUID
public_key: bytes public_key: bytes
# Mutable fields that may be updated during authentication
sign_count: int sign_count: int
created_at: datetime created_at: datetime
last_used: datetime | None = None last_used: datetime | None = None