389 lines
14 KiB
Python
389 lines
14 KiB
Python
"""
|
|
Async database implementation for WebAuthn passkey authentication.
|
|
|
|
This module provides an async database layer using SQLAlchemy async mode
|
|
for managing users and credentials in a WebAuthn authentication system.
|
|
"""
|
|
|
|
import secrets
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import (
|
|
DateTime,
|
|
ForeignKey,
|
|
Integer,
|
|
LargeBinary,
|
|
String,
|
|
delete,
|
|
select,
|
|
update,
|
|
)
|
|
from sqlalchemy.dialects.sqlite import BLOB
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
|
|
|
from .passkey import StoredCredential
|
|
|
|
DB_PATH = "sqlite+aiosqlite:///webauthn.db"
|
|
|
|
|
|
# SQLAlchemy Models
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
|
|
class UserModel(Base):
|
|
__tablename__ = "users"
|
|
|
|
user_id: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
|
|
user_name: Mapped[str] = mapped_column(String, nullable=False)
|
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
|
last_seen: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
|
visits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
|
|
# Relationship to credentials
|
|
credentials: Mapped[list["CredentialModel"]] = relationship(
|
|
"CredentialModel", back_populates="user", cascade="all, delete-orphan"
|
|
)
|
|
|
|
|
|
class CredentialModel(Base):
|
|
__tablename__ = "credentials"
|
|
|
|
credential_id: Mapped[bytes] = mapped_column(LargeBinary(64), primary_key=True)
|
|
user_id: Mapped[bytes] = mapped_column(
|
|
LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
|
|
)
|
|
aaguid: Mapped[bytes] = mapped_column(LargeBinary(16), nullable=False)
|
|
public_key: Mapped[bytes] = mapped_column(BLOB, nullable=False)
|
|
sign_count: Mapped[int] = mapped_column(Integer, nullable=False)
|
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
|
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
|
last_verified: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
|
|
|
# Relationship to user
|
|
user: Mapped["UserModel"] = relationship("UserModel", back_populates="credentials")
|
|
|
|
|
|
class ResetTokenModel(Base):
|
|
__tablename__ = "reset_tokens"
|
|
|
|
token: Mapped[str] = mapped_column(String(64), primary_key=True)
|
|
user_id: Mapped[bytes] = mapped_column(
|
|
LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
|
|
)
|
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
|
|
|
# Relationship to user
|
|
user: Mapped["UserModel"] = relationship("UserModel")
|
|
|
|
|
|
@dataclass
|
|
class User:
|
|
user_id: UUID
|
|
user_name: str
|
|
created_at: datetime | None = None
|
|
last_seen: datetime | None = None
|
|
visits: int = 0
|
|
|
|
|
|
@dataclass
|
|
class ResetToken:
|
|
token: str
|
|
user_id: UUID
|
|
created_at: datetime
|
|
|
|
|
|
# Global engine and session factory
|
|
engine = create_async_engine(DB_PATH, echo=False)
|
|
async_session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def connect():
|
|
"""Context manager for database connections."""
|
|
async with async_session_factory() as session:
|
|
yield DB(session)
|
|
await session.commit()
|
|
|
|
|
|
class DB:
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
async def init_db(self) -> None:
|
|
"""Initialize database tables."""
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
async def get_user_by_user_id(self, user_id: UUID) -> User:
|
|
"""Get user record by WebAuthn user ID."""
|
|
stmt = select(UserModel).where(UserModel.user_id == user_id.bytes)
|
|
result = await self.session.execute(stmt)
|
|
user_model = result.scalar_one_or_none()
|
|
|
|
if user_model:
|
|
return User(
|
|
user_id=UUID(bytes=user_model.user_id),
|
|
user_name=user_model.user_name,
|
|
created_at=user_model.created_at,
|
|
last_seen=user_model.last_seen,
|
|
visits=user_model.visits,
|
|
)
|
|
raise ValueError("User not found")
|
|
|
|
async def create_user(self, user: User) -> None:
|
|
"""Create a new user."""
|
|
user_model = UserModel(
|
|
user_id=user.user_id.bytes,
|
|
user_name=user.user_name,
|
|
created_at=user.created_at or datetime.now(),
|
|
last_seen=user.last_seen,
|
|
visits=user.visits,
|
|
)
|
|
self.session.add(user_model)
|
|
await self.session.flush()
|
|
|
|
async def create_credential(self, credential: StoredCredential) -> None:
|
|
"""Store a credential for a user."""
|
|
credential_model = CredentialModel(
|
|
credential_id=credential.credential_id,
|
|
user_id=credential.user_id.bytes,
|
|
aaguid=credential.aaguid.bytes,
|
|
public_key=credential.public_key,
|
|
sign_count=credential.sign_count,
|
|
created_at=credential.created_at,
|
|
last_used=credential.last_used,
|
|
last_verified=credential.last_verified,
|
|
)
|
|
self.session.add(credential_model)
|
|
await self.session.flush()
|
|
|
|
async def get_credential_by_id(self, credential_id: bytes) -> StoredCredential:
|
|
"""Get credential by credential ID."""
|
|
stmt = select(CredentialModel).where(
|
|
CredentialModel.credential_id == credential_id
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
credential_model = result.scalar_one_or_none()
|
|
|
|
if credential_model:
|
|
return StoredCredential(
|
|
credential_id=credential_model.credential_id,
|
|
user_id=UUID(bytes=credential_model.user_id),
|
|
aaguid=UUID(bytes=credential_model.aaguid),
|
|
public_key=credential_model.public_key,
|
|
sign_count=credential_model.sign_count,
|
|
created_at=credential_model.created_at,
|
|
last_used=credential_model.last_used,
|
|
last_verified=credential_model.last_verified,
|
|
)
|
|
raise ValueError("Credential not registered")
|
|
|
|
async def get_credentials_by_user_id(self, user_id: UUID) -> list[bytes]:
|
|
"""Get all credential IDs for a user."""
|
|
stmt = select(CredentialModel.credential_id).where(
|
|
CredentialModel.user_id == user_id.bytes
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return [row[0] for row in result.fetchall()]
|
|
|
|
async def update_credential(self, credential: StoredCredential) -> None:
|
|
"""Update the sign count, created_at, last_used, and last_verified for a credential."""
|
|
stmt = (
|
|
update(CredentialModel)
|
|
.where(CredentialModel.credential_id == credential.credential_id)
|
|
.values(
|
|
sign_count=credential.sign_count,
|
|
created_at=credential.created_at,
|
|
last_used=credential.last_used,
|
|
last_verified=credential.last_verified,
|
|
)
|
|
)
|
|
await self.session.execute(stmt)
|
|
|
|
async def login(self, user_id: UUID, credential: StoredCredential) -> None:
|
|
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
|
|
async with self.session.begin():
|
|
# Update credential
|
|
await self.update_credential(credential)
|
|
|
|
# Update user's last_seen and increment visits
|
|
stmt = (
|
|
update(UserModel)
|
|
.where(UserModel.user_id == user_id.bytes)
|
|
.values(last_seen=credential.last_used, visits=UserModel.visits + 1)
|
|
)
|
|
await self.session.execute(stmt)
|
|
|
|
async def create_new_session(
|
|
self, user_id: UUID, credential: StoredCredential
|
|
) -> None:
|
|
"""Create a new session for a user by incrementing visits and updating last_seen."""
|
|
async with self.session.begin():
|
|
# Update credential
|
|
await self.update_credential(credential)
|
|
|
|
# Update user's last_seen and increment visits
|
|
stmt = (
|
|
update(UserModel)
|
|
.where(UserModel.user_id == user_id.bytes)
|
|
.values(last_seen=credential.last_used, visits=UserModel.visits + 1)
|
|
)
|
|
await self.session.execute(stmt)
|
|
|
|
async def delete_credential(self, credential_id: bytes) -> None:
|
|
"""Delete a credential by its ID."""
|
|
stmt = delete(CredentialModel).where(
|
|
CredentialModel.credential_id == credential_id
|
|
)
|
|
await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
|
|
async def create_reset_token(self, user_id: UUID, token: str | None = None) -> str:
|
|
"""Create a new reset token for a user."""
|
|
if token is None:
|
|
token = secrets.token_urlsafe(32)
|
|
|
|
reset_token_model = ResetTokenModel(
|
|
token=token,
|
|
user_id=user_id.bytes,
|
|
created_at=datetime.now(),
|
|
)
|
|
self.session.add(reset_token_model)
|
|
await self.session.flush()
|
|
return token
|
|
|
|
async def get_reset_token(self, token: str) -> ResetToken | None:
|
|
"""Get reset token by token string."""
|
|
stmt = select(ResetTokenModel).where(ResetTokenModel.token == token)
|
|
result = await self.session.execute(stmt)
|
|
token_model = result.scalar_one_or_none()
|
|
|
|
if token_model:
|
|
return ResetToken(
|
|
token=token_model.token,
|
|
user_id=UUID(bytes=token_model.user_id),
|
|
created_at=token_model.created_at,
|
|
)
|
|
return None
|
|
|
|
async def delete_reset_token(self, token: str) -> None:
|
|
"""Delete a reset token (used after successful credential addition)."""
|
|
stmt = delete(ResetTokenModel).where(ResetTokenModel.token == token)
|
|
await self.session.execute(stmt)
|
|
|
|
async def cleanup_expired_tokens(self) -> None:
|
|
"""Remove expired reset tokens (older than 24 hours)."""
|
|
expiry_time = datetime.now() - timedelta(hours=24)
|
|
stmt = delete(ResetTokenModel).where(ResetTokenModel.created_at < expiry_time)
|
|
await self.session.execute(stmt)
|
|
|
|
async def get_user_by_username(self, user_name: str) -> User | None:
|
|
"""Get user by username."""
|
|
stmt = select(UserModel).where(UserModel.user_name == user_name)
|
|
result = await self.session.execute(stmt)
|
|
user_model = result.scalar_one_or_none()
|
|
|
|
if user_model:
|
|
return User(
|
|
user_id=UUID(bytes=user_model.user_id),
|
|
user_name=user_model.user_name,
|
|
created_at=user_model.created_at,
|
|
last_seen=user_model.last_seen,
|
|
visits=user_model.visits,
|
|
)
|
|
return None
|
|
|
|
|
|
# Standalone functions that handle database connections internally
|
|
async def init_database() -> None:
|
|
"""Initialize database tables."""
|
|
async with connect() as db:
|
|
await db.init_db()
|
|
|
|
|
|
async def create_user_and_credential(user: User, credential: StoredCredential) -> None:
|
|
"""Create a new user and their first credential in a single transaction."""
|
|
async with connect() as db:
|
|
await db.session.begin()
|
|
# Set visits to 1 for the new user since they're creating their first session
|
|
user.visits = 1
|
|
await db.create_user(user)
|
|
await db.create_credential(credential)
|
|
|
|
|
|
async def get_user_by_id(user_id: UUID) -> User:
|
|
"""Get user record by WebAuthn user ID."""
|
|
async with connect() as db:
|
|
return await db.get_user_by_user_id(user_id)
|
|
|
|
|
|
async def create_credential_for_user(credential: StoredCredential) -> None:
|
|
"""Store a credential for an existing user."""
|
|
async with connect() as db:
|
|
await db.create_credential(credential)
|
|
|
|
|
|
async def get_credential_by_id(credential_id: bytes) -> StoredCredential:
|
|
"""Get credential by credential ID."""
|
|
async with connect() as db:
|
|
return await db.get_credential_by_id(credential_id)
|
|
|
|
|
|
async def get_user_credentials(user_id: UUID) -> list[bytes]:
|
|
"""Get all credential IDs for a user."""
|
|
async with connect() as db:
|
|
return await db.get_credentials_by_user_id(user_id)
|
|
|
|
|
|
async def login_user(user_id: UUID, credential: StoredCredential) -> None:
|
|
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
|
|
async with connect() as db:
|
|
await db.login(user_id, credential)
|
|
|
|
|
|
async def delete_user_credential(credential_id: bytes) -> None:
|
|
"""Delete a credential by its ID."""
|
|
async with connect() as db:
|
|
await db.delete_credential(credential_id)
|
|
|
|
|
|
async def create_new_session(user_id: UUID, credential: StoredCredential) -> None:
|
|
"""Create a new session for a user by incrementing visits and updating last_seen."""
|
|
async with connect() as db:
|
|
await db.create_new_session(user_id, credential)
|
|
|
|
|
|
async def create_reset_token(user_id: UUID, token: str | None = None) -> str:
|
|
"""Create a reset token for a user."""
|
|
async with connect() as db:
|
|
return await db.create_reset_token(user_id, token)
|
|
|
|
|
|
async def get_reset_token(token: str) -> ResetToken | None:
|
|
"""Get reset token by token string."""
|
|
async with connect() as db:
|
|
return await db.get_reset_token(token)
|
|
|
|
|
|
async def delete_reset_token(token: str) -> None:
|
|
"""Delete a reset token (used after successful credential addition)."""
|
|
async with connect() as db:
|
|
await db.delete_reset_token(token)
|
|
|
|
|
|
async def cleanup_expired_tokens() -> None:
|
|
"""Remove expired reset tokens (older than 24 hours)."""
|
|
async with connect() as db:
|
|
await db.cleanup_expired_tokens()
|
|
|
|
|
|
async def get_user_by_username(user_name: str) -> User | None:
|
|
"""Get user by username."""
|
|
async with connect() as db:
|
|
return await db.get_user_by_username(user_name)
|