Database refactoring
This commit is contained in:
parent
66384da8ce
commit
4f8b5f837c
@ -5,6 +5,7 @@ This module provides an async database layer using dataclasses and aiosqlite
|
|||||||
for managing users and credentials in a WebAuthn authentication system.
|
for managing users and credentials in a WebAuthn authentication system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -35,6 +36,7 @@ SQL_CREATE_CREDENTIALS = """
|
|||||||
sign_count INTEGER NOT NULL,
|
sign_count INTEGER NOT NULL,
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
last_used TIMESTAMP NULL,
|
last_used TIMESTAMP NULL,
|
||||||
|
last_verified TIMESTAMP NULL,
|
||||||
FOREIGN KEY (user_id) REFERENCES users (user_id) ON DELETE CASCADE
|
FOREIGN KEY (user_id) REFERENCES users (user_id) ON DELETE CASCADE
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
@ -48,12 +50,12 @@ SQL_CREATE_USER = """
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
SQL_STORE_CREDENTIAL = """
|
SQL_STORE_CREDENTIAL = """
|
||||||
INSERT INTO credentials (credential_id, user_id, aaguid, public_key, sign_count)
|
INSERT INTO credentials (credential_id, user_id, aaguid, public_key, sign_count, created_at, last_used, last_verified)
|
||||||
VALUES (?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SQL_GET_CREDENTIAL_BY_ID = """
|
SQL_GET_CREDENTIAL_BY_ID = """
|
||||||
SELECT credential_id, user_id, aaguid, public_key, sign_count, created_at, last_used
|
SELECT credential_id, user_id, aaguid, public_key, sign_count, created_at, last_used, last_verified
|
||||||
FROM credentials
|
FROM credentials
|
||||||
WHERE credential_id = ?
|
WHERE credential_id = ?
|
||||||
"""
|
"""
|
||||||
@ -62,7 +64,7 @@ SQL_GET_USER_CREDENTIALS = """
|
|||||||
SELECT c.credential_id
|
SELECT c.credential_id
|
||||||
FROM credentials c
|
FROM credentials c
|
||||||
JOIN users u ON c.user_id = u.user_id
|
JOIN users u ON c.user_id = u.user_id
|
||||||
WHERE u.user_name = ?
|
WHERE u.user_id = ?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SQL_UPDATE_CREDENTIAL_SIGN_COUNT = """
|
SQL_UPDATE_CREDENTIAL_SIGN_COUNT = """
|
||||||
@ -71,6 +73,12 @@ SQL_UPDATE_CREDENTIAL_SIGN_COUNT = """
|
|||||||
WHERE credential_id = ?
|
WHERE credential_id = ?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
SQL_UPDATE_CREDENTIAL = """
|
||||||
|
UPDATE credentials
|
||||||
|
SET sign_count = ?, created_at = ?, last_used = ?, last_verified = ?
|
||||||
|
WHERE credential_id = ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class User:
|
class User:
|
||||||
@ -82,27 +90,34 @@ class User:
|
|||||||
last_seen: Optional[datetime] = None
|
last_seen: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
class Database:
|
@asynccontextmanager
|
||||||
"""Async database handler for WebAuthn operations."""
|
async def connect():
|
||||||
|
conn = await aiosqlite.connect(DB_PATH)
|
||||||
def __init__(self, db_path: str = DB_PATH):
|
try:
|
||||||
self.db_path = db_path
|
yield DB(conn)
|
||||||
|
|
||||||
async def init_database(self):
|
|
||||||
"""Initialize the SQLite database with required tables."""
|
|
||||||
async with aiosqlite.connect(self.db_path) as conn:
|
|
||||||
await conn.execute(SQL_CREATE_USERS)
|
|
||||||
await conn.execute(SQL_CREATE_CREDENTIALS)
|
|
||||||
await conn.commit()
|
await conn.commit()
|
||||||
|
finally:
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
class DB:
|
||||||
|
def __init__(self, conn: aiosqlite.Connection):
|
||||||
|
self.conn = conn
|
||||||
|
|
||||||
|
async def init_db(self) -> None:
|
||||||
|
"""Initialize database tables."""
|
||||||
|
await self.conn.execute(SQL_CREATE_USERS)
|
||||||
|
await self.conn.execute(SQL_CREATE_CREDENTIALS)
|
||||||
|
await self.conn.commit()
|
||||||
|
|
||||||
|
# Database operation functions that work with a connection
|
||||||
async def get_user_by_user_id(self, user_id: bytes) -> User:
|
async def get_user_by_user_id(self, user_id: bytes) -> User:
|
||||||
"""Get user record by WebAuthn user ID."""
|
"""Get user record by WebAuthn user ID."""
|
||||||
async with aiosqlite.connect(self.db_path) as conn:
|
async with self.conn.execute(SQL_GET_USER_BY_USER_ID, (user_id,)) as cursor:
|
||||||
async with conn.execute(SQL_GET_USER_BY_USER_ID, (user_id,)) as cursor:
|
|
||||||
row = await cursor.fetchone()
|
row = await cursor.fetchone()
|
||||||
if row:
|
if row:
|
||||||
return User(
|
return User(
|
||||||
user_id=row[0],
|
user_id=UUID(bytes=row[0]),
|
||||||
user_name=row[1],
|
user_name=row[1],
|
||||||
created_at=row[2],
|
created_at=row[2],
|
||||||
last_seen=row[3],
|
last_seen=row[3],
|
||||||
@ -111,33 +126,31 @@ class Database:
|
|||||||
|
|
||||||
async def create_user(self, user: User) -> User:
|
async def create_user(self, user: User) -> User:
|
||||||
"""Create a new user and return the User dataclass."""
|
"""Create a new user and return the User dataclass."""
|
||||||
async with aiosqlite.connect(self.db_path) as conn:
|
await self.conn.execute(
|
||||||
await conn.execute(
|
|
||||||
SQL_CREATE_USER,
|
SQL_CREATE_USER,
|
||||||
(user.user_id, user.user_name, user.created_at, user.last_seen),
|
(user.user_id.bytes, user.user_name, user.created_at, user.last_seen),
|
||||||
)
|
)
|
||||||
await conn.commit()
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
async def store_credential(self, credential: StoredCredential) -> None:
|
async def store_credential(self, credential: StoredCredential) -> None:
|
||||||
"""Store a credential for a user."""
|
"""Store a credential for a user."""
|
||||||
async with aiosqlite.connect(self.db_path) as conn:
|
await self.conn.execute(
|
||||||
await conn.execute(
|
|
||||||
SQL_STORE_CREDENTIAL,
|
SQL_STORE_CREDENTIAL,
|
||||||
(
|
(
|
||||||
credential.credential_id,
|
credential.credential_id,
|
||||||
credential.user_id,
|
credential.user_id.bytes,
|
||||||
credential.aaguid.bytes,
|
credential.aaguid.bytes,
|
||||||
credential.public_key,
|
credential.public_key,
|
||||||
credential.sign_count,
|
credential.sign_count,
|
||||||
|
credential.created_at,
|
||||||
|
credential.last_used,
|
||||||
|
credential.last_verified,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
await conn.commit()
|
|
||||||
|
|
||||||
async def get_credential_by_id(self, credential_id: bytes) -> StoredCredential:
|
async def get_credential_by_id(self, credential_id: bytes) -> StoredCredential:
|
||||||
"""Get credential by credential ID."""
|
"""Get credential by credential ID."""
|
||||||
async with aiosqlite.connect(self.db_path) as conn:
|
async with self.conn.execute(
|
||||||
async with conn.execute(
|
|
||||||
SQL_GET_CREDENTIAL_BY_ID, (credential_id,)
|
SQL_GET_CREDENTIAL_BY_ID, (credential_id,)
|
||||||
) as cursor:
|
) as cursor:
|
||||||
row = await cursor.fetchone()
|
row = await cursor.fetchone()
|
||||||
@ -150,36 +163,35 @@ class Database:
|
|||||||
sign_count=row[4],
|
sign_count=row[4],
|
||||||
created_at=row[5],
|
created_at=row[5],
|
||||||
last_used=row[6],
|
last_used=row[6],
|
||||||
|
last_verified=row[7],
|
||||||
)
|
)
|
||||||
raise ValueError("Credential not found")
|
raise ValueError("Credential not found")
|
||||||
|
|
||||||
async def get_credentials_by_user_id(self, user_id: bytes) -> list[bytes]:
|
async def get_credentials_by_user_id(self, user_id: bytes) -> list[bytes]:
|
||||||
"""Get all credential IDs for a user."""
|
"""Get all credential IDs for a user."""
|
||||||
async with aiosqlite.connect(self.db_path) as conn:
|
async with self.conn.execute(SQL_GET_USER_CREDENTIALS, (user_id,)) as cursor:
|
||||||
async with conn.execute(SQL_GET_USER_CREDENTIALS, (user_id,)) as cursor:
|
|
||||||
rows = await cursor.fetchall()
|
rows = await cursor.fetchall()
|
||||||
return [row[0] for row in rows]
|
return [row[0] for row in rows]
|
||||||
|
|
||||||
async def update_credential(self, credential: StoredCredential) -> None:
|
async def update_credential(self, credential: StoredCredential) -> None:
|
||||||
"""Update the sign count for a credential."""
|
"""Update the sign count, created_at, last_used, and last_verified for a credential."""
|
||||||
async with aiosqlite.connect(self.db_path) as conn:
|
await self.conn.execute(
|
||||||
await conn.execute(
|
SQL_UPDATE_CREDENTIAL,
|
||||||
SQL_UPDATE_CREDENTIAL_SIGN_COUNT,
|
(
|
||||||
(credential.sign_count, credential.credential_id),
|
credential.sign_count,
|
||||||
|
credential.created_at,
|
||||||
|
credential.last_used,
|
||||||
|
credential.last_verified,
|
||||||
|
credential.credential_id,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
await conn.commit()
|
|
||||||
|
|
||||||
async def login(self, user_id: bytes, credential: StoredCredential) -> None:
|
async def login(self, user_id: bytes, credential: StoredCredential) -> None:
|
||||||
"""Update the last_seen timestamp for a user."""
|
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
|
||||||
async with aiosqlite.connect(self.db_path) as conn:
|
# Update credential
|
||||||
# Do these in a single transaction
|
await self.update_credential(credential)
|
||||||
self.store_credential(credential)
|
# Update user's last_seen timestamp
|
||||||
await conn.execute(
|
await self.conn.execute(
|
||||||
"UPDATE users SET last_seen = ? WHERE user_id = ?",
|
"UPDATE users SET last_seen = ? WHERE user_id = ?",
|
||||||
(last_seen, user_id),
|
(credential.last_used, user_id),
|
||||||
)
|
)
|
||||||
await conn.commit()
|
|
||||||
|
|
||||||
|
|
||||||
# Global database instance
|
|
||||||
db = Database()
|
|
||||||
|
@ -40,10 +40,12 @@ from webauthn.helpers.structs import (
|
|||||||
class StoredCredential:
|
class StoredCredential:
|
||||||
"""Credential data stored in the database."""
|
"""Credential data stored in the database."""
|
||||||
|
|
||||||
|
# Fields set only at registration time
|
||||||
credential_id: bytes
|
credential_id: bytes
|
||||||
user_id: UUID
|
user_id: UUID
|
||||||
aaguid: UUID
|
aaguid: UUID
|
||||||
public_key: bytes
|
public_key: bytes
|
||||||
|
# Mutable fields that may be updated during authentication
|
||||||
sign_count: int
|
sign_count: int
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
last_used: datetime | None = None
|
last_used: datetime | None = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user