Support for adding permissions on roles and orgs.

This commit is contained in:
Leo Vasanko 2025-08-12 13:13:35 -07:00
parent d2a6bfd2a5
commit 02ac4adc77
4 changed files with 273 additions and 62 deletions

View File

@ -66,7 +66,12 @@ async def bootstrap_system(
) )
await globals.db.instance.create_permission(perm1) await globals.db.instance.create_permission(perm1)
role = Role(uuid7.create(), org.uuid, "Administration") # Allow this org to grant admin permissions
await globals.db.instance.add_permission_to_organization(str(org.uuid), perm0.id)
await globals.db.instance.add_permission_to_organization(str(org.uuid), perm1.id)
# Create an Administration role granting both org and global admin
role = Role(uuid7.create(), org.uuid, "Administration", permissions=[perm0.id, perm1.id])
await globals.db.instance.create_role(role) await globals.db.instance.create_role(role)
user = User( user = User(
@ -110,9 +115,9 @@ async def check_admin_credentials() -> bool:
# Get users from the first organization with admin permission # Get users from the first organization with admin permission
org_users = await globals.db.instance.get_organization_users( org_users = await globals.db.instance.get_organization_users(
permission_orgs[0].id str(permission_orgs[0].uuid)
) )
admin_users = [user for user, role in org_users if role == "Admin"] admin_users = [user for user, role in org_users if role == "Administration"]
if not admin_users: if not admin_users:
return False return False

View File

@ -6,7 +6,7 @@ users, credentials, and sessions in a WebAuthn authentication system.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from uuid import UUID from uuid import UUID
@ -22,15 +22,18 @@ class Role:
uuid: UUID uuid: UUID
org_uuid: UUID org_uuid: UUID
display_name: str display_name: str
permissions: list[Permission] # List of permission IDs this role grants to its members
permissions: list[str] = field(default_factory=list) # permission IDs
@dataclass @dataclass
class Org: class Org:
uuid: UUID uuid: UUID
display_name: str display_name: str
permissions: list[Permission] # All that the Org can grant # All permission IDs that the Org is allowed to grant to its roles
roles: list[Role] permissions: list[str] = field(default_factory=list) # permission IDs
# Roles belonging to this org
roles: list[Role] = field(default_factory=list)
@dataclass @dataclass
@ -160,7 +163,7 @@ class DatabaseInterface(ABC):
@abstractmethod @abstractmethod
async def get_organization(self, org_id: str) -> Org: async def get_organization(self, org_id: str) -> Org:
"""Get organization by ID.""" """Get organization by ID, including its permission IDs and roles (with their permission IDs)."""
@abstractmethod @abstractmethod
async def update_organization(self, org: Org) -> None: async def update_organization(self, org: Org) -> None:
@ -239,6 +242,23 @@ class DatabaseInterface(ABC):
async def get_permission_organizations(self, permission_id: str) -> list[Org]: async def get_permission_organizations(self, permission_id: str) -> list[Org]:
"""Get all organizations that have a specific permission.""" """Get all organizations that have a specific permission."""
# Role-permission operations
@abstractmethod
async def add_permission_to_role(self, role_uuid: UUID, permission_id: str) -> None:
"""Add a permission to a role."""
@abstractmethod
async def remove_permission_from_role(self, role_uuid: UUID, permission_id: str) -> None:
"""Remove a permission from a role."""
@abstractmethod
async def get_role_permissions(self, role_uuid: UUID) -> list[Permission]:
"""List all permissions granted to a role."""
@abstractmethod
async def get_permission_roles(self, permission_id: str) -> list[Role]:
"""List all roles that grant a permission."""
# Combined operations # Combined operations
@abstractmethod @abstractmethod
async def login(self, user_uuid: UUID, credential: Credential) -> None: async def login(self, user_uuid: UUID, credential: Credential) -> None:
@ -261,6 +281,7 @@ __all__ = [
"Session", "Session",
"SessionContext", "SessionContext",
"Org", "Org",
"Role",
"Permission", "Permission",
"DatabaseInterface", "DatabaseInterface",
] ]

View File

@ -54,6 +54,7 @@ class OrgModel(Base):
display_name: Mapped[str] = mapped_column(String, nullable=False) display_name: Mapped[str] = mapped_column(String, nullable=False)
def as_dataclass(self): def as_dataclass(self):
# Base Org without permissions/roles (filled by data accessors)
return Org(UUID(bytes=self.uuid), self.display_name) return Org(UUID(bytes=self.uuid), self.display_name)
@staticmethod @staticmethod
@ -71,6 +72,7 @@ class RoleModel(Base):
display_name: Mapped[str] = mapped_column(String, nullable=False) display_name: Mapped[str] = mapped_column(String, nullable=False)
def as_dataclass(self): def as_dataclass(self):
# Base Role without permissions (filled by data accessors)
return Role( return Role(
uuid=UUID(bytes=self.uuid), uuid=UUID(bytes=self.uuid),
org_uuid=UUID(bytes=self.org_uuid), org_uuid=UUID(bytes=self.org_uuid),
@ -258,7 +260,17 @@ class DB(DatabaseInterface):
async def create_role(self, role: Role) -> None: async def create_role(self, role: Role) -> None:
async with self.session() as session: async with self.session() as session:
# Create role record
session.add(RoleModel.from_dataclass(role)) session.add(RoleModel.from_dataclass(role))
# Persist role permissions
if role.permissions:
for perm_id in role.permissions:
session.add(
RolePermission(
role_uuid=role.uuid.bytes,
permission_id=perm_id,
)
)
async def create_credential(self, credential: Credential) -> None: async def create_credential(self, credential: Credential) -> None:
async with self.session() as session: async with self.session() as session:
@ -429,6 +441,12 @@ class DB(DatabaseInterface):
display_name=org.display_name, display_name=org.display_name,
) )
session.add(org_model) session.add(org_model)
# Persist org permissions the org is allowed to grant
if org.permissions:
for perm_id in org.permissions:
session.add(
OrgPermission(org_uuid=org.uuid.bytes, permission_id=perm_id)
)
async def get_organization(self, org_id: str) -> Org: async def get_organization(self, org_id: str) -> Org:
async with self.session() as session: async with self.session() as session:
@ -441,7 +459,34 @@ class DB(DatabaseInterface):
if not org_model: if not org_model:
raise ValueError("Organization not found") raise ValueError("Organization not found")
return org_model.as_dataclass() # Build Org with permissions and roles
org_dc = org_model.as_dataclass()
# Load org permission IDs
perm_stmt = select(OrgPermission.permission_id).where(
OrgPermission.org_uuid == org_uuid.bytes
)
perm_result = await session.execute(perm_stmt)
org_dc.permissions = [row[0] for row in perm_result.fetchall()]
# Load roles for org
roles_stmt = select(RoleModel).where(RoleModel.org_uuid == org_uuid.bytes)
roles_result = await session.execute(roles_stmt)
roles_models = roles_result.scalars().all()
roles: list[Role] = []
if roles_models:
# For each role, load permission IDs
for r_model in roles_models:
r_dc = r_model.as_dataclass()
r_perm_stmt = select(RolePermission.permission_id).where(
RolePermission.role_uuid == r_model.uuid
)
r_perm_result = await session.execute(r_perm_stmt)
r_dc.permissions = [row[0] for row in r_perm_result.fetchall()]
roles.append(r_dc)
org_dc.roles = roles
return org_dc
async def update_organization(self, org: Org) -> None: async def update_organization(self, org: Org) -> None:
async with self.session() as session: async with self.session() as session:
@ -451,6 +496,19 @@ class DB(DatabaseInterface):
.values(display_name=org.display_name) .values(display_name=org.display_name)
) )
await session.execute(stmt) await session.execute(stmt)
# Synchronize org permissions join table to match org.permissions
# Delete existing rows for this org
await session.execute(
delete(OrgPermission).where(OrgPermission.org_uuid == org.uuid.bytes)
)
# Insert new rows
if org.permissions:
for perm_id in org.permissions:
await session.merge(
OrgPermission(
org_uuid=org.uuid.bytes, permission_id=perm_id
)
)
async def delete_organization(self, org_uuid: UUID) -> None: async def delete_organization(self, org_uuid: UUID) -> None:
async with self.session() as session: async with self.session() as session:
@ -459,9 +517,10 @@ class DB(DatabaseInterface):
await session.execute(stmt) await session.execute(stmt)
async def add_user_to_organization( async def add_user_to_organization(
self, user_uuid: UUID, org_uuid: UUID, role: str self, user_uuid: UUID, org_id: str, role: str
) -> None: ) -> None:
async with self.session() as session: async with self.session() as session:
org_uuid = UUID(org_id)
# Get user and organization models # Get user and organization models
user_stmt = select(UserModel).where(UserModel.uuid == user_uuid.bytes) user_stmt = select(UserModel).where(UserModel.uuid == user_uuid.bytes)
user_result = await session.execute(user_stmt) user_result = await session.execute(user_stmt)
@ -477,40 +536,32 @@ class DB(DatabaseInterface):
if not org_model: if not org_model:
raise ValueError("Organization not found") raise ValueError("Organization not found")
# Update the user's organization and role # Find the role within this organization by display_name
role_stmt = select(RoleModel).where(
RoleModel.org_uuid == org_uuid.bytes,
RoleModel.display_name == role,
)
role_result = await session.execute(role_stmt)
role_model = role_result.scalar_one_or_none()
if not role_model:
raise ValueError("Role not found in organization")
# Update the user's role assignment
stmt = ( stmt = (
update(UserModel) update(UserModel)
.where(UserModel.uuid == user_uuid.bytes) .where(UserModel.uuid == user_uuid.bytes)
.values(org_uuid=org_uuid.bytes, role=role) .values(role_uuid=role_model.uuid)
) )
await session.execute(stmt) await session.execute(stmt)
async def transfer_user_to_organization( async def transfer_user_to_organization(
self, user_uuid: UUID, new_org_id: str, new_role: str | None = None self, user_uuid: UUID, new_org_id: str, new_role: str | None = None
) -> None: ) -> None:
async with self.session() as session: # Users are members of an org that never changes after creation.
# Convert string ID to UUID bytes for lookup # Disallow transfers across organizations to enforce invariant.
new_org_uuid = UUID(new_org_id) raise ValueError("Users cannot be transferred to a different organization")
# Verify the new organization exists async def get_user_organization(self, user_uuid: UUID) -> tuple[Org, str]:
org_stmt = select(OrgModel).where(OrgModel.uuid == new_org_uuid.bytes)
org_result = await session.execute(org_stmt)
org_model = org_result.scalar_one_or_none()
if not org_model:
raise ValueError("Target organization not found")
# Update the user's organization and role
stmt = (
update(UserModel)
.where(UserModel.uuid == user_uuid.bytes)
.values(org_uuid=new_org_uuid.bytes, role=new_role)
)
result = await session.execute(stmt)
if result.rowcount == 0:
raise ValueError("User not found")
async def get_user_organization(self, user_uuid: UUID) -> Org:
async with self.session() as session: async with self.session() as session:
stmt = select(UserModel).where(UserModel.uuid == user_uuid.bytes) stmt = select(UserModel).where(UserModel.uuid == user_uuid.bytes)
result = await session.execute(stmt) result = await session.execute(stmt)
@ -519,20 +570,31 @@ class DB(DatabaseInterface):
if not user_model: if not user_model:
raise ValueError("User not found") raise ValueError("User not found")
# Find user's role to get org
role_stmt = select(RoleModel).where(RoleModel.uuid == user_model.role_uuid)
role_result = await session.execute(role_stmt)
role_model = role_result.scalar_one()
# Fetch the organization details # Fetch the organization details
org_stmt = select(OrgModel).where(OrgModel.uuid == user_model.role_uuid) org_stmt = select(OrgModel).where(OrgModel.uuid == role_model.org_uuid)
org_result = await session.execute(org_stmt) org_result = await session.execute(org_stmt)
org_model = org_result.scalar_one() org_model = org_result.scalar_one()
# Convert UUID bytes back to string for the interface # Convert UUID bytes back to string for the interface
return org_model.as_dataclass() return org_model.as_dataclass(), role_model.display_name
async def get_organization_users(self, org_id: str) -> list[User]: async def get_organization_users(self, org_id: str) -> list[tuple[User, str]]:
async with self.session() as session: async with self.session() as session:
stmt = select(UserModel).where(UserModel.role_uuid == role.uuid.bytes) org_uuid = UUID(org_id)
# Join users with roles to filter by org and return role names
stmt = (
select(UserModel, RoleModel.display_name)
.join(RoleModel, UserModel.role_uuid == RoleModel.uuid)
.where(RoleModel.org_uuid == org_uuid.bytes)
)
result = await session.execute(stmt) result = await session.execute(stmt)
user_models = result.scalars().all() rows = result.fetchall()
return [u.as_dataclass() for u in user_models] return [(u.as_dataclass(), role_name) for (u, role_name) in rows]
async def get_user_role_in_organization( async def get_user_role_in_organization(
self, user_uuid: UUID, org_id: str self, user_uuid: UUID, org_id: str
@ -541,9 +603,14 @@ class DB(DatabaseInterface):
async with self.session() as session: async with self.session() as session:
# Convert string ID to UUID bytes for lookup # Convert string ID to UUID bytes for lookup
org_uuid = UUID(org_id) org_uuid = UUID(org_id)
stmt = select(UserModel.role).where( stmt = (
select(RoleModel.display_name)
.select_from(UserModel)
.join(RoleModel, UserModel.role_uuid == RoleModel.uuid)
.where(
UserModel.uuid == user_uuid.bytes, UserModel.uuid == user_uuid.bytes,
UserModel.org_uuid == org_uuid.bytes, RoleModel.org_uuid == org_uuid.bytes,
)
) )
result = await session.execute(stmt) result = await session.execute(stmt)
return result.scalar_one_or_none() return result.scalar_one_or_none()
@ -553,14 +620,35 @@ class DB(DatabaseInterface):
) -> None: ) -> None:
"""Update a user's role in their organization.""" """Update a user's role in their organization."""
async with self.session() as session: async with self.session() as session:
# Find user's current org via their role
user_stmt = select(UserModel).where(UserModel.uuid == user_uuid.bytes)
user_result = await session.execute(user_stmt)
user_model = user_result.scalar_one_or_none()
if not user_model:
raise ValueError("User not found")
current_role_stmt = select(RoleModel).where(
RoleModel.uuid == user_model.role_uuid
)
current_role_result = await session.execute(current_role_stmt)
current_role = current_role_result.scalar_one()
# Find the new role within the same organization
role_stmt = select(RoleModel).where(
RoleModel.org_uuid == current_role.org_uuid,
RoleModel.display_name == new_role,
)
role_result = await session.execute(role_stmt)
role_model = role_result.scalar_one_or_none()
if not role_model:
raise ValueError("Role not found in user's organization")
stmt = ( stmt = (
update(UserModel) update(UserModel)
.where(UserModel.uuid == user_uuid.bytes) .where(UserModel.uuid == user_uuid.bytes)
.values(role=new_role) .values(role_uuid=role_model.uuid)
) )
result = await session.execute(stmt) await session.execute(stmt)
if result.rowcount == 0:
raise ValueError("User not found")
# Permission operations # Permission operations
async def create_permission(self, permission: Permission) -> None: async def create_permission(self, permission: Permission) -> None:
@ -598,6 +686,53 @@ class DB(DatabaseInterface):
stmt = delete(PermissionModel).where(PermissionModel.id == permission_id) stmt = delete(PermissionModel).where(PermissionModel.id == permission_id)
await session.execute(stmt) await session.execute(stmt)
async def add_permission_to_role(self, role_uuid: UUID, permission_id: str) -> None:
async with self.session() as session:
# Ensure role exists
role_stmt = select(RoleModel).where(RoleModel.uuid == role_uuid.bytes)
role_result = await session.execute(role_stmt)
role_model = role_result.scalar_one_or_none()
if not role_model:
raise ValueError("Role not found")
# Ensure permission exists
perm_stmt = select(PermissionModel).where(PermissionModel.id == permission_id)
perm_result = await session.execute(perm_stmt)
if not perm_result.scalar_one_or_none():
raise ValueError("Permission not found")
session.add(
RolePermission(role_uuid=role_uuid.bytes, permission_id=permission_id)
)
async def remove_permission_from_role(self, role_uuid: UUID, permission_id: str) -> None:
async with self.session() as session:
await session.execute(
delete(RolePermission)
.where(RolePermission.role_uuid == role_uuid.bytes)
.where(RolePermission.permission_id == permission_id)
)
async def get_role_permissions(self, role_uuid: UUID) -> list[Permission]:
async with self.session() as session:
stmt = (
select(PermissionModel)
.join(RolePermission, PermissionModel.id == RolePermission.permission_id)
.where(RolePermission.role_uuid == role_uuid.bytes)
)
result = await session.execute(stmt)
return [p.as_dataclass() for p in result.scalars().all()]
async def get_permission_roles(self, permission_id: str) -> list[Role]:
async with self.session() as session:
stmt = (
select(RoleModel)
.join(RolePermission, RoleModel.uuid == RolePermission.role_uuid)
.where(RolePermission.permission_id == permission_id)
)
result = await session.execute(stmt)
return [r.as_dataclass() for r in result.scalars().all()]
async def add_permission_to_organization( async def add_permission_to_organization(
self, org_id: str, permission_id: str self, org_id: str, permission_id: str
) -> None: ) -> None:
@ -679,9 +814,7 @@ class DB(DatabaseInterface):
) )
org_result = await session.execute(org_stmt) org_result = await session.execute(org_stmt)
org_model = org_result.scalar_one() org_model = org_result.scalar_one()
organizations.append(org_model.as_dataclass())
# Convert UUID bytes back to string for the interface
organizations.append(org.as_dataclass())
return organizations return organizations
@ -697,21 +830,21 @@ class DB(DatabaseInterface):
Uses efficient JOINs to retrieve all related data in a single database query. Uses efficient JOINs to retrieve all related data in a single database query.
""" """
async with self.session() as session: async with self.session() as session:
# Build a query that joins sessions, users, organizations, org_permissions, and permissions # Build a query that joins sessions, users, roles, organizations, and role_permissions
stmt = ( stmt = (
select( select(
SessionModel, SessionModel,
UserModel, UserModel,
RoleModel,
OrgModel, OrgModel,
PermissionModel, PermissionModel,
) )
.select_from(SessionModel) .select_from(SessionModel)
.join(UserModel, SessionModel.user_uuid == UserModel.uuid) .join(UserModel, SessionModel.user_uuid == UserModel.uuid)
.join(OrgModel, UserModel.org_uuid == OrgModel.uuid) .join(RoleModel, UserModel.role_uuid == RoleModel.uuid)
.outerjoin(OrgPermission, OrgModel.uuid == OrgPermission.org_uuid) .join(OrgModel, RoleModel.org_uuid == OrgModel.uuid)
.outerjoin( .outerjoin(RolePermission, RoleModel.uuid == RolePermission.role_uuid)
PermissionModel, OrgPermission.permission_id == PermissionModel.id .outerjoin(PermissionModel, RolePermission.permission_id == PermissionModel.id)
)
.where(SessionModel.key == session_key) .where(SessionModel.key == session_key)
) )
@ -723,7 +856,7 @@ class DB(DatabaseInterface):
# Extract the first row to get session and user data # Extract the first row to get session and user data
first_row = rows[0] first_row = rows[0]
session_model, user_model, org_model, _ = first_row session_model, user_model, role_model, org_model, _ = first_row
# Create the session object # Create the session object
session_obj = Session( session_obj = Session(
@ -739,14 +872,21 @@ class DB(DatabaseInterface):
# Create the user object # Create the user object
user_obj = user_model.as_dataclass() user_obj = user_model.as_dataclass()
# Create organization object (always exists now) # Create organization object (fill permissions later if needed)
organization = Org(UUID(bytes=org_model.uuid), org_model.display_name) organization = Org(UUID(bytes=org_model.uuid), org_model.display_name)
# Collect all unique permissions # Create role object
role = Role(
uuid=UUID(bytes=role_model.uuid),
org_uuid=UUID(bytes=role_model.org_uuid),
display_name=role_model.display_name,
)
# Collect all unique permissions for the role
permissions = [] permissions = []
seen_permission_ids = set() seen_permission_ids = set()
for row in rows: for row in rows:
_, _, _, permission_model = row _, _, _, _, permission_model = row
if permission_model and permission_model.id not in seen_permission_ids: if permission_model and permission_model.id not in seen_permission_ids:
permissions.append( permissions.append(
Permission( Permission(
@ -756,10 +896,20 @@ class DB(DatabaseInterface):
) )
seen_permission_ids.add(permission_model.id) seen_permission_ids.add(permission_model.id)
# Attach permission IDs to role
role.permissions = list(seen_permission_ids)
# Load org permission IDs as well
org_perm_stmt = select(OrgPermission.permission_id).where(
OrgPermission.org_uuid == org_model.uuid
)
org_perm_result = await session.execute(org_perm_stmt)
organization.permissions = [row[0] for row in org_perm_result.fetchall()]
return SessionContext( return SessionContext(
session=session_obj, session=session_obj,
user=user_obj, user=user_obj,
org=organization, org=organization,
role=user_model.role, role=role,
permissions=permissions if permissions else None, permissions=permissions if permissions else None,
) )

View File

@ -41,6 +41,9 @@ def register_api_routes(app: FastAPI):
"""Get full user information for the authenticated user.""" """Get full user information for the authenticated user."""
reset = passphrase.is_well_formed(auth) reset = passphrase.is_well_formed(auth)
s = await (get_reset if reset else get_session)(auth) s = await (get_reset if reset else get_session)(auth)
# Session context (org, role, permissions)
ctx = await db.instance.get_session_context(session_key(auth))
# Fallback if context not available (e.g., reset session)
u = await db.instance.get_user_by_uuid(s.user_uuid) u = await db.instance.get_user_by_uuid(s.user_uuid)
# Get all credentials for the user # Get all credentials for the user
credential_ids = await db.instance.get_credentials_by_user_uuid(s.user_uuid) credential_ids = await db.instance.get_credentials_by_user_uuid(s.user_uuid)
@ -78,6 +81,33 @@ def register_api_routes(app: FastAPI):
# Sort credentials by creation date (earliest first, most recently created last) # Sort credentials by creation date (earliest first, most recently created last)
credentials.sort(key=lambda cred: cred["created_at"]) credentials.sort(key=lambda cred: cred["created_at"])
# Permissions and roles
role_info = None
org_info = None
effective_permissions = []
is_global_admin = False
is_org_admin = False
if ctx:
role_info = {
"uuid": str(ctx.role.uuid),
"display_name": ctx.role.display_name,
"permissions": ctx.role.permissions, # IDs
}
org_info = {
"uuid": str(ctx.org.uuid),
"display_name": ctx.org.display_name,
"permissions": ctx.org.permissions, # IDs the org can grant
}
# Effective permissions are role permissions; API also returns full objects for convenience
effective_permissions = [p.id for p in (ctx.permissions or [])]
is_global_admin = "auth/admin" in role_info["permissions"]
# org admin permission is auth/org:<org_uuid>
is_org_admin = (
f"auth/org:{org_info['uuid']}" in role_info["permissions"]
if org_info
else False
)
return { return {
"authenticated": not reset, "authenticated": not reset,
"session_type": s.info["type"], "session_type": s.info["type"],
@ -88,6 +118,11 @@ def register_api_routes(app: FastAPI):
"last_seen": u.last_seen.isoformat() if u.last_seen else None, "last_seen": u.last_seen.isoformat() if u.last_seen else None,
"visits": u.visits, "visits": u.visits,
}, },
"org": org_info,
"role": role_info,
"permissions": effective_permissions,
"is_global_admin": is_global_admin,
"is_org_admin": is_org_admin,
"credentials": credentials, "credentials": credentials,
"aaguid_info": aaguid_info, "aaguid_info": aaguid_info,
} }