Users always belong to one Org. Implement a DB function to fetch all data relevant to a session.
This commit is contained in:
parent
2e3ce32779
commit
2e4ff30bea
@ -15,7 +15,7 @@ from uuid import UUID
|
||||
class User:
|
||||
uuid: UUID
|
||||
display_name: str
|
||||
org_uuid: UUID | None = None
|
||||
org_uuid: UUID
|
||||
role: str | None = None
|
||||
created_at: datetime | None = None
|
||||
last_seen: datetime | None = None
|
||||
@ -62,6 +62,17 @@ class Session:
|
||||
credential_uuid: UUID | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionContext:
|
||||
"""Complete session context with user, organization, role, and permissions."""
|
||||
|
||||
session: Session
|
||||
user: User
|
||||
organization: Org
|
||||
role: str | None = None
|
||||
permissions: list[Permission] | None = None
|
||||
|
||||
|
||||
class DatabaseInterface(ABC):
|
||||
"""Abstract base class defining the database interface.
|
||||
|
||||
@ -160,11 +171,13 @@ class DatabaseInterface(ABC):
|
||||
"""Set a user's organization and role."""
|
||||
|
||||
@abstractmethod
|
||||
async def remove_user_from_organization(self, user_uuid: UUID) -> None:
|
||||
"""Remove a user from their organization."""
|
||||
async def transfer_user_to_organization(
|
||||
self, user_uuid: UUID, new_org_id: str, new_role: str | None = None
|
||||
) -> None:
|
||||
"""Transfer a user to another organization with an optional role."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_organization(self, user_uuid: UUID) -> tuple[Org, str] | None:
|
||||
async def get_user_organization(self, user_uuid: UUID) -> tuple[Org, str]:
|
||||
"""Get the organization and role for a user."""
|
||||
|
||||
@abstractmethod
|
||||
@ -231,11 +244,16 @@ class DatabaseInterface(ABC):
|
||||
) -> None:
|
||||
"""Create a new user and their first credential in a transaction."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_session_context(self, session_key: bytes) -> SessionContext | None:
|
||||
"""Get complete session context including user, organization, role, and permissions."""
|
||||
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
"Credential",
|
||||
"Session",
|
||||
"SessionContext",
|
||||
"Org",
|
||||
"Permission",
|
||||
"DatabaseInterface",
|
||||
|
@ -24,7 +24,15 @@ from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
from ..globals import db
|
||||
from . import Credential, DatabaseInterface, Org, Permission, Session, User
|
||||
from . import (
|
||||
Credential,
|
||||
DatabaseInterface,
|
||||
Org,
|
||||
Permission,
|
||||
Session,
|
||||
SessionContext,
|
||||
User,
|
||||
)
|
||||
|
||||
DB_PATH = "sqlite+aiosqlite:///passkey-auth.sqlite"
|
||||
|
||||
@ -55,20 +63,6 @@ class OrgPermission(Base):
|
||||
ForeignKey("permissions.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
"""Permissions each Org is allowed to grant to its roles."""
|
||||
|
||||
__tablename__ = "org_permissions"
|
||||
|
||||
org_uuid: Mapped[bytes] = mapped_column(
|
||||
LargeBinary(16),
|
||||
ForeignKey("orgs.uuid", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
permission_id: Mapped[str] = mapped_column(
|
||||
String(32),
|
||||
ForeignKey("permissions.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
|
||||
class PermissionModel(Base):
|
||||
@ -90,8 +84,8 @@ class UserModel(Base):
|
||||
|
||||
uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
|
||||
display_name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
org_uuid: Mapped[bytes | None] = mapped_column(
|
||||
LargeBinary(16), ForeignKey("orgs.uuid", ondelete="SET NULL"), nullable=True
|
||||
org_uuid: Mapped[bytes] = mapped_column(
|
||||
LargeBinary(16), ForeignKey("orgs.uuid", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
role: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||
@ -164,9 +158,7 @@ class DB(DatabaseInterface):
|
||||
return User(
|
||||
uuid=UUID(bytes=user_model.uuid),
|
||||
display_name=user_model.display_name,
|
||||
org_uuid=UUID(bytes=user_model.org_uuid)
|
||||
if user_model.org_uuid
|
||||
else None,
|
||||
org_uuid=UUID(bytes=user_model.org_uuid),
|
||||
role=user_model.role,
|
||||
created_at=user_model.created_at,
|
||||
last_seen=user_model.last_seen,
|
||||
@ -179,7 +171,7 @@ class DB(DatabaseInterface):
|
||||
user_model = UserModel(
|
||||
uuid=user.uuid.bytes,
|
||||
display_name=user.display_name,
|
||||
org_uuid=user.org_uuid.bytes if user.org_uuid else None,
|
||||
org_uuid=user.org_uuid.bytes,
|
||||
role=user.role,
|
||||
created_at=user.created_at or datetime.now(),
|
||||
last_seen=user.last_seen,
|
||||
@ -280,7 +272,7 @@ class DB(DatabaseInterface):
|
||||
user_model = UserModel(
|
||||
uuid=user.uuid.bytes,
|
||||
display_name=user.display_name,
|
||||
org_uuid=user.org_uuid.bytes if user.org_uuid else None,
|
||||
org_uuid=user.org_uuid.bytes,
|
||||
role=user.role,
|
||||
created_at=user.created_at or datetime.now(),
|
||||
last_seen=user.last_seen,
|
||||
@ -432,24 +424,39 @@ class DB(DatabaseInterface):
|
||||
)
|
||||
await session.execute(stmt)
|
||||
|
||||
async def remove_user_from_organization(self, user_uuid: UUID) -> None:
|
||||
async def transfer_user_to_organization(
|
||||
self, user_uuid: UUID, new_org_id: str, new_role: str | None = None
|
||||
) -> None:
|
||||
async with self.session() as session:
|
||||
# Clear the user's organization and role
|
||||
# Convert string ID to UUID bytes for lookup
|
||||
new_org_uuid = UUID(new_org_id)
|
||||
|
||||
# Verify the new organization exists
|
||||
org_stmt = select(OrgModel).where(OrgModel.uuid == new_org_uuid.bytes)
|
||||
org_result = await session.execute(org_stmt)
|
||||
org_model = org_result.scalar_one_or_none()
|
||||
|
||||
if not org_model:
|
||||
raise ValueError("Target organization not found")
|
||||
|
||||
# Update the user's organization and role
|
||||
stmt = (
|
||||
update(UserModel)
|
||||
.where(UserModel.uuid == user_uuid.bytes)
|
||||
.values(org_uuid=None, role=None)
|
||||
.values(org_uuid=new_org_uuid.bytes, role=new_role)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
result = await session.execute(stmt)
|
||||
if result.rowcount == 0:
|
||||
raise ValueError("User not found")
|
||||
|
||||
async def get_user_organization(self, user_uuid: UUID) -> tuple[Org, str] | None:
|
||||
async def get_user_organization(self, user_uuid: UUID) -> tuple[Org, str]:
|
||||
async with self.session() as session:
|
||||
stmt = select(UserModel).where(UserModel.uuid == user_uuid.bytes)
|
||||
result = await session.execute(stmt)
|
||||
user_model = result.scalar_one_or_none()
|
||||
|
||||
if not user_model or not user_model.org_uuid:
|
||||
return None
|
||||
if not user_model:
|
||||
raise ValueError("User not found")
|
||||
|
||||
# Fetch the organization details
|
||||
org_stmt = select(OrgModel).where(OrgModel.uuid == user_model.org_uuid)
|
||||
@ -474,9 +481,7 @@ class DB(DatabaseInterface):
|
||||
user = User(
|
||||
uuid=UUID(bytes=user_model.uuid),
|
||||
display_name=user_model.display_name,
|
||||
org_uuid=UUID(bytes=user_model.org_uuid)
|
||||
if user_model.org_uuid
|
||||
else None,
|
||||
org_uuid=UUID(bytes=user_model.org_uuid),
|
||||
role=user_model.role,
|
||||
created_at=user_model.created_at,
|
||||
last_seen=user_model.last_seen,
|
||||
@ -643,3 +648,87 @@ class DB(DatabaseInterface):
|
||||
current_time = datetime.now()
|
||||
stmt = delete(SessionModel).where(SessionModel.expires < current_time)
|
||||
await session.execute(stmt)
|
||||
|
||||
async def get_session_context(self, session_key: bytes) -> SessionContext | None:
|
||||
"""Get complete session context including user, organization, role, and permissions.
|
||||
|
||||
Uses efficient JOINs to retrieve all related data in a single database query.
|
||||
"""
|
||||
async with self.session() as session:
|
||||
# Build a query that joins sessions, users, organizations, org_permissions, and permissions
|
||||
stmt = (
|
||||
select(
|
||||
SessionModel,
|
||||
UserModel,
|
||||
OrgModel,
|
||||
PermissionModel,
|
||||
)
|
||||
.select_from(SessionModel)
|
||||
.join(UserModel, SessionModel.user_uuid == UserModel.uuid)
|
||||
.join(OrgModel, UserModel.org_uuid == OrgModel.uuid)
|
||||
.outerjoin(OrgPermission, OrgModel.uuid == OrgPermission.org_uuid)
|
||||
.outerjoin(
|
||||
PermissionModel, OrgPermission.permission_id == PermissionModel.id
|
||||
)
|
||||
.where(SessionModel.key == session_key)
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
rows = result.fetchall()
|
||||
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
# Extract the first row to get session and user data
|
||||
first_row = rows[0]
|
||||
session_model, user_model, org_model, _ = first_row
|
||||
|
||||
# Create the session object
|
||||
session_obj = Session(
|
||||
key=session_model.key,
|
||||
user_uuid=UUID(bytes=session_model.user_uuid),
|
||||
credential_uuid=UUID(bytes=session_model.credential_uuid)
|
||||
if session_model.credential_uuid
|
||||
else None,
|
||||
expires=session_model.expires,
|
||||
info=session_model.info or {},
|
||||
)
|
||||
|
||||
# Create the user object
|
||||
user_obj = User(
|
||||
uuid=UUID(bytes=user_model.uuid),
|
||||
display_name=user_model.display_name,
|
||||
org_uuid=UUID(bytes=user_model.org_uuid),
|
||||
role=user_model.role,
|
||||
created_at=user_model.created_at,
|
||||
last_seen=user_model.last_seen,
|
||||
visits=user_model.visits,
|
||||
)
|
||||
|
||||
# Create organization object (always exists now)
|
||||
organization = Org(
|
||||
id=str(UUID(bytes=org_model.uuid)),
|
||||
options=org_model.options,
|
||||
)
|
||||
|
||||
# Collect all unique permissions
|
||||
permissions = []
|
||||
seen_permission_ids = set()
|
||||
for row in rows:
|
||||
_, _, _, permission_model = row
|
||||
if permission_model and permission_model.id not in seen_permission_ids:
|
||||
permissions.append(
|
||||
Permission(
|
||||
id=permission_model.id,
|
||||
display_name=permission_model.display_name,
|
||||
)
|
||||
)
|
||||
seen_permission_ids.add(permission_model.id)
|
||||
|
||||
return SessionContext(
|
||||
session=session_obj,
|
||||
user=user_obj,
|
||||
organization=organization,
|
||||
role=user_model.role,
|
||||
permissions=permissions if permissions else None,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user