From c42864794aa27e1cfe8b7fd8f771e3bfc1e13e03 Mon Sep 17 00:00:00 2001 From: Leo Vasanko Date: Tue, 5 Aug 2025 12:16:02 -0600 Subject: [PATCH] Add organisations on DB --- passkey/db/__init__.py | 50 +++++++++- passkey/db/sql.py | 202 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 246 insertions(+), 6 deletions(-) diff --git a/passkey/db/__init__.py b/passkey/db/__init__.py index 24655a1..7749fb5 100644 --- a/passkey/db/__init__.py +++ b/passkey/db/__init__.py @@ -13,8 +13,6 @@ from uuid import UUID @dataclass class User: - """User data structure.""" - user_uuid: UUID user_name: str created_at: datetime | None = None @@ -24,10 +22,8 @@ class User: @dataclass class Credential: - """Credential data structure.""" - uuid: UUID - credential_id: bytes + credential_id: bytes # Long binary ID passed from the authenticator user_uuid: UUID aaguid: UUID public_key: bytes @@ -37,6 +33,14 @@ class Credential: last_verified: datetime | None = None +@dataclass +class Org: + """Organization data structure.""" + + id: str # ASCII primary key + options: dict + + @dataclass class Session: """Session data structure.""" @@ -122,6 +126,41 @@ class DatabaseInterface(ABC): async def cleanup(self) -> None: """Called periodically to clean up expired records.""" + # Organization operations + @abstractmethod + async def create_organization(self, organization: Org) -> None: + """Create a new organization.""" + + @abstractmethod + async def get_organization(self, org_id: str) -> Org: + """Get organization by ID.""" + + @abstractmethod + async def update_organization(self, organization: Org) -> None: + """Update organization options.""" + + @abstractmethod + async def delete_organization(self, org_id: str) -> None: + """Delete organization by ID.""" + + @abstractmethod + async def add_user_to_organization( + self, user_uuid: UUID, org_id: str, role: str + ) -> None: + """Add a user to an organization with a specific role.""" + + @abstractmethod + async def remove_user_from_organization(self, user_uuid: UUID, org_id: str) -> None: + """Remove a user from an organization.""" + + @abstractmethod + async def get_user_org_role(self, user_uuid: UUID) -> list[tuple[Org, str]]: + """Get all organizations for a user with their roles.""" + + @abstractmethod + async def get_organization_users(self, org_id: str) -> list[tuple[User, str]]: + """Get all users in an organization with their roles.""" + # Combined operations @abstractmethod async def login(self, user_uuid: UUID, credential: Credential) -> None: @@ -160,6 +199,7 @@ __all__ = [ "User", "Credential", "Session", + "Org", "DatabaseInterface", "db", ] diff --git a/passkey/db/sql.py b/passkey/db/sql.py index 3bb8dbe..9d6bcef 100644 --- a/passkey/db/sql.py +++ b/passkey/db/sql.py @@ -23,7 +23,7 @@ from sqlalchemy.dialects.sqlite import BLOB, JSON from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship -from . import Credential, DatabaseInterface, Session, User, db +from . import Credential, DatabaseInterface, Org, Session, User, db DB_PATH = "sqlite+aiosqlite:///passkey-auth.sqlite" @@ -38,6 +38,39 @@ class Base(DeclarativeBase): pass +# Association model for many-to-many relationship between users and organizations with roles +class UserRole(Base): + __tablename__ = "user_roles" + + user_uuid: Mapped[bytes] = mapped_column( + LargeBinary(16), + ForeignKey("users.user_uuid", ondelete="CASCADE"), + primary_key=True, + ) + org_uuid: Mapped[bytes] = mapped_column( + LargeBinary(16), + ForeignKey("orgs.uuid", ondelete="CASCADE"), + primary_key=True, + ) + role: Mapped[str] = mapped_column(String, nullable=False) + + # Relationships to the actual models + user: Mapped["UserModel"] = relationship("UserModel", back_populates="user_orgs") + org: Mapped["OrgModel"] = relationship("OrgModel", back_populates="user_orgs") + + +class OrgModel(Base): + __tablename__ = "orgs" + + uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True) + options: Mapped[dict] = mapped_column(JSON, nullable=False, default=dict) + + # Relationship to user-org associations + user_orgs: Mapped[list["UserRole"]] = relationship( + "UserRoleModel", back_populates="org", cascade="all, delete-orphan" + ) + + class UserModel(Base): __tablename__ = "users" @@ -52,6 +85,11 @@ class UserModel(Base): "CredentialModel", back_populates="user", cascade="all, delete-orphan" ) + # Relationship to user-org associations + user_orgs: Mapped[list["UserRole"]] = relationship( + "UserOrgModel", back_populates="user", cascade="all, delete-orphan" + ) + class CredentialModel(Base): __tablename__ = "credentials" @@ -311,6 +349,168 @@ class DB(DatabaseInterface): .values(expires=expires, info=info) ) + # Organization operations + async def create_organization(self, organization: Org) -> None: + async with self.session() as session: + # Convert string ID to UUID bytes for storage + org_uuid = UUID(organization.id) + org_model = OrgModel( + uuid=org_uuid.bytes, + options=organization.options, + ) + session.add(org_model) + + async def get_organization(self, org_id: str) -> Org: + async with self.session() as session: + # Convert string ID to UUID bytes for lookup + org_uuid = UUID(org_id) + stmt = select(OrgModel).where(OrgModel.uuid == org_uuid.bytes) + result = await session.execute(stmt) + org_model = result.scalar_one_or_none() + + if org_model: + # Convert UUID bytes back to string for the interface + return Org( + id=str(UUID(bytes=org_model.uuid)), + options=org_model.options, + ) + raise ValueError("Organization not found") + + async def update_organization(self, organization: Org) -> None: + async with self.session() as session: + # Convert string ID to UUID bytes for lookup + org_uuid = UUID(organization.id) + stmt = ( + update(OrgModel) + .where(OrgModel.uuid == org_uuid.bytes) + .values(options=organization.options) + ) + await session.execute(stmt) + + async def delete_organization(self, org_id: str) -> None: + async with self.session() as session: + # Convert string ID to UUID bytes for lookup + org_uuid = UUID(org_id) + stmt = delete(OrgModel).where(OrgModel.uuid == org_uuid.bytes) + await session.execute(stmt) + + async def add_user_to_organization( + self, user_uuid: UUID, org_id: str, role: str + ) -> None: + async with self.session() as session: + # Get user and organization models + user_stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes) + user_result = await session.execute(user_stmt) + user_model = user_result.scalar_one_or_none() + + # Convert string ID to UUID bytes for lookup + org_uuid = UUID(org_id) + org_stmt = select(OrgModel).where(OrgModel.uuid == org_uuid.bytes) + org_result = await session.execute(org_stmt) + org_model = org_result.scalar_one_or_none() + + if not user_model: + raise ValueError("User not found") + if not org_model: + raise ValueError("Organization not found") + + # Create the user-org relationship with role + user_org = UserRole( + user_uuid=user_uuid.bytes, org_uuid=org_uuid.bytes, role=role + ) + session.add(user_org) + + async def remove_user_from_organization(self, user_uuid: UUID, org_id: str) -> None: + async with self.session() as session: + # Convert string ID to UUID bytes for lookup + org_uuid = UUID(org_id) + # Delete the user-org relationship + stmt = delete(UserRole).where( + UserRole.user_uuid == user_uuid.bytes, + UserRole.org_uuid == org_uuid.bytes, + ) + await session.execute(stmt) + + async def get_user_org_role(self, user_uuid: UUID) -> list[tuple[Org, str]]: + async with self.session() as session: + stmt = select(UserRole).where(UserRole.user_uuid == user_uuid.bytes) + result = await session.execute(stmt) + user_org_models = result.scalars().all() + + # Fetch the organization details for each user-org relationship + org_role_pairs = [] + for user_org in user_org_models: + org_stmt = select(OrgModel).where(OrgModel.uuid == user_org.org_uuid) + org_result = await session.execute(org_stmt) + org_model = org_result.scalar_one() + + # Convert UUID bytes back to string for the interface + org = Org(id=str(UUID(bytes=org_model.uuid)), options=org_model.options) + org_role_pairs.append((org, user_org.role)) + + return org_role_pairs + + async def get_organization_users(self, org_id: str) -> list[tuple[User, str]]: + async with self.session() as session: + # Convert string ID to UUID bytes for lookup + org_uuid = UUID(org_id) + stmt = select(UserRole).where(UserRole.org_uuid == org_uuid.bytes) + result = await session.execute(stmt) + user_org_models = result.scalars().all() + + # Fetch the user details for each user-org relationship + user_role_pairs = [] + for user_org in user_org_models: + user_stmt = select(UserModel).where( + UserModel.user_uuid == user_org.user_uuid + ) + user_result = await session.execute(user_stmt) + user_model = user_result.scalar_one() + + user = User( + user_uuid=UUID(bytes=user_model.user_uuid), + user_name=user_model.user_name, + created_at=user_model.created_at, + last_seen=user_model.last_seen, + visits=user_model.visits, + ) + user_role_pairs.append((user, user_org.role)) + + return user_role_pairs + + async def get_user_role_in_organization( + self, user_uuid: UUID, org_id: str + ) -> str | None: + """Get a user's role in a specific organization.""" + async with self.session() as session: + # Convert string ID to UUID bytes for lookup + org_uuid = UUID(org_id) + stmt = select(UserRole.role).where( + UserRole.user_uuid == user_uuid.bytes, + UserRole.org_uuid == org_uuid.bytes, + ) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + async def update_user_role_in_organization( + self, user_uuid: UUID, org_id: str, new_role: str + ) -> None: + """Update a user's role in an organization.""" + async with self.session() as session: + # Convert string ID to UUID bytes for lookup + org_uuid = UUID(org_id) + stmt = ( + update(UserRole) + .where( + UserRole.user_uuid == user_uuid.bytes, + UserRole.org_uuid == org_uuid.bytes, + ) + .values(role=new_role) + ) + result = await session.execute(stmt) + if result.rowcount == 0: + raise ValueError("User is not a member of this organization") + async def cleanup(self) -> None: async with self.session() as session: current_time = datetime.now()