Add host-based authentication, UTC timestamps, session management, and secure cookies; fix styling issues; refactor to remove module; update database schema for sessions and reset tokens.

This commit is contained in:
Leo Vasanko
2025-10-03 18:31:54 -06:00
parent 963ab06664
commit 591ea626bf
29 changed files with 1489 additions and 611 deletions

View File

@@ -63,9 +63,27 @@ class Credential:
class Session:
key: bytes
user_uuid: UUID
expires: datetime
info: dict
credential_uuid: UUID | None = None
credential_uuid: UUID
host: str
ip: str
user_agent: str
renewed: datetime
def metadata(self) -> dict:
"""Return session metadata for backwards compatibility."""
return {
"ip": self.ip,
"user_agent": self.user_agent,
"renewed": self.renewed.isoformat(),
}
@dataclass
class ResetToken:
key: bytes
user_uuid: UUID
expiry: datetime
token_type: str
@dataclass
@@ -146,9 +164,11 @@ class DatabaseInterface(ABC):
self,
user_uuid: UUID,
key: bytes,
expires: datetime,
info: dict,
credential_uuid: UUID | None = None,
credential_uuid: UUID,
host: str,
ip: str,
user_agent: str,
renewed: datetime,
) -> None:
"""Create a new session."""
@@ -162,14 +182,50 @@ class DatabaseInterface(ABC):
@abstractmethod
async def update_session(
self, key: bytes, expires: datetime, info: dict
self,
key: bytes,
*,
ip: str,
user_agent: str,
renewed: datetime,
) -> Session | None:
"""Update session expiry and info."""
"""Update session metadata and touch renewed timestamp."""
@abstractmethod
async def set_session_host(self, key: bytes, host: str) -> None:
"""Bind a session to a specific host if not already set."""
@abstractmethod
async def list_sessions_for_user(self, user_uuid: UUID) -> list[Session]:
"""Return all sessions for a user (including other hosts)."""
@abstractmethod
async def cleanup(self) -> None:
"""Called periodically to clean up expired records."""
@abstractmethod
async def delete_sessions_for_user(self, user_uuid: UUID) -> None:
"""Delete all sessions belonging to the provided user."""
# Reset token operations
@abstractmethod
async def create_reset_token(
self,
user_uuid: UUID,
key: bytes,
expiry: datetime,
token_type: str,
) -> None:
"""Create a reset token for a user."""
@abstractmethod
async def get_reset_token(self, key: bytes) -> ResetToken | None:
"""Retrieve a reset token by key."""
@abstractmethod
async def delete_reset_token(self, key: bytes) -> None:
"""Delete a reset token by key."""
# Organization operations
@abstractmethod
async def create_organization(self, org: Org) -> None:
@@ -315,36 +371,41 @@ class DatabaseInterface(ABC):
"""Create a new user and their first credential in a transaction."""
@abstractmethod
async def get_session_context(self, session_key: bytes) -> SessionContext | None:
async def get_session_context(
self, session_key: bytes, host: str | None = None
) -> SessionContext | None:
"""Get complete session context including user, organization, role, and permissions."""
# Combined atomic operations
@abstractmethod
async def create_credential_session(
self,
user_uuid: UUID,
credential: Credential,
reset_key: bytes | None,
session_key: bytes,
session_expires: datetime,
session_info: dict,
display_name: str | None = None,
) -> None:
"""Atomically add a credential and create a session.
# Combined atomic operations
@abstractmethod
async def create_credential_session(
self,
user_uuid: UUID,
credential: Credential,
reset_key: bytes | None,
session_key: bytes,
*,
display_name: str | None = None,
host: str | None = None,
ip: str | None = None,
user_agent: str | None = None,
) -> None:
"""Atomically add a credential and create a session.
Steps (single transaction):
1. Insert credential
2. Optionally delete old session (e.g. reset token) if provided
3. Optionally update user's display name
4. Insert new session referencing the credential
5. Update user's last_seen and increment visits (treat as a login)
"""
Steps (single transaction):
1. Insert credential
2. Optionally delete old reset token if provided
3. Optionally update user's display name
4. Insert new session referencing the credential
5. Update user's last_seen and increment visits (treat as a login)
"""
__all__ = [
"User",
"Credential",
"Session",
"ResetToken",
"SessionContext",
"Org",
"Role",

View File

@@ -6,7 +6,7 @@ for managing users and credentials in a WebAuthn authentication system.
"""
from contextlib import asynccontextmanager
from datetime import datetime
from datetime import datetime, timezone
from uuid import UUID
from sqlalchemy import (
@@ -19,18 +19,21 @@ from sqlalchemy import (
event,
insert,
select,
text,
update,
)
from sqlalchemy.dialects.sqlite import BLOB, JSON
from sqlalchemy.dialects.sqlite import BLOB
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from ..config import SESSION_LIFETIME
from ..globals import db
from . import (
Credential,
DatabaseInterface,
Org,
Permission,
ResetToken,
Role,
Session,
SessionContext,
@@ -40,6 +43,14 @@ from . import (
DB_PATH = "sqlite+aiosqlite:///passkey-auth.sqlite"
def _normalize_dt(value: datetime | None) -> datetime | None:
if value is None:
return None
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)
async def init(*args, **kwargs):
db.instance = DB()
await db.instance.init_db()
@@ -98,8 +109,12 @@ class UserModel(Base):
role_uuid: Mapped[bytes] = mapped_column(
LargeBinary(16), ForeignKey("roles.uuid", ondelete="CASCADE"), nullable=False
)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
last_seen: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
)
last_seen: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
visits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
def as_dataclass(self) -> User:
@@ -107,8 +122,8 @@ class UserModel(Base):
uuid=UUID(bytes=self.uuid),
display_name=self.display_name,
role_uuid=UUID(bytes=self.role_uuid),
created_at=self.created_at,
last_seen=self.last_seen,
created_at=_normalize_dt(self.created_at) or self.created_at,
last_seen=_normalize_dt(self.last_seen) or self.last_seen,
visits=self.visits,
)
@@ -118,7 +133,7 @@ class UserModel(Base):
uuid=user.uuid.bytes,
display_name=user.display_name,
role_uuid=user.role_uuid.bytes,
created_at=user.created_at or datetime.now(),
created_at=user.created_at or datetime.now(timezone.utc),
last_seen=user.last_seen,
visits=user.visits,
)
@@ -137,9 +152,29 @@ class CredentialModel(Base):
aaguid: Mapped[bytes] = mapped_column(LargeBinary(16), nullable=False)
public_key: Mapped[bytes] = mapped_column(BLOB, nullable=False)
sign_count: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
last_verified: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
)
# Columns declared timezone-aware going forward; legacy rows may still be naive in storage
last_used: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
last_verified: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
def as_dataclass(self): # type: ignore[override]
return Credential(
uuid=UUID(bytes=self.uuid),
credential_id=self.credential_id,
user_uuid=UUID(bytes=self.user_uuid),
aaguid=UUID(bytes=self.aaguid),
public_key=self.public_key,
sign_count=self.sign_count,
created_at=_normalize_dt(self.created_at) or self.created_at,
last_used=_normalize_dt(self.last_used) or self.last_used,
last_verified=_normalize_dt(self.last_verified) or self.last_verified,
)
class SessionModel(Base):
@@ -147,23 +182,31 @@ class SessionModel(Base):
key: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
user_uuid: Mapped[bytes] = mapped_column(
LargeBinary(16), ForeignKey("users.uuid", ondelete="CASCADE")
LargeBinary(16), ForeignKey("users.uuid", ondelete="CASCADE"), nullable=False
)
credential_uuid: Mapped[bytes | None] = mapped_column(
LargeBinary(16), ForeignKey("credentials.uuid", ondelete="CASCADE")
credential_uuid: Mapped[bytes] = mapped_column(
LargeBinary(16),
ForeignKey("credentials.uuid", ondelete="CASCADE"),
nullable=False,
)
host: Mapped[str] = mapped_column(String, nullable=False)
ip: Mapped[str] = mapped_column(String(64), nullable=False)
user_agent: Mapped[str] = mapped_column(String(512), nullable=False)
renewed: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
nullable=False,
)
expires: Mapped[datetime] = mapped_column(DateTime, nullable=False)
info: Mapped[dict] = mapped_column(JSON, default=dict)
def as_dataclass(self):
return Session(
key=self.key,
user_uuid=UUID(bytes=self.user_uuid),
credential_uuid=(
UUID(bytes=self.credential_uuid) if self.credential_uuid else None
),
expires=self.expires,
info=self.info,
credential_uuid=UUID(bytes=self.credential_uuid),
host=self.host,
ip=self.ip,
user_agent=self.user_agent,
renewed=_normalize_dt(self.renewed) or self.renewed,
)
@staticmethod
@@ -171,9 +214,30 @@ class SessionModel(Base):
return SessionModel(
key=session.key,
user_uuid=session.user_uuid.bytes,
credential_uuid=session.credential_uuid and session.credential_uuid.bytes,
expires=session.expires,
info=session.info,
credential_uuid=session.credential_uuid.bytes,
host=session.host,
ip=session.ip,
user_agent=session.user_agent,
renewed=session.renewed,
)
class ResetTokenModel(Base):
__tablename__ = "reset_tokens"
key: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
user_uuid: Mapped[bytes] = mapped_column(
LargeBinary(16), ForeignKey("users.uuid", ondelete="CASCADE"), nullable=False
)
token_type: Mapped[str] = mapped_column(String, nullable=False)
expiry: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
def as_dataclass(self) -> ResetToken:
return ResetToken(
key=self.key,
user_uuid=UUID(bytes=self.user_uuid),
token_type=self.token_type,
expiry=_normalize_dt(self.expiry) or self.expiry,
)
@@ -257,6 +321,58 @@ class DB(DatabaseInterface):
"""Initialize database tables."""
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
result = await conn.execute(text("PRAGMA table_info('sessions')"))
columns = {row[1] for row in result}
expected = {
"key",
"user_uuid",
"credential_uuid",
"host",
"ip",
"user_agent",
"renewed",
}
needs_recreate = False
if columns and columns != expected:
await conn.execute(text("DROP TABLE sessions"))
needs_recreate = True
result = await conn.execute(text("PRAGMA table_info('reset_tokens')"))
if not list(result):
needs_recreate = True
if needs_recreate:
await conn.run_sync(Base.metadata.create_all)
# Run one-time migration to add UTC tzinfo to any naive datetimes
await self._migrate_naive_datetimes()
async def _migrate_naive_datetimes(self) -> None:
"""Attach UTC tzinfo to any legacy naive datetime rows.
SQLite stores datetimes as text; older rows may have been inserted naive.
We treat naive timestamps as already UTC and rewrite them in ISO8601 with Z.
"""
# Helper SQL fragment for detecting naive (no timezone offset) for ISO strings
# We only update rows whose textual representation lacks a 'Z' or '+' sign.
async with self.session() as session:
# Users
for model, fields in [
(UserModel, ["created_at", "last_seen"]),
(CredentialModel, ["created_at", "last_used", "last_verified"]),
(SessionModel, ["renewed"]),
(ResetTokenModel, ["expiry"]),
]:
stmt = select(model)
result = await session.execute(stmt)
rows = result.scalars().all()
dirty = False
for row in rows:
for fname in fields:
value = getattr(row, fname, None)
if isinstance(value, datetime) and value.tzinfo is None:
setattr(row, fname, value.replace(tzinfo=timezone.utc))
dirty = True
if dirty:
# SQLAlchemy autoflush/commit in context manager will persist
pass
async def get_user_by_uuid(self, user_uuid: UUID) -> User:
async with self.session() as session:
@@ -409,9 +525,11 @@ class DB(DatabaseInterface):
credential: Credential,
reset_key: bytes | None,
session_key: bytes,
session_expires: datetime,
session_info: dict,
*,
display_name: str | None = None,
host: str | None = None,
ip: str | None = None,
user_agent: str | None = None,
) -> None:
"""Atomic credential + (optional old session delete) + (optional rename) + new session."""
async with self.session() as session:
@@ -434,10 +552,10 @@ class DB(DatabaseInterface):
last_verified=credential.last_verified,
)
)
# Delete old session if provided
# Delete old reset token if provided
if reset_key:
await session.execute(
delete(SessionModel).where(SessionModel.key == reset_key)
delete(ResetTokenModel).where(ResetTokenModel.key == reset_key)
)
# Optional rename
if display_name:
@@ -452,8 +570,9 @@ class DB(DatabaseInterface):
key=session_key,
user_uuid=user_uuid.bytes,
credential_uuid=credential.uuid.bytes,
expires=session_expires,
info=session_info,
host=host,
ip=ip,
user_agent=user_agent,
)
)
# Login side-effects: update user analytics (last_seen + visits increment)
@@ -476,17 +595,21 @@ class DB(DatabaseInterface):
self,
user_uuid: UUID,
key: bytes,
expires: datetime,
info: dict,
credential_uuid: UUID | None = None,
credential_uuid: UUID,
host: str,
ip: str,
user_agent: str,
renewed: datetime,
) -> None:
async with self.session() as session:
session_model = SessionModel(
key=key,
user_uuid=user_uuid.bytes,
credential_uuid=credential_uuid.bytes if credential_uuid else None,
expires=expires,
info=info,
credential_uuid=credential_uuid.bytes,
host=host,
ip=ip,
user_agent=user_agent,
renewed=renewed,
)
session.add(session_model)
@@ -497,29 +620,88 @@ class DB(DatabaseInterface):
session_model = result.scalar_one_or_none()
if session_model:
return Session(
key=session_model.key,
user_uuid=UUID(bytes=session_model.user_uuid),
credential_uuid=UUID(bytes=session_model.credential_uuid)
if session_model.credential_uuid
else None,
expires=session_model.expires,
info=session_model.info or {},
)
return session_model.as_dataclass()
return None
async def delete_session(self, key: bytes) -> None:
async with self.session() as session:
await session.execute(delete(SessionModel).where(SessionModel.key == key))
async def update_session(self, key: bytes, expires: datetime, info: dict) -> None:
async def delete_sessions_for_user(self, user_uuid: UUID) -> None:
async with self.session() as session:
await session.execute(
update(SessionModel)
.where(SessionModel.key == key)
.values(expires=expires, info=info)
delete(SessionModel).where(SessionModel.user_uuid == user_uuid.bytes)
)
async def create_reset_token(
self,
user_uuid: UUID,
key: bytes,
expiry: datetime,
token_type: str,
) -> None:
async with self.session() as session:
model = ResetTokenModel(
key=key,
user_uuid=user_uuid.bytes,
token_type=token_type,
expiry=expiry,
)
session.add(model)
async def get_reset_token(self, key: bytes) -> ResetToken | None:
async with self.session() as session:
stmt = select(ResetTokenModel).where(ResetTokenModel.key == key)
result = await session.execute(stmt)
model = result.scalar_one_or_none()
return model.as_dataclass() if model else None
async def delete_reset_token(self, key: bytes) -> None:
async with self.session() as session:
await session.execute(
delete(ResetTokenModel).where(ResetTokenModel.key == key)
)
async def update_session(
self,
key: bytes,
*,
ip: str,
user_agent: str,
renewed: datetime,
) -> Session | None:
async with self.session() as session:
model = await session.get(SessionModel, key)
if not model:
return None
model.ip = ip
model.user_agent = user_agent
model.renewed = renewed
await session.flush()
return model.as_dataclass()
async def set_session_host(self, key: bytes, host: str) -> None:
async with self.session() as session:
model = await session.get(SessionModel, key)
if model and model.host is None:
model.host = host
await session.flush()
async def list_sessions_for_user(self, user_uuid: UUID) -> list[Session]:
async with self.session() as session:
stmt = (
select(SessionModel)
.where(SessionModel.user_uuid == user_uuid.bytes)
.order_by(SessionModel.renewed.desc())
)
result = await session.execute(stmt)
session_models = [
model
for model in result.scalars().all()
if model.key.startswith(b"sess")
]
return [model.as_dataclass() for model in session_models]
# Organization operations
async def create_organization(self, org: Org) -> None:
async with self.session() as session:
@@ -1115,11 +1297,18 @@ class DB(DatabaseInterface):
async def cleanup(self) -> None:
async with self.session() as session:
current_time = datetime.now()
stmt = delete(SessionModel).where(SessionModel.expires < current_time)
await session.execute(stmt)
current_time = datetime.now(timezone.utc)
session_threshold = current_time - SESSION_LIFETIME
await session.execute(
delete(SessionModel).where(SessionModel.renewed < session_threshold)
)
await session.execute(
delete(ResetTokenModel).where(ResetTokenModel.expiry < current_time)
)
async def get_session_context(self, session_key: bytes) -> SessionContext | None:
async def get_session_context(
self, session_key: bytes, host: str | None = None
) -> SessionContext | None:
"""Get complete session context including user, organization, role, and permissions.
Uses efficient JOINs to retrieve all related data in a single database query.
@@ -1156,15 +1345,18 @@ class DB(DatabaseInterface):
session_model, user_model, role_model, org_model, _ = first_row
# Create the session object
session_obj = Session(
key=session_model.key,
user_uuid=UUID(bytes=session_model.user_uuid),
credential_uuid=UUID(bytes=session_model.credential_uuid)
if session_model.credential_uuid
else None,
expires=session_model.expires,
info=session_model.info or {},
)
if host is not None:
if session_model.host is None:
await session.execute(
update(SessionModel)
.where(SessionModel.key == session_key)
.values(host=host)
)
session_model.host = host
elif session_model.host != host:
return None
session_obj = session_model.as_dataclass()
# Create the user object
user_obj = user_model.as_dataclass()