Users always belong to one Org. Implement a DB function to fetch all data relevant to a session.

This commit is contained in:
Leo Vasanko 2025-08-07 10:42:49 -06:00
parent 2e3ce32779
commit 2e4ff30bea
2 changed files with 143 additions and 36 deletions

View File

@ -15,7 +15,7 @@ from uuid import UUID
class User:
uuid: UUID
display_name: str
org_uuid: UUID | None = None
org_uuid: UUID
role: str | None = None
created_at: datetime | None = None
last_seen: datetime | None = None
@ -62,6 +62,17 @@ class Session:
credential_uuid: UUID | None = None
@dataclass
class SessionContext:
"""Complete session context with user, organization, role, and permissions."""
session: Session
user: User
organization: Org
role: str | None = None
permissions: list[Permission] | None = None
class DatabaseInterface(ABC):
"""Abstract base class defining the database interface.
@ -160,11 +171,13 @@ class DatabaseInterface(ABC):
"""Set a user's organization and role."""
@abstractmethod
async def remove_user_from_organization(self, user_uuid: UUID) -> None:
"""Remove a user from their organization."""
async def transfer_user_to_organization(
self, user_uuid: UUID, new_org_id: str, new_role: str | None = None
) -> None:
"""Transfer a user to another organization with an optional role."""
@abstractmethod
async def get_user_organization(self, user_uuid: UUID) -> tuple[Org, str] | None:
async def get_user_organization(self, user_uuid: UUID) -> tuple[Org, str]:
"""Get the organization and role for a user."""
@abstractmethod
@ -231,11 +244,16 @@ class DatabaseInterface(ABC):
) -> None:
"""Create a new user and their first credential in a transaction."""
@abstractmethod
async def get_session_context(self, session_key: bytes) -> SessionContext | None:
"""Get complete session context including user, organization, role, and permissions."""
__all__ = [
"User",
"Credential",
"Session",
"SessionContext",
"Org",
"Permission",
"DatabaseInterface",

View File

@ -24,7 +24,15 @@ from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from ..globals import db
from . import Credential, DatabaseInterface, Org, Permission, Session, User
from . import (
Credential,
DatabaseInterface,
Org,
Permission,
Session,
SessionContext,
User,
)
DB_PATH = "sqlite+aiosqlite:///passkey-auth.sqlite"
@ -55,20 +63,6 @@ class OrgPermission(Base):
ForeignKey("permissions.id", ondelete="CASCADE"),
primary_key=True,
)
"""Permissions each Org is allowed to grant to its roles."""
__tablename__ = "org_permissions"
org_uuid: Mapped[bytes] = mapped_column(
LargeBinary(16),
ForeignKey("orgs.uuid", ondelete="CASCADE"),
primary_key=True,
)
permission_id: Mapped[str] = mapped_column(
String(32),
ForeignKey("permissions.id", ondelete="CASCADE"),
primary_key=True,
)
class PermissionModel(Base):
@ -90,8 +84,8 @@ class UserModel(Base):
uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
display_name: Mapped[str] = mapped_column(String, nullable=False)
org_uuid: Mapped[bytes | None] = mapped_column(
LargeBinary(16), ForeignKey("orgs.uuid", ondelete="SET NULL"), nullable=True
org_uuid: Mapped[bytes] = mapped_column(
LargeBinary(16), ForeignKey("orgs.uuid", ondelete="CASCADE"), nullable=False
)
role: Mapped[str | None] = mapped_column(String, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
@ -164,9 +158,7 @@ class DB(DatabaseInterface):
return User(
uuid=UUID(bytes=user_model.uuid),
display_name=user_model.display_name,
org_uuid=UUID(bytes=user_model.org_uuid)
if user_model.org_uuid
else None,
org_uuid=UUID(bytes=user_model.org_uuid),
role=user_model.role,
created_at=user_model.created_at,
last_seen=user_model.last_seen,
@ -179,7 +171,7 @@ class DB(DatabaseInterface):
user_model = UserModel(
uuid=user.uuid.bytes,
display_name=user.display_name,
org_uuid=user.org_uuid.bytes if user.org_uuid else None,
org_uuid=user.org_uuid.bytes,
role=user.role,
created_at=user.created_at or datetime.now(),
last_seen=user.last_seen,
@ -280,7 +272,7 @@ class DB(DatabaseInterface):
user_model = UserModel(
uuid=user.uuid.bytes,
display_name=user.display_name,
org_uuid=user.org_uuid.bytes if user.org_uuid else None,
org_uuid=user.org_uuid.bytes,
role=user.role,
created_at=user.created_at or datetime.now(),
last_seen=user.last_seen,
@ -432,24 +424,39 @@ class DB(DatabaseInterface):
)
await session.execute(stmt)
async def remove_user_from_organization(self, user_uuid: UUID) -> None:
async def transfer_user_to_organization(
self, user_uuid: UUID, new_org_id: str, new_role: str | None = None
) -> None:
async with self.session() as session:
# Clear the user's organization and role
# Convert string ID to UUID bytes for lookup
new_org_uuid = UUID(new_org_id)
# Verify the new organization exists
org_stmt = select(OrgModel).where(OrgModel.uuid == new_org_uuid.bytes)
org_result = await session.execute(org_stmt)
org_model = org_result.scalar_one_or_none()
if not org_model:
raise ValueError("Target organization not found")
# Update the user's organization and role
stmt = (
update(UserModel)
.where(UserModel.uuid == user_uuid.bytes)
.values(org_uuid=None, role=None)
.values(org_uuid=new_org_uuid.bytes, role=new_role)
)
await session.execute(stmt)
result = await session.execute(stmt)
if result.rowcount == 0:
raise ValueError("User not found")
async def get_user_organization(self, user_uuid: UUID) -> tuple[Org, str] | None:
async def get_user_organization(self, user_uuid: UUID) -> tuple[Org, str]:
async with self.session() as session:
stmt = select(UserModel).where(UserModel.uuid == user_uuid.bytes)
result = await session.execute(stmt)
user_model = result.scalar_one_or_none()
if not user_model or not user_model.org_uuid:
return None
if not user_model:
raise ValueError("User not found")
# Fetch the organization details
org_stmt = select(OrgModel).where(OrgModel.uuid == user_model.org_uuid)
@ -474,9 +481,7 @@ class DB(DatabaseInterface):
user = User(
uuid=UUID(bytes=user_model.uuid),
display_name=user_model.display_name,
org_uuid=UUID(bytes=user_model.org_uuid)
if user_model.org_uuid
else None,
org_uuid=UUID(bytes=user_model.org_uuid),
role=user_model.role,
created_at=user_model.created_at,
last_seen=user_model.last_seen,
@ -643,3 +648,87 @@ class DB(DatabaseInterface):
current_time = datetime.now()
stmt = delete(SessionModel).where(SessionModel.expires < current_time)
await session.execute(stmt)
async def get_session_context(self, session_key: bytes) -> SessionContext | None:
"""Get complete session context including user, organization, role, and permissions.
Uses efficient JOINs to retrieve all related data in a single database query.
"""
async with self.session() as session:
# Build a query that joins sessions, users, organizations, org_permissions, and permissions
stmt = (
select(
SessionModel,
UserModel,
OrgModel,
PermissionModel,
)
.select_from(SessionModel)
.join(UserModel, SessionModel.user_uuid == UserModel.uuid)
.join(OrgModel, UserModel.org_uuid == OrgModel.uuid)
.outerjoin(OrgPermission, OrgModel.uuid == OrgPermission.org_uuid)
.outerjoin(
PermissionModel, OrgPermission.permission_id == PermissionModel.id
)
.where(SessionModel.key == session_key)
)
result = await session.execute(stmt)
rows = result.fetchall()
if not rows:
return None
# Extract the first row to get session and user data
first_row = rows[0]
session_model, user_model, org_model, _ = first_row
# Create the session object
session_obj = Session(
key=session_model.key,
user_uuid=UUID(bytes=session_model.user_uuid),
credential_uuid=UUID(bytes=session_model.credential_uuid)
if session_model.credential_uuid
else None,
expires=session_model.expires,
info=session_model.info or {},
)
# Create the user object
user_obj = User(
uuid=UUID(bytes=user_model.uuid),
display_name=user_model.display_name,
org_uuid=UUID(bytes=user_model.org_uuid),
role=user_model.role,
created_at=user_model.created_at,
last_seen=user_model.last_seen,
visits=user_model.visits,
)
# Create organization object (always exists now)
organization = Org(
id=str(UUID(bytes=org_model.uuid)),
options=org_model.options,
)
# Collect all unique permissions
permissions = []
seen_permission_ids = set()
for row in rows:
_, _, _, permission_model = row
if permission_model and permission_model.id not in seen_permission_ids:
permissions.append(
Permission(
id=permission_model.id,
display_name=permission_model.display_name,
)
)
seen_permission_ids.add(permission_model.id)
return SessionContext(
session=session_obj,
user=user_obj,
organization=organization,
role=user_model.role,
permissions=permissions if permissions else None,
)