Add host-based authentication, UTC timestamps, session management, and secure cookies; fix styling issues; refactor to remove module; update database schema for sessions and reset tokens.
This commit is contained in:
@@ -63,9 +63,27 @@ class Credential:
|
||||
class Session:
|
||||
key: bytes
|
||||
user_uuid: UUID
|
||||
expires: datetime
|
||||
info: dict
|
||||
credential_uuid: UUID | None = None
|
||||
credential_uuid: UUID
|
||||
host: str
|
||||
ip: str
|
||||
user_agent: str
|
||||
renewed: datetime
|
||||
|
||||
def metadata(self) -> dict:
|
||||
"""Return session metadata for backwards compatibility."""
|
||||
return {
|
||||
"ip": self.ip,
|
||||
"user_agent": self.user_agent,
|
||||
"renewed": self.renewed.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetToken:
|
||||
key: bytes
|
||||
user_uuid: UUID
|
||||
expiry: datetime
|
||||
token_type: str
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -146,9 +164,11 @@ class DatabaseInterface(ABC):
|
||||
self,
|
||||
user_uuid: UUID,
|
||||
key: bytes,
|
||||
expires: datetime,
|
||||
info: dict,
|
||||
credential_uuid: UUID | None = None,
|
||||
credential_uuid: UUID,
|
||||
host: str,
|
||||
ip: str,
|
||||
user_agent: str,
|
||||
renewed: datetime,
|
||||
) -> None:
|
||||
"""Create a new session."""
|
||||
|
||||
@@ -162,14 +182,50 @@ class DatabaseInterface(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def update_session(
|
||||
self, key: bytes, expires: datetime, info: dict
|
||||
self,
|
||||
key: bytes,
|
||||
*,
|
||||
ip: str,
|
||||
user_agent: str,
|
||||
renewed: datetime,
|
||||
) -> Session | None:
|
||||
"""Update session expiry and info."""
|
||||
"""Update session metadata and touch renewed timestamp."""
|
||||
|
||||
@abstractmethod
|
||||
async def set_session_host(self, key: bytes, host: str) -> None:
|
||||
"""Bind a session to a specific host if not already set."""
|
||||
|
||||
@abstractmethod
|
||||
async def list_sessions_for_user(self, user_uuid: UUID) -> list[Session]:
|
||||
"""Return all sessions for a user (including other hosts)."""
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup(self) -> None:
|
||||
"""Called periodically to clean up expired records."""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_sessions_for_user(self, user_uuid: UUID) -> None:
|
||||
"""Delete all sessions belonging to the provided user."""
|
||||
|
||||
# Reset token operations
|
||||
@abstractmethod
|
||||
async def create_reset_token(
|
||||
self,
|
||||
user_uuid: UUID,
|
||||
key: bytes,
|
||||
expiry: datetime,
|
||||
token_type: str,
|
||||
) -> None:
|
||||
"""Create a reset token for a user."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_reset_token(self, key: bytes) -> ResetToken | None:
|
||||
"""Retrieve a reset token by key."""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_reset_token(self, key: bytes) -> None:
|
||||
"""Delete a reset token by key."""
|
||||
|
||||
# Organization operations
|
||||
@abstractmethod
|
||||
async def create_organization(self, org: Org) -> None:
|
||||
@@ -315,36 +371,41 @@ class DatabaseInterface(ABC):
|
||||
"""Create a new user and their first credential in a transaction."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_session_context(self, session_key: bytes) -> SessionContext | None:
|
||||
async def get_session_context(
|
||||
self, session_key: bytes, host: str | None = None
|
||||
) -> SessionContext | None:
|
||||
"""Get complete session context including user, organization, role, and permissions."""
|
||||
|
||||
# Combined atomic operations
|
||||
@abstractmethod
|
||||
async def create_credential_session(
|
||||
self,
|
||||
user_uuid: UUID,
|
||||
credential: Credential,
|
||||
reset_key: bytes | None,
|
||||
session_key: bytes,
|
||||
session_expires: datetime,
|
||||
session_info: dict,
|
||||
display_name: str | None = None,
|
||||
) -> None:
|
||||
"""Atomically add a credential and create a session.
|
||||
# Combined atomic operations
|
||||
@abstractmethod
|
||||
async def create_credential_session(
|
||||
self,
|
||||
user_uuid: UUID,
|
||||
credential: Credential,
|
||||
reset_key: bytes | None,
|
||||
session_key: bytes,
|
||||
*,
|
||||
display_name: str | None = None,
|
||||
host: str | None = None,
|
||||
ip: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> None:
|
||||
"""Atomically add a credential and create a session.
|
||||
|
||||
Steps (single transaction):
|
||||
1. Insert credential
|
||||
2. Optionally delete old session (e.g. reset token) if provided
|
||||
3. Optionally update user's display name
|
||||
4. Insert new session referencing the credential
|
||||
5. Update user's last_seen and increment visits (treat as a login)
|
||||
"""
|
||||
Steps (single transaction):
|
||||
1. Insert credential
|
||||
2. Optionally delete old reset token if provided
|
||||
3. Optionally update user's display name
|
||||
4. Insert new session referencing the credential
|
||||
5. Update user's last_seen and increment visits (treat as a login)
|
||||
"""
|
||||
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
"Credential",
|
||||
"Session",
|
||||
"ResetToken",
|
||||
"SessionContext",
|
||||
"Org",
|
||||
"Role",
|
||||
|
||||
@@ -6,7 +6,7 @@ for managing users and credentials in a WebAuthn authentication system.
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import (
|
||||
@@ -19,18 +19,21 @@ from sqlalchemy import (
|
||||
event,
|
||||
insert,
|
||||
select,
|
||||
text,
|
||||
update,
|
||||
)
|
||||
from sqlalchemy.dialects.sqlite import BLOB, JSON
|
||||
from sqlalchemy.dialects.sqlite import BLOB
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
from ..config import SESSION_LIFETIME
|
||||
from ..globals import db
|
||||
from . import (
|
||||
Credential,
|
||||
DatabaseInterface,
|
||||
Org,
|
||||
Permission,
|
||||
ResetToken,
|
||||
Role,
|
||||
Session,
|
||||
SessionContext,
|
||||
@@ -40,6 +43,14 @@ from . import (
|
||||
DB_PATH = "sqlite+aiosqlite:///passkey-auth.sqlite"
|
||||
|
||||
|
||||
def _normalize_dt(value: datetime | None) -> datetime | None:
|
||||
if value is None:
|
||||
return None
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=timezone.utc)
|
||||
return value.astimezone(timezone.utc)
|
||||
|
||||
|
||||
async def init(*args, **kwargs):
|
||||
db.instance = DB()
|
||||
await db.instance.init_db()
|
||||
@@ -98,8 +109,12 @@ class UserModel(Base):
|
||||
role_uuid: Mapped[bytes] = mapped_column(
|
||||
LargeBinary(16), ForeignKey("roles.uuid", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||
last_seen: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
last_seen: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
visits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
def as_dataclass(self) -> User:
|
||||
@@ -107,8 +122,8 @@ class UserModel(Base):
|
||||
uuid=UUID(bytes=self.uuid),
|
||||
display_name=self.display_name,
|
||||
role_uuid=UUID(bytes=self.role_uuid),
|
||||
created_at=self.created_at,
|
||||
last_seen=self.last_seen,
|
||||
created_at=_normalize_dt(self.created_at) or self.created_at,
|
||||
last_seen=_normalize_dt(self.last_seen) or self.last_seen,
|
||||
visits=self.visits,
|
||||
)
|
||||
|
||||
@@ -118,7 +133,7 @@ class UserModel(Base):
|
||||
uuid=user.uuid.bytes,
|
||||
display_name=user.display_name,
|
||||
role_uuid=user.role_uuid.bytes,
|
||||
created_at=user.created_at or datetime.now(),
|
||||
created_at=user.created_at or datetime.now(timezone.utc),
|
||||
last_seen=user.last_seen,
|
||||
visits=user.visits,
|
||||
)
|
||||
@@ -137,9 +152,29 @@ class CredentialModel(Base):
|
||||
aaguid: Mapped[bytes] = mapped_column(LargeBinary(16), nullable=False)
|
||||
public_key: Mapped[bytes] = mapped_column(BLOB, nullable=False)
|
||||
sign_count: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
last_verified: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
# Columns declared timezone-aware going forward; legacy rows may still be naive in storage
|
||||
last_used: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
last_verified: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
def as_dataclass(self): # type: ignore[override]
|
||||
return Credential(
|
||||
uuid=UUID(bytes=self.uuid),
|
||||
credential_id=self.credential_id,
|
||||
user_uuid=UUID(bytes=self.user_uuid),
|
||||
aaguid=UUID(bytes=self.aaguid),
|
||||
public_key=self.public_key,
|
||||
sign_count=self.sign_count,
|
||||
created_at=_normalize_dt(self.created_at) or self.created_at,
|
||||
last_used=_normalize_dt(self.last_used) or self.last_used,
|
||||
last_verified=_normalize_dt(self.last_verified) or self.last_verified,
|
||||
)
|
||||
|
||||
|
||||
class SessionModel(Base):
|
||||
@@ -147,23 +182,31 @@ class SessionModel(Base):
|
||||
|
||||
key: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
|
||||
user_uuid: Mapped[bytes] = mapped_column(
|
||||
LargeBinary(16), ForeignKey("users.uuid", ondelete="CASCADE")
|
||||
LargeBinary(16), ForeignKey("users.uuid", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
credential_uuid: Mapped[bytes | None] = mapped_column(
|
||||
LargeBinary(16), ForeignKey("credentials.uuid", ondelete="CASCADE")
|
||||
credential_uuid: Mapped[bytes] = mapped_column(
|
||||
LargeBinary(16),
|
||||
ForeignKey("credentials.uuid", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
host: Mapped[str] = mapped_column(String, nullable=False)
|
||||
ip: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
user_agent: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
renewed: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
nullable=False,
|
||||
)
|
||||
expires: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
info: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
|
||||
def as_dataclass(self):
|
||||
return Session(
|
||||
key=self.key,
|
||||
user_uuid=UUID(bytes=self.user_uuid),
|
||||
credential_uuid=(
|
||||
UUID(bytes=self.credential_uuid) if self.credential_uuid else None
|
||||
),
|
||||
expires=self.expires,
|
||||
info=self.info,
|
||||
credential_uuid=UUID(bytes=self.credential_uuid),
|
||||
host=self.host,
|
||||
ip=self.ip,
|
||||
user_agent=self.user_agent,
|
||||
renewed=_normalize_dt(self.renewed) or self.renewed,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -171,9 +214,30 @@ class SessionModel(Base):
|
||||
return SessionModel(
|
||||
key=session.key,
|
||||
user_uuid=session.user_uuid.bytes,
|
||||
credential_uuid=session.credential_uuid and session.credential_uuid.bytes,
|
||||
expires=session.expires,
|
||||
info=session.info,
|
||||
credential_uuid=session.credential_uuid.bytes,
|
||||
host=session.host,
|
||||
ip=session.ip,
|
||||
user_agent=session.user_agent,
|
||||
renewed=session.renewed,
|
||||
)
|
||||
|
||||
|
||||
class ResetTokenModel(Base):
|
||||
__tablename__ = "reset_tokens"
|
||||
|
||||
key: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
|
||||
user_uuid: Mapped[bytes] = mapped_column(
|
||||
LargeBinary(16), ForeignKey("users.uuid", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
token_type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
expiry: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
def as_dataclass(self) -> ResetToken:
|
||||
return ResetToken(
|
||||
key=self.key,
|
||||
user_uuid=UUID(bytes=self.user_uuid),
|
||||
token_type=self.token_type,
|
||||
expiry=_normalize_dt(self.expiry) or self.expiry,
|
||||
)
|
||||
|
||||
|
||||
@@ -257,6 +321,58 @@ class DB(DatabaseInterface):
|
||||
"""Initialize database tables."""
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
result = await conn.execute(text("PRAGMA table_info('sessions')"))
|
||||
columns = {row[1] for row in result}
|
||||
expected = {
|
||||
"key",
|
||||
"user_uuid",
|
||||
"credential_uuid",
|
||||
"host",
|
||||
"ip",
|
||||
"user_agent",
|
||||
"renewed",
|
||||
}
|
||||
needs_recreate = False
|
||||
if columns and columns != expected:
|
||||
await conn.execute(text("DROP TABLE sessions"))
|
||||
needs_recreate = True
|
||||
result = await conn.execute(text("PRAGMA table_info('reset_tokens')"))
|
||||
if not list(result):
|
||||
needs_recreate = True
|
||||
if needs_recreate:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
# Run one-time migration to add UTC tzinfo to any naive datetimes
|
||||
await self._migrate_naive_datetimes()
|
||||
|
||||
async def _migrate_naive_datetimes(self) -> None:
|
||||
"""Attach UTC tzinfo to any legacy naive datetime rows.
|
||||
|
||||
SQLite stores datetimes as text; older rows may have been inserted naive.
|
||||
We treat naive timestamps as already UTC and rewrite them in ISO8601 with Z.
|
||||
"""
|
||||
# Helper SQL fragment for detecting naive (no timezone offset) for ISO strings
|
||||
# We only update rows whose textual representation lacks a 'Z' or '+' sign.
|
||||
async with self.session() as session:
|
||||
# Users
|
||||
for model, fields in [
|
||||
(UserModel, ["created_at", "last_seen"]),
|
||||
(CredentialModel, ["created_at", "last_used", "last_verified"]),
|
||||
(SessionModel, ["renewed"]),
|
||||
(ResetTokenModel, ["expiry"]),
|
||||
]:
|
||||
stmt = select(model)
|
||||
result = await session.execute(stmt)
|
||||
rows = result.scalars().all()
|
||||
dirty = False
|
||||
for row in rows:
|
||||
for fname in fields:
|
||||
value = getattr(row, fname, None)
|
||||
if isinstance(value, datetime) and value.tzinfo is None:
|
||||
setattr(row, fname, value.replace(tzinfo=timezone.utc))
|
||||
dirty = True
|
||||
if dirty:
|
||||
# SQLAlchemy autoflush/commit in context manager will persist
|
||||
pass
|
||||
|
||||
async def get_user_by_uuid(self, user_uuid: UUID) -> User:
|
||||
async with self.session() as session:
|
||||
@@ -409,9 +525,11 @@ class DB(DatabaseInterface):
|
||||
credential: Credential,
|
||||
reset_key: bytes | None,
|
||||
session_key: bytes,
|
||||
session_expires: datetime,
|
||||
session_info: dict,
|
||||
*,
|
||||
display_name: str | None = None,
|
||||
host: str | None = None,
|
||||
ip: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> None:
|
||||
"""Atomic credential + (optional old session delete) + (optional rename) + new session."""
|
||||
async with self.session() as session:
|
||||
@@ -434,10 +552,10 @@ class DB(DatabaseInterface):
|
||||
last_verified=credential.last_verified,
|
||||
)
|
||||
)
|
||||
# Delete old session if provided
|
||||
# Delete old reset token if provided
|
||||
if reset_key:
|
||||
await session.execute(
|
||||
delete(SessionModel).where(SessionModel.key == reset_key)
|
||||
delete(ResetTokenModel).where(ResetTokenModel.key == reset_key)
|
||||
)
|
||||
# Optional rename
|
||||
if display_name:
|
||||
@@ -452,8 +570,9 @@ class DB(DatabaseInterface):
|
||||
key=session_key,
|
||||
user_uuid=user_uuid.bytes,
|
||||
credential_uuid=credential.uuid.bytes,
|
||||
expires=session_expires,
|
||||
info=session_info,
|
||||
host=host,
|
||||
ip=ip,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
)
|
||||
# Login side-effects: update user analytics (last_seen + visits increment)
|
||||
@@ -476,17 +595,21 @@ class DB(DatabaseInterface):
|
||||
self,
|
||||
user_uuid: UUID,
|
||||
key: bytes,
|
||||
expires: datetime,
|
||||
info: dict,
|
||||
credential_uuid: UUID | None = None,
|
||||
credential_uuid: UUID,
|
||||
host: str,
|
||||
ip: str,
|
||||
user_agent: str,
|
||||
renewed: datetime,
|
||||
) -> None:
|
||||
async with self.session() as session:
|
||||
session_model = SessionModel(
|
||||
key=key,
|
||||
user_uuid=user_uuid.bytes,
|
||||
credential_uuid=credential_uuid.bytes if credential_uuid else None,
|
||||
expires=expires,
|
||||
info=info,
|
||||
credential_uuid=credential_uuid.bytes,
|
||||
host=host,
|
||||
ip=ip,
|
||||
user_agent=user_agent,
|
||||
renewed=renewed,
|
||||
)
|
||||
session.add(session_model)
|
||||
|
||||
@@ -497,29 +620,88 @@ class DB(DatabaseInterface):
|
||||
session_model = result.scalar_one_or_none()
|
||||
|
||||
if session_model:
|
||||
return Session(
|
||||
key=session_model.key,
|
||||
user_uuid=UUID(bytes=session_model.user_uuid),
|
||||
credential_uuid=UUID(bytes=session_model.credential_uuid)
|
||||
if session_model.credential_uuid
|
||||
else None,
|
||||
expires=session_model.expires,
|
||||
info=session_model.info or {},
|
||||
)
|
||||
return session_model.as_dataclass()
|
||||
return None
|
||||
|
||||
async def delete_session(self, key: bytes) -> None:
|
||||
async with self.session() as session:
|
||||
await session.execute(delete(SessionModel).where(SessionModel.key == key))
|
||||
|
||||
async def update_session(self, key: bytes, expires: datetime, info: dict) -> None:
|
||||
async def delete_sessions_for_user(self, user_uuid: UUID) -> None:
|
||||
async with self.session() as session:
|
||||
await session.execute(
|
||||
update(SessionModel)
|
||||
.where(SessionModel.key == key)
|
||||
.values(expires=expires, info=info)
|
||||
delete(SessionModel).where(SessionModel.user_uuid == user_uuid.bytes)
|
||||
)
|
||||
|
||||
async def create_reset_token(
|
||||
self,
|
||||
user_uuid: UUID,
|
||||
key: bytes,
|
||||
expiry: datetime,
|
||||
token_type: str,
|
||||
) -> None:
|
||||
async with self.session() as session:
|
||||
model = ResetTokenModel(
|
||||
key=key,
|
||||
user_uuid=user_uuid.bytes,
|
||||
token_type=token_type,
|
||||
expiry=expiry,
|
||||
)
|
||||
session.add(model)
|
||||
|
||||
async def get_reset_token(self, key: bytes) -> ResetToken | None:
|
||||
async with self.session() as session:
|
||||
stmt = select(ResetTokenModel).where(ResetTokenModel.key == key)
|
||||
result = await session.execute(stmt)
|
||||
model = result.scalar_one_or_none()
|
||||
return model.as_dataclass() if model else None
|
||||
|
||||
async def delete_reset_token(self, key: bytes) -> None:
|
||||
async with self.session() as session:
|
||||
await session.execute(
|
||||
delete(ResetTokenModel).where(ResetTokenModel.key == key)
|
||||
)
|
||||
|
||||
async def update_session(
|
||||
self,
|
||||
key: bytes,
|
||||
*,
|
||||
ip: str,
|
||||
user_agent: str,
|
||||
renewed: datetime,
|
||||
) -> Session | None:
|
||||
async with self.session() as session:
|
||||
model = await session.get(SessionModel, key)
|
||||
if not model:
|
||||
return None
|
||||
model.ip = ip
|
||||
model.user_agent = user_agent
|
||||
model.renewed = renewed
|
||||
await session.flush()
|
||||
return model.as_dataclass()
|
||||
|
||||
async def set_session_host(self, key: bytes, host: str) -> None:
|
||||
async with self.session() as session:
|
||||
model = await session.get(SessionModel, key)
|
||||
if model and model.host is None:
|
||||
model.host = host
|
||||
await session.flush()
|
||||
|
||||
async def list_sessions_for_user(self, user_uuid: UUID) -> list[Session]:
|
||||
async with self.session() as session:
|
||||
stmt = (
|
||||
select(SessionModel)
|
||||
.where(SessionModel.user_uuid == user_uuid.bytes)
|
||||
.order_by(SessionModel.renewed.desc())
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
session_models = [
|
||||
model
|
||||
for model in result.scalars().all()
|
||||
if model.key.startswith(b"sess")
|
||||
]
|
||||
return [model.as_dataclass() for model in session_models]
|
||||
|
||||
# Organization operations
|
||||
async def create_organization(self, org: Org) -> None:
|
||||
async with self.session() as session:
|
||||
@@ -1115,11 +1297,18 @@ class DB(DatabaseInterface):
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
async with self.session() as session:
|
||||
current_time = datetime.now()
|
||||
stmt = delete(SessionModel).where(SessionModel.expires < current_time)
|
||||
await session.execute(stmt)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
session_threshold = current_time - SESSION_LIFETIME
|
||||
await session.execute(
|
||||
delete(SessionModel).where(SessionModel.renewed < session_threshold)
|
||||
)
|
||||
await session.execute(
|
||||
delete(ResetTokenModel).where(ResetTokenModel.expiry < current_time)
|
||||
)
|
||||
|
||||
async def get_session_context(self, session_key: bytes) -> SessionContext | None:
|
||||
async def get_session_context(
|
||||
self, session_key: bytes, host: str | None = None
|
||||
) -> SessionContext | None:
|
||||
"""Get complete session context including user, organization, role, and permissions.
|
||||
|
||||
Uses efficient JOINs to retrieve all related data in a single database query.
|
||||
@@ -1156,15 +1345,18 @@ class DB(DatabaseInterface):
|
||||
session_model, user_model, role_model, org_model, _ = first_row
|
||||
|
||||
# Create the session object
|
||||
session_obj = Session(
|
||||
key=session_model.key,
|
||||
user_uuid=UUID(bytes=session_model.user_uuid),
|
||||
credential_uuid=UUID(bytes=session_model.credential_uuid)
|
||||
if session_model.credential_uuid
|
||||
else None,
|
||||
expires=session_model.expires,
|
||||
info=session_model.info or {},
|
||||
)
|
||||
if host is not None:
|
||||
if session_model.host is None:
|
||||
await session.execute(
|
||||
update(SessionModel)
|
||||
.where(SessionModel.key == session_key)
|
||||
.values(host=host)
|
||||
)
|
||||
session_model.host = host
|
||||
elif session_model.host != host:
|
||||
return None
|
||||
|
||||
session_obj = session_model.as_dataclass()
|
||||
|
||||
# Create the user object
|
||||
user_obj = user_model.as_dataclass()
|
||||
|
||||
Reference in New Issue
Block a user