From 2e4ff30beaf556fe58cbfddf3e4f9daefdec07d3 Mon Sep 17 00:00:00 2001 From: Leo Vasanko Date: Thu, 7 Aug 2025 10:42:49 -0600 Subject: [PATCH] Users always belong to one Org. Implement a DB function to fetch all data relevant to a session. --- passkey/db/__init__.py | 26 +++++-- passkey/db/sql.py | 153 ++++++++++++++++++++++++++++++++--------- 2 files changed, 143 insertions(+), 36 deletions(-) diff --git a/passkey/db/__init__.py b/passkey/db/__init__.py index 72a106a..97819f1 100644 --- a/passkey/db/__init__.py +++ b/passkey/db/__init__.py @@ -15,7 +15,7 @@ from uuid import UUID class User: uuid: UUID display_name: str - org_uuid: UUID | None = None + org_uuid: UUID role: str | None = None created_at: datetime | None = None last_seen: datetime | None = None @@ -62,6 +62,17 @@ class Session: credential_uuid: UUID | None = None +@dataclass +class SessionContext: + """Complete session context with user, organization, role, and permissions.""" + + session: Session + user: User + organization: Org + role: str | None = None + permissions: list[Permission] | None = None + + class DatabaseInterface(ABC): """Abstract base class defining the database interface. @@ -160,11 +171,13 @@ class DatabaseInterface(ABC): """Set a user's organization and role.""" @abstractmethod - async def remove_user_from_organization(self, user_uuid: UUID) -> None: - """Remove a user from their organization.""" + async def transfer_user_to_organization( + self, user_uuid: UUID, new_org_id: str, new_role: str | None = None + ) -> None: + """Transfer a user to another organization with an optional role.""" @abstractmethod - async def get_user_organization(self, user_uuid: UUID) -> tuple[Org, str] | None: + async def get_user_organization(self, user_uuid: UUID) -> tuple[Org, str]: """Get the organization and role for a user.""" @abstractmethod @@ -231,11 +244,16 @@ class DatabaseInterface(ABC): ) -> None: """Create a new user and their first credential in a transaction.""" + @abstractmethod + async def get_session_context(self, session_key: bytes) -> SessionContext | None: + """Get complete session context including user, organization, role, and permissions.""" + __all__ = [ "User", "Credential", "Session", + "SessionContext", "Org", "Permission", "DatabaseInterface", diff --git a/passkey/db/sql.py b/passkey/db/sql.py index 695efd2..209977c 100644 --- a/passkey/db/sql.py +++ b/passkey/db/sql.py @@ -24,7 +24,15 @@ from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from ..globals import db -from . import Credential, DatabaseInterface, Org, Permission, Session, User +from . import ( + Credential, + DatabaseInterface, + Org, + Permission, + Session, + SessionContext, + User, +) DB_PATH = "sqlite+aiosqlite:///passkey-auth.sqlite" @@ -55,20 +63,6 @@ class OrgPermission(Base): ForeignKey("permissions.id", ondelete="CASCADE"), primary_key=True, ) - """Permissions each Org is allowed to grant to its roles.""" - - __tablename__ = "org_permissions" - - org_uuid: Mapped[bytes] = mapped_column( - LargeBinary(16), - ForeignKey("orgs.uuid", ondelete="CASCADE"), - primary_key=True, - ) - permission_id: Mapped[str] = mapped_column( - String(32), - ForeignKey("permissions.id", ondelete="CASCADE"), - primary_key=True, - ) class PermissionModel(Base): @@ -90,8 +84,8 @@ class UserModel(Base): uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True) display_name: Mapped[str] = mapped_column(String, nullable=False) - org_uuid: Mapped[bytes | None] = mapped_column( - LargeBinary(16), ForeignKey("orgs.uuid", ondelete="SET NULL"), nullable=True + org_uuid: Mapped[bytes] = mapped_column( + LargeBinary(16), ForeignKey("orgs.uuid", ondelete="CASCADE"), nullable=False ) role: Mapped[str | None] = mapped_column(String, nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now) @@ -164,9 +158,7 @@ class DB(DatabaseInterface): return User( uuid=UUID(bytes=user_model.uuid), display_name=user_model.display_name, - org_uuid=UUID(bytes=user_model.org_uuid) - if user_model.org_uuid - else None, + org_uuid=UUID(bytes=user_model.org_uuid), role=user_model.role, created_at=user_model.created_at, last_seen=user_model.last_seen, @@ -179,7 +171,7 @@ class DB(DatabaseInterface): user_model = UserModel( uuid=user.uuid.bytes, display_name=user.display_name, - org_uuid=user.org_uuid.bytes if user.org_uuid else None, + org_uuid=user.org_uuid.bytes, role=user.role, created_at=user.created_at or datetime.now(), last_seen=user.last_seen, @@ -280,7 +272,7 @@ class DB(DatabaseInterface): user_model = UserModel( uuid=user.uuid.bytes, display_name=user.display_name, - org_uuid=user.org_uuid.bytes if user.org_uuid else None, + org_uuid=user.org_uuid.bytes, role=user.role, created_at=user.created_at or datetime.now(), last_seen=user.last_seen, @@ -432,24 +424,39 @@ class DB(DatabaseInterface): ) await session.execute(stmt) - async def remove_user_from_organization(self, user_uuid: UUID) -> None: + async def transfer_user_to_organization( + self, user_uuid: UUID, new_org_id: str, new_role: str | None = None + ) -> None: async with self.session() as session: - # Clear the user's organization and role + # Convert string ID to UUID bytes for lookup + new_org_uuid = UUID(new_org_id) + + # Verify the new organization exists + org_stmt = select(OrgModel).where(OrgModel.uuid == new_org_uuid.bytes) + org_result = await session.execute(org_stmt) + org_model = org_result.scalar_one_or_none() + + if not org_model: + raise ValueError("Target organization not found") + + # Update the user's organization and role stmt = ( update(UserModel) .where(UserModel.uuid == user_uuid.bytes) - .values(org_uuid=None, role=None) + .values(org_uuid=new_org_uuid.bytes, role=new_role) ) - await session.execute(stmt) + result = await session.execute(stmt) + if result.rowcount == 0: + raise ValueError("User not found") - async def get_user_organization(self, user_uuid: UUID) -> tuple[Org, str] | None: + async def get_user_organization(self, user_uuid: UUID) -> tuple[Org, str]: async with self.session() as session: stmt = select(UserModel).where(UserModel.uuid == user_uuid.bytes) result = await session.execute(stmt) user_model = result.scalar_one_or_none() - if not user_model or not user_model.org_uuid: - return None + if not user_model: + raise ValueError("User not found") # Fetch the organization details org_stmt = select(OrgModel).where(OrgModel.uuid == user_model.org_uuid) @@ -474,9 +481,7 @@ class DB(DatabaseInterface): user = User( uuid=UUID(bytes=user_model.uuid), display_name=user_model.display_name, - org_uuid=UUID(bytes=user_model.org_uuid) - if user_model.org_uuid - else None, + org_uuid=UUID(bytes=user_model.org_uuid), role=user_model.role, created_at=user_model.created_at, last_seen=user_model.last_seen, @@ -643,3 +648,87 @@ class DB(DatabaseInterface): current_time = datetime.now() stmt = delete(SessionModel).where(SessionModel.expires < current_time) await session.execute(stmt) + + async def get_session_context(self, session_key: bytes) -> SessionContext | None: + """Get complete session context including user, organization, role, and permissions. + + Uses efficient JOINs to retrieve all related data in a single database query. + """ + async with self.session() as session: + # Build a query that joins sessions, users, organizations, org_permissions, and permissions + stmt = ( + select( + SessionModel, + UserModel, + OrgModel, + PermissionModel, + ) + .select_from(SessionModel) + .join(UserModel, SessionModel.user_uuid == UserModel.uuid) + .join(OrgModel, UserModel.org_uuid == OrgModel.uuid) + .outerjoin(OrgPermission, OrgModel.uuid == OrgPermission.org_uuid) + .outerjoin( + PermissionModel, OrgPermission.permission_id == PermissionModel.id + ) + .where(SessionModel.key == session_key) + ) + + result = await session.execute(stmt) + rows = result.fetchall() + + if not rows: + return None + + # Extract the first row to get session and user data + first_row = rows[0] + session_model, user_model, org_model, _ = first_row + + # Create the session object + session_obj = Session( + key=session_model.key, + user_uuid=UUID(bytes=session_model.user_uuid), + credential_uuid=UUID(bytes=session_model.credential_uuid) + if session_model.credential_uuid + else None, + expires=session_model.expires, + info=session_model.info or {}, + ) + + # Create the user object + user_obj = User( + uuid=UUID(bytes=user_model.uuid), + display_name=user_model.display_name, + org_uuid=UUID(bytes=user_model.org_uuid), + role=user_model.role, + created_at=user_model.created_at, + last_seen=user_model.last_seen, + visits=user_model.visits, + ) + + # Create organization object (always exists now) + organization = Org( + id=str(UUID(bytes=org_model.uuid)), + options=org_model.options, + ) + + # Collect all unique permissions + permissions = [] + seen_permission_ids = set() + for row in rows: + _, _, _, permission_model = row + if permission_model and permission_model.id not in seen_permission_ids: + permissions.append( + Permission( + id=permission_model.id, + display_name=permission_model.display_name, + ) + ) + seen_permission_ids.add(permission_model.id) + + return SessionContext( + session=session_obj, + user=user_obj, + organization=organization, + role=user_model.role, + permissions=permissions if permissions else None, + )