Checkpoint, fixing reset token handling broken in earlier edits.
This commit is contained in:
parent
c42864794a
commit
cf138d90c5
@ -2,6 +2,7 @@
|
|||||||
<div class="container">
|
<div class="container">
|
||||||
<div class="view active">
|
<div class="view active">
|
||||||
<h1>🔑 Add Device Credential</h1>
|
<h1>🔑 Add Device Credential</h1>
|
||||||
|
<h3>👤 {{ authStore.currentUser.user_name }}</h3>
|
||||||
<button
|
<button
|
||||||
class="btn-primary"
|
class="btn-primary"
|
||||||
:disabled="authStore.isLoading"
|
:disabled="authStore.isLoading"
|
||||||
@ -15,13 +16,13 @@
|
|||||||
|
|
||||||
<script setup>
|
<script setup>
|
||||||
import { useAuthStore } from '@/stores/auth'
|
import { useAuthStore } from '@/stores/auth'
|
||||||
|
import { computed } from 'vue'
|
||||||
import { registerCredential } from '@/utils/passkey'
|
import { registerCredential } from '@/utils/passkey'
|
||||||
import { ref, computed } from 'vue'
|
|
||||||
|
|
||||||
const authStore = useAuthStore()
|
const authStore = useAuthStore()
|
||||||
const hasDeviceSession = computed(() => !!authStore.currentUser)
|
const hasDeviceSession = computed(() => !!authStore.currentUser)
|
||||||
|
|
||||||
function register() {
|
async function register() {
|
||||||
if (!hasDeviceSession.value) {
|
if (!hasDeviceSession.value) {
|
||||||
authStore.showMessage('No valid device addition session', 'error')
|
authStore.showMessage('No valid device addition session', 'error')
|
||||||
return
|
return
|
||||||
@ -29,18 +30,22 @@ function register() {
|
|||||||
|
|
||||||
authStore.isLoading = true
|
authStore.isLoading = true
|
||||||
authStore.showMessage('Starting registration...', 'info')
|
authStore.showMessage('Starting registration...', 'info')
|
||||||
registerCredential().finally(() => {
|
|
||||||
authStore.isLoading = false
|
try {
|
||||||
}).then(() => {
|
const result = await registerCredential()
|
||||||
|
console.log("Result", result)
|
||||||
|
await authStore.setSessionCookie(result.session_token)
|
||||||
authStore.showMessage('Passkey registered successfully!', 'success', 2000)
|
authStore.showMessage('Passkey registered successfully!', 'success', 2000)
|
||||||
authStore.currentView = 'profile'
|
authStore.currentView = 'profile'
|
||||||
}).catch((error) => {
|
} catch (error) {
|
||||||
console.error('Registration error:', error)
|
console.error('Registration error:', error)
|
||||||
if (error.name === "NotAllowedError") {
|
if (error.name === "NotAllowedError") {
|
||||||
authStore.showMessage('Registration cancelled', 'error')
|
authStore.showMessage('Registration cancelled', 'error')
|
||||||
} else {
|
} else {
|
||||||
authStore.showMessage(`Registration failed: ${error.message}`, 'error')
|
authStore.showMessage(`Registration failed: ${error.message}`, 'error')
|
||||||
}
|
}
|
||||||
})
|
} finally {
|
||||||
|
authStore.isLoading = false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
@ -46,13 +46,12 @@ export const useAuthStore = defineStore('auth', {
|
|||||||
try {
|
try {
|
||||||
const result = await registerUser(user_name)
|
const result = await registerUser(user_name)
|
||||||
|
|
||||||
await this.setSessionCookie(result.session_token)
|
|
||||||
|
|
||||||
this.currentUser = {
|
this.currentUser = {
|
||||||
user_id: result.user_id,
|
user_id: result.user_id,
|
||||||
user_name: user_name,
|
user_name: user_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await this.setSessionCookie(result.session_token)
|
||||||
return result
|
return result
|
||||||
} finally {
|
} finally {
|
||||||
this.isLoading = false
|
this.isLoading = false
|
||||||
|
@ -8,7 +8,6 @@ export async function register(url, options) {
|
|||||||
const optionsJSON = await ws.receive_json()
|
const optionsJSON = await ws.receive_json()
|
||||||
const registrationResponse = await startRegistration({ optionsJSON })
|
const registrationResponse = await startRegistration({ optionsJSON })
|
||||||
ws.send_json(registrationResponse)
|
ws.send_json(registrationResponse)
|
||||||
|
|
||||||
const result = await ws.receive_json()
|
const result = await ws.receive_json()
|
||||||
ws.close()
|
ws.close()
|
||||||
return result;
|
return result;
|
||||||
@ -31,7 +30,6 @@ export async function authenticateUser() {
|
|||||||
const optionsJSON = await ws.receive_json()
|
const optionsJSON = await ws.receive_json()
|
||||||
const authResponse = await startAuthentication({ optionsJSON })
|
const authResponse = await startAuthentication({ optionsJSON })
|
||||||
ws.send_json(authResponse)
|
ws.send_json(authResponse)
|
||||||
|
|
||||||
const result = await ws.receive_json()
|
const result = await ws.receive_json()
|
||||||
ws.close()
|
ws.close()
|
||||||
return result
|
return result
|
||||||
|
@ -12,7 +12,6 @@ from datetime import datetime, timedelta
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from .db import Session, db
|
from .db import Session, db
|
||||||
from .util import passphrase
|
|
||||||
from .util.tokens import create_token, reset_key, session_key
|
from .util.tokens import create_token, reset_key, session_key
|
||||||
|
|
||||||
EXPIRES = timedelta(hours=24)
|
EXPIRES = timedelta(hours=24)
|
||||||
@ -22,29 +21,30 @@ def expires() -> datetime:
|
|||||||
return datetime.now() + EXPIRES
|
return datetime.now() + EXPIRES
|
||||||
|
|
||||||
|
|
||||||
async def create_session(user_uuid: UUID, info: dict, credential_uuid: UUID) -> str:
|
async def create_session(user_uuid: UUID, credential_uuid: UUID, info: dict) -> str:
|
||||||
"""Create a new session and return a session token."""
|
"""Create a new session and return a session token."""
|
||||||
token = create_token()
|
token = create_token()
|
||||||
await db.instance.create_session(
|
await db.instance.create_session(
|
||||||
user_uuid=user_uuid,
|
user_uuid=user_uuid,
|
||||||
|
credential_uuid=credential_uuid,
|
||||||
key=session_key(token),
|
key=session_key(token),
|
||||||
expires=datetime.now() + EXPIRES,
|
expires=datetime.now() + EXPIRES,
|
||||||
info=info,
|
info=info,
|
||||||
credential_uuid=credential_uuid,
|
|
||||||
)
|
)
|
||||||
return token
|
return token
|
||||||
|
|
||||||
|
|
||||||
async def get_session(token: str, reset_allowed=False) -> Session:
|
async def get_reset(token: str) -> Session:
|
||||||
"""Validate a session token and return session data if valid."""
|
"""Validate a credential reset token. Returns None if the token is not well formed (i.e. it is another type of token)."""
|
||||||
if passphrase.is_well_formed(token):
|
session = await db.instance.get_session(reset_key(token))
|
||||||
if not reset_allowed:
|
if not session:
|
||||||
raise ValueError("Reset link is not allowed for this endpoint")
|
raise ValueError("Invalid or expired session token")
|
||||||
key = reset_key(token)
|
return session
|
||||||
else:
|
|
||||||
key = session_key(token)
|
|
||||||
|
|
||||||
session = await db.instance.get_session(key)
|
|
||||||
|
async def get_session(token: str) -> Session:
|
||||||
|
"""Validate a session token and return session data if valid."""
|
||||||
|
session = await db.instance.get_session(session_key(token))
|
||||||
if not session:
|
if not session:
|
||||||
raise ValueError("Invalid or expired session token")
|
raise ValueError("Invalid or expired session token")
|
||||||
return session
|
return session
|
||||||
|
@ -13,8 +13,10 @@ from uuid import UUID
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class User:
|
class User:
|
||||||
user_uuid: UUID
|
uuid: UUID
|
||||||
user_name: str
|
display_name: str
|
||||||
|
org_uuid: UUID | None = None
|
||||||
|
role: str | None = None
|
||||||
created_at: datetime | None = None
|
created_at: datetime | None = None
|
||||||
last_seen: datetime | None = None
|
last_seen: datetime | None = None
|
||||||
visits: int = 0
|
visits: int = 0
|
||||||
@ -41,6 +43,14 @@ class Org:
|
|||||||
options: dict
|
options: dict
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Permission:
|
||||||
|
"""Permission data structure."""
|
||||||
|
|
||||||
|
id: str # String primary key (max 32 chars)
|
||||||
|
display_name: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Session:
|
class Session:
|
||||||
"""Session data structure."""
|
"""Session data structure."""
|
||||||
@ -68,7 +78,7 @@ class DatabaseInterface(ABC):
|
|||||||
|
|
||||||
# User operations
|
# User operations
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_user_by_user_uuid(self, user_uuid: UUID) -> User:
|
async def get_user_by_uuid(self, user_uuid: UUID) -> User:
|
||||||
"""Get user record by WebAuthn user UUID."""
|
"""Get user record by WebAuthn user UUID."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -147,20 +157,69 @@ class DatabaseInterface(ABC):
|
|||||||
async def add_user_to_organization(
|
async def add_user_to_organization(
|
||||||
self, user_uuid: UUID, org_id: str, role: str
|
self, user_uuid: UUID, org_id: str, role: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a user to an organization with a specific role."""
|
"""Set a user's organization and role."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def remove_user_from_organization(self, user_uuid: UUID, org_id: str) -> None:
|
async def remove_user_from_organization(self, user_uuid: UUID) -> None:
|
||||||
"""Remove a user from an organization."""
|
"""Remove a user from their organization."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_user_org_role(self, user_uuid: UUID) -> list[tuple[Org, str]]:
|
async def get_user_organization(self, user_uuid: UUID) -> tuple[Org, str] | None:
|
||||||
"""Get all organizations for a user with their roles."""
|
"""Get the organization and role for a user."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_organization_users(self, org_id: str) -> list[tuple[User, str]]:
|
async def get_organization_users(self, org_id: str) -> list[tuple[User, str]]:
|
||||||
"""Get all users in an organization with their roles."""
|
"""Get all users in an organization with their roles."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_user_role_in_organization(
|
||||||
|
self, user_uuid: UUID, org_id: str
|
||||||
|
) -> str | None:
|
||||||
|
"""Get a user's role in a specific organization."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def update_user_role_in_organization(
|
||||||
|
self, user_uuid: UUID, new_role: str
|
||||||
|
) -> None:
|
||||||
|
"""Update a user's role in their organization."""
|
||||||
|
|
||||||
|
# Permission operations
|
||||||
|
@abstractmethod
|
||||||
|
async def create_permission(self, permission: Permission) -> None:
|
||||||
|
"""Create a new permission."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_permission(self, permission_id: str) -> Permission:
|
||||||
|
"""Get permission by ID."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def update_permission(self, permission: Permission) -> None:
|
||||||
|
"""Update permission details."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete_permission(self, permission_id: str) -> None:
|
||||||
|
"""Delete permission by ID."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def add_permission_to_organization(
|
||||||
|
self, org_id: str, permission_id: str
|
||||||
|
) -> None:
|
||||||
|
"""Add a permission to an organization."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def remove_permission_from_organization(
|
||||||
|
self, org_id: str, permission_id: str
|
||||||
|
) -> None:
|
||||||
|
"""Remove a permission from an organization."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_organization_permissions(self, org_id: str) -> list[Permission]:
|
||||||
|
"""Get all permissions assigned to an organization."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_permission_organizations(self, permission_id: str) -> list[Org]:
|
||||||
|
"""Get all organizations that have a specific permission."""
|
||||||
|
|
||||||
# Combined operations
|
# Combined operations
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def login(self, user_uuid: UUID, credential: Credential) -> None:
|
async def login(self, user_uuid: UUID, credential: Credential) -> None:
|
||||||
@ -200,6 +259,7 @@ __all__ = [
|
|||||||
"Credential",
|
"Credential",
|
||||||
"Session",
|
"Session",
|
||||||
"Org",
|
"Org",
|
||||||
|
"Permission",
|
||||||
"DatabaseInterface",
|
"DatabaseInterface",
|
||||||
"db",
|
"db",
|
||||||
]
|
]
|
||||||
|
@ -21,9 +21,9 @@ from sqlalchemy import (
|
|||||||
)
|
)
|
||||||
from sqlalchemy.dialects.sqlite import BLOB, JSON
|
from sqlalchemy.dialects.sqlite import BLOB, JSON
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
|
|
||||||
from . import Credential, DatabaseInterface, Org, Session, User, db
|
from . import Credential, DatabaseInterface, Org, Permission, Session, User, db
|
||||||
|
|
||||||
DB_PATH = "sqlite+aiosqlite:///passkey-auth.sqlite"
|
DB_PATH = "sqlite+aiosqlite:///passkey-auth.sqlite"
|
||||||
|
|
||||||
@ -38,25 +38,43 @@ class Base(DeclarativeBase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Association model for many-to-many relationship between users and organizations with roles
|
# Association model for many-to-many relationship between organizations and permissions
|
||||||
class UserRole(Base):
|
class OrgPermission(Base):
|
||||||
__tablename__ = "user_roles"
|
"""Permissions each Org is allowed to grant to its roles."""
|
||||||
|
|
||||||
|
__tablename__ = "org_permissions"
|
||||||
|
|
||||||
user_uuid: Mapped[bytes] = mapped_column(
|
|
||||||
LargeBinary(16),
|
|
||||||
ForeignKey("users.user_uuid", ondelete="CASCADE"),
|
|
||||||
primary_key=True,
|
|
||||||
)
|
|
||||||
org_uuid: Mapped[bytes] = mapped_column(
|
org_uuid: Mapped[bytes] = mapped_column(
|
||||||
LargeBinary(16),
|
LargeBinary(16),
|
||||||
ForeignKey("orgs.uuid", ondelete="CASCADE"),
|
ForeignKey("orgs.uuid", ondelete="CASCADE"),
|
||||||
primary_key=True,
|
primary_key=True,
|
||||||
)
|
)
|
||||||
role: Mapped[str] = mapped_column(String, nullable=False)
|
permission_id: Mapped[str] = mapped_column(
|
||||||
|
String(32),
|
||||||
|
ForeignKey("permissions.id", ondelete="CASCADE"),
|
||||||
|
primary_key=True,
|
||||||
|
)
|
||||||
|
"""Permissions each Org is allowed to grant to its roles."""
|
||||||
|
|
||||||
# Relationships to the actual models
|
__tablename__ = "org_permissions"
|
||||||
user: Mapped["UserModel"] = relationship("UserModel", back_populates="user_orgs")
|
|
||||||
org: Mapped["OrgModel"] = relationship("OrgModel", back_populates="user_orgs")
|
org_uuid: Mapped[bytes] = mapped_column(
|
||||||
|
LargeBinary(16),
|
||||||
|
ForeignKey("orgs.uuid", ondelete="CASCADE"),
|
||||||
|
primary_key=True,
|
||||||
|
)
|
||||||
|
permission_id: Mapped[str] = mapped_column(
|
||||||
|
String(32),
|
||||||
|
ForeignKey("permissions.id", ondelete="CASCADE"),
|
||||||
|
primary_key=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionModel(Base):
|
||||||
|
__tablename__ = "permissions"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(128), primary_key=True)
|
||||||
|
display_name: Mapped[str] = mapped_column(String, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
class OrgModel(Base):
|
class OrgModel(Base):
|
||||||
@ -65,31 +83,20 @@ class OrgModel(Base):
|
|||||||
uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
|
uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
|
||||||
options: Mapped[dict] = mapped_column(JSON, nullable=False, default=dict)
|
options: Mapped[dict] = mapped_column(JSON, nullable=False, default=dict)
|
||||||
|
|
||||||
# Relationship to user-org associations
|
|
||||||
user_orgs: Mapped[list["UserRole"]] = relationship(
|
|
||||||
"UserRoleModel", back_populates="org", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UserModel(Base):
|
class UserModel(Base):
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
user_uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
|
uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
|
||||||
user_name: Mapped[str] = mapped_column(String, nullable=False)
|
display_name: Mapped[str] = mapped_column(String, nullable=False)
|
||||||
|
org_uuid: Mapped[bytes | None] = mapped_column(
|
||||||
|
LargeBinary(16), ForeignKey("orgs.uuid", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
role: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||||
last_seen: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
last_seen: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
visits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
visits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
|
||||||
# Relationship to credentials
|
|
||||||
credentials: Mapped[list["CredentialModel"]] = relationship(
|
|
||||||
"CredentialModel", back_populates="user", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Relationship to user-org associations
|
|
||||||
user_orgs: Mapped[list["UserRole"]] = relationship(
|
|
||||||
"UserOrgModel", back_populates="user", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CredentialModel(Base):
|
class CredentialModel(Base):
|
||||||
__tablename__ = "credentials"
|
__tablename__ = "credentials"
|
||||||
@ -98,7 +105,7 @@ class CredentialModel(Base):
|
|||||||
LargeBinary(64), unique=True, index=True
|
LargeBinary(64), unique=True, index=True
|
||||||
)
|
)
|
||||||
user_uuid: Mapped[bytes] = mapped_column(
|
user_uuid: Mapped[bytes] = mapped_column(
|
||||||
LargeBinary(16), ForeignKey("users.user_uuid", ondelete="CASCADE")
|
LargeBinary(16), ForeignKey("users.uuid", ondelete="CASCADE")
|
||||||
)
|
)
|
||||||
aaguid: Mapped[bytes] = mapped_column(LargeBinary(16), nullable=False)
|
aaguid: Mapped[bytes] = mapped_column(LargeBinary(16), nullable=False)
|
||||||
public_key: Mapped[bytes] = mapped_column(BLOB, nullable=False)
|
public_key: Mapped[bytes] = mapped_column(BLOB, nullable=False)
|
||||||
@ -107,16 +114,13 @@ class CredentialModel(Base):
|
|||||||
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
last_verified: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
last_verified: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
|
|
||||||
# Relationship to user
|
|
||||||
user: Mapped["UserModel"] = relationship("UserModel", back_populates="credentials")
|
|
||||||
|
|
||||||
|
|
||||||
class SessionModel(Base):
|
class SessionModel(Base):
|
||||||
__tablename__ = "sessions"
|
__tablename__ = "sessions"
|
||||||
|
|
||||||
key: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
|
key: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
|
||||||
user_uuid: Mapped[bytes] = mapped_column(
|
user_uuid: Mapped[bytes] = mapped_column(
|
||||||
LargeBinary(16), ForeignKey("users.user_uuid", ondelete="CASCADE")
|
LargeBinary(16), ForeignKey("users.uuid", ondelete="CASCADE")
|
||||||
)
|
)
|
||||||
credential_uuid: Mapped[bytes | None] = mapped_column(
|
credential_uuid: Mapped[bytes | None] = mapped_column(
|
||||||
LargeBinary(16), ForeignKey("credentials.uuid", ondelete="CASCADE")
|
LargeBinary(16), ForeignKey("credentials.uuid", ondelete="CASCADE")
|
||||||
@ -124,9 +128,6 @@ class SessionModel(Base):
|
|||||||
expires: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
expires: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||||
info: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
info: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||||
|
|
||||||
# Relationship to user
|
|
||||||
user: Mapped["UserModel"] = relationship("UserModel")
|
|
||||||
|
|
||||||
|
|
||||||
class DB(DatabaseInterface):
|
class DB(DatabaseInterface):
|
||||||
"""Database class that handles its own connections."""
|
"""Database class that handles its own connections."""
|
||||||
@ -152,16 +153,20 @@ class DB(DatabaseInterface):
|
|||||||
async with self.engine.begin() as conn:
|
async with self.engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
async def get_user_by_user_uuid(self, user_uuid: UUID) -> User:
|
async def get_user_by_uuid(self, user_uuid: UUID) -> User:
|
||||||
async with self.session() as session:
|
async with self.session() as session:
|
||||||
stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes)
|
stmt = select(UserModel).where(UserModel.uuid == user_uuid.bytes)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
user_model = result.scalar_one_or_none()
|
user_model = result.scalar_one_or_none()
|
||||||
|
|
||||||
if user_model:
|
if user_model:
|
||||||
return User(
|
return User(
|
||||||
user_uuid=UUID(bytes=user_model.user_uuid),
|
uuid=UUID(bytes=user_model.uuid),
|
||||||
user_name=user_model.user_name,
|
display_name=user_model.display_name,
|
||||||
|
org_uuid=UUID(bytes=user_model.org_uuid)
|
||||||
|
if user_model.org_uuid
|
||||||
|
else None,
|
||||||
|
role=user_model.role,
|
||||||
created_at=user_model.created_at,
|
created_at=user_model.created_at,
|
||||||
last_seen=user_model.last_seen,
|
last_seen=user_model.last_seen,
|
||||||
visits=user_model.visits,
|
visits=user_model.visits,
|
||||||
@ -171,8 +176,10 @@ class DB(DatabaseInterface):
|
|||||||
async def create_user(self, user: User) -> None:
|
async def create_user(self, user: User) -> None:
|
||||||
async with self.session() as session:
|
async with self.session() as session:
|
||||||
user_model = UserModel(
|
user_model = UserModel(
|
||||||
user_uuid=user.user_uuid.bytes,
|
uuid=user.uuid.bytes,
|
||||||
user_name=user.user_name,
|
display_name=user.display_name,
|
||||||
|
org_uuid=user.org_uuid.bytes if user.org_uuid else None,
|
||||||
|
role=user.role,
|
||||||
created_at=user.created_at or datetime.now(),
|
created_at=user.created_at or datetime.now(),
|
||||||
last_seen=user.last_seen,
|
last_seen=user.last_seen,
|
||||||
visits=user.visits,
|
visits=user.visits,
|
||||||
@ -256,7 +263,7 @@ class DB(DatabaseInterface):
|
|||||||
# Update user's last_seen and increment visits
|
# Update user's last_seen and increment visits
|
||||||
stmt = (
|
stmt = (
|
||||||
update(UserModel)
|
update(UserModel)
|
||||||
.where(UserModel.user_uuid == user_uuid.bytes)
|
.where(UserModel.uuid == user_uuid.bytes)
|
||||||
.values(last_seen=credential.last_used, visits=UserModel.visits + 1)
|
.values(last_seen=credential.last_used, visits=UserModel.visits + 1)
|
||||||
)
|
)
|
||||||
await session.execute(stmt)
|
await session.execute(stmt)
|
||||||
@ -270,8 +277,10 @@ class DB(DatabaseInterface):
|
|||||||
|
|
||||||
# Create user
|
# Create user
|
||||||
user_model = UserModel(
|
user_model = UserModel(
|
||||||
user_uuid=user.user_uuid.bytes,
|
uuid=user.uuid.bytes,
|
||||||
user_name=user.user_name,
|
display_name=user.display_name,
|
||||||
|
org_uuid=user.org_uuid.bytes if user.org_uuid else None,
|
||||||
|
role=user.role,
|
||||||
created_at=user.created_at or datetime.now(),
|
created_at=user.created_at or datetime.now(),
|
||||||
last_seen=user.last_seen,
|
last_seen=user.last_seen,
|
||||||
visits=user.visits,
|
visits=user.visits,
|
||||||
@ -399,7 +408,7 @@ class DB(DatabaseInterface):
|
|||||||
) -> None:
|
) -> None:
|
||||||
async with self.session() as session:
|
async with self.session() as session:
|
||||||
# Get user and organization models
|
# Get user and organization models
|
||||||
user_stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes)
|
user_stmt = select(UserModel).where(UserModel.uuid == user_uuid.bytes)
|
||||||
user_result = await session.execute(user_stmt)
|
user_result = await session.execute(user_stmt)
|
||||||
user_model = user_result.scalar_one_or_none()
|
user_model = user_result.scalar_one_or_none()
|
||||||
|
|
||||||
@ -414,67 +423,65 @@ class DB(DatabaseInterface):
|
|||||||
if not org_model:
|
if not org_model:
|
||||||
raise ValueError("Organization not found")
|
raise ValueError("Organization not found")
|
||||||
|
|
||||||
# Create the user-org relationship with role
|
# Update the user's organization and role
|
||||||
user_org = UserRole(
|
stmt = (
|
||||||
user_uuid=user_uuid.bytes, org_uuid=org_uuid.bytes, role=role
|
update(UserModel)
|
||||||
)
|
.where(UserModel.uuid == user_uuid.bytes)
|
||||||
session.add(user_org)
|
.values(org_uuid=org_uuid.bytes, role=role)
|
||||||
|
|
||||||
async def remove_user_from_organization(self, user_uuid: UUID, org_id: str) -> None:
|
|
||||||
async with self.session() as session:
|
|
||||||
# Convert string ID to UUID bytes for lookup
|
|
||||||
org_uuid = UUID(org_id)
|
|
||||||
# Delete the user-org relationship
|
|
||||||
stmt = delete(UserRole).where(
|
|
||||||
UserRole.user_uuid == user_uuid.bytes,
|
|
||||||
UserRole.org_uuid == org_uuid.bytes,
|
|
||||||
)
|
)
|
||||||
await session.execute(stmt)
|
await session.execute(stmt)
|
||||||
|
|
||||||
async def get_user_org_role(self, user_uuid: UUID) -> list[tuple[Org, str]]:
|
async def remove_user_from_organization(self, user_uuid: UUID) -> None:
|
||||||
async with self.session() as session:
|
async with self.session() as session:
|
||||||
stmt = select(UserRole).where(UserRole.user_uuid == user_uuid.bytes)
|
# Clear the user's organization and role
|
||||||
|
stmt = (
|
||||||
|
update(UserModel)
|
||||||
|
.where(UserModel.uuid == user_uuid.bytes)
|
||||||
|
.values(org_uuid=None, role=None)
|
||||||
|
)
|
||||||
|
await session.execute(stmt)
|
||||||
|
|
||||||
|
async def get_user_organization(self, user_uuid: UUID) -> tuple[Org, str] | None:
|
||||||
|
async with self.session() as session:
|
||||||
|
stmt = select(UserModel).where(UserModel.uuid == user_uuid.bytes)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
user_org_models = result.scalars().all()
|
user_model = result.scalar_one_or_none()
|
||||||
|
|
||||||
# Fetch the organization details for each user-org relationship
|
if not user_model or not user_model.org_uuid:
|
||||||
org_role_pairs = []
|
return None
|
||||||
for user_org in user_org_models:
|
|
||||||
org_stmt = select(OrgModel).where(OrgModel.uuid == user_org.org_uuid)
|
|
||||||
org_result = await session.execute(org_stmt)
|
|
||||||
org_model = org_result.scalar_one()
|
|
||||||
|
|
||||||
# Convert UUID bytes back to string for the interface
|
# Fetch the organization details
|
||||||
org = Org(id=str(UUID(bytes=org_model.uuid)), options=org_model.options)
|
org_stmt = select(OrgModel).where(OrgModel.uuid == user_model.org_uuid)
|
||||||
org_role_pairs.append((org, user_org.role))
|
org_result = await session.execute(org_stmt)
|
||||||
|
org_model = org_result.scalar_one()
|
||||||
|
|
||||||
return org_role_pairs
|
# Convert UUID bytes back to string for the interface
|
||||||
|
org = Org(id=str(UUID(bytes=org_model.uuid)), options=org_model.options)
|
||||||
|
return (org, user_model.role or "")
|
||||||
|
|
||||||
async def get_organization_users(self, org_id: str) -> list[tuple[User, str]]:
|
async def get_organization_users(self, org_id: str) -> list[tuple[User, str]]:
|
||||||
async with self.session() as session:
|
async with self.session() as session:
|
||||||
# Convert string ID to UUID bytes for lookup
|
# Convert string ID to UUID bytes for lookup
|
||||||
org_uuid = UUID(org_id)
|
org_uuid = UUID(org_id)
|
||||||
stmt = select(UserRole).where(UserRole.org_uuid == org_uuid.bytes)
|
stmt = select(UserModel).where(UserModel.org_uuid == org_uuid.bytes)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
user_org_models = result.scalars().all()
|
user_models = result.scalars().all()
|
||||||
|
|
||||||
# Fetch the user details for each user-org relationship
|
# Create user objects with their roles
|
||||||
user_role_pairs = []
|
user_role_pairs = []
|
||||||
for user_org in user_org_models:
|
for user_model in user_models:
|
||||||
user_stmt = select(UserModel).where(
|
|
||||||
UserModel.user_uuid == user_org.user_uuid
|
|
||||||
)
|
|
||||||
user_result = await session.execute(user_stmt)
|
|
||||||
user_model = user_result.scalar_one()
|
|
||||||
|
|
||||||
user = User(
|
user = User(
|
||||||
user_uuid=UUID(bytes=user_model.user_uuid),
|
uuid=UUID(bytes=user_model.uuid),
|
||||||
user_name=user_model.user_name,
|
display_name=user_model.display_name,
|
||||||
|
org_uuid=UUID(bytes=user_model.org_uuid)
|
||||||
|
if user_model.org_uuid
|
||||||
|
else None,
|
||||||
|
role=user_model.role,
|
||||||
created_at=user_model.created_at,
|
created_at=user_model.created_at,
|
||||||
last_seen=user_model.last_seen,
|
last_seen=user_model.last_seen,
|
||||||
visits=user_model.visits,
|
visits=user_model.visits,
|
||||||
)
|
)
|
||||||
user_role_pairs.append((user, user_org.role))
|
user_role_pairs.append((user, user_model.role or ""))
|
||||||
|
|
||||||
return user_role_pairs
|
return user_role_pairs
|
||||||
|
|
||||||
@ -485,31 +492,150 @@ class DB(DatabaseInterface):
|
|||||||
async with self.session() as session:
|
async with self.session() as session:
|
||||||
# Convert string ID to UUID bytes for lookup
|
# Convert string ID to UUID bytes for lookup
|
||||||
org_uuid = UUID(org_id)
|
org_uuid = UUID(org_id)
|
||||||
stmt = select(UserRole.role).where(
|
stmt = select(UserModel.role).where(
|
||||||
UserRole.user_uuid == user_uuid.bytes,
|
UserModel.uuid == user_uuid.bytes,
|
||||||
UserRole.org_uuid == org_uuid.bytes,
|
UserModel.org_uuid == org_uuid.bytes,
|
||||||
)
|
)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
async def update_user_role_in_organization(
|
async def update_user_role_in_organization(
|
||||||
self, user_uuid: UUID, org_id: str, new_role: str
|
self, user_uuid: UUID, new_role: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update a user's role in an organization."""
|
"""Update a user's role in their organization."""
|
||||||
async with self.session() as session:
|
async with self.session() as session:
|
||||||
# Convert string ID to UUID bytes for lookup
|
|
||||||
org_uuid = UUID(org_id)
|
|
||||||
stmt = (
|
stmt = (
|
||||||
update(UserRole)
|
update(UserModel)
|
||||||
.where(
|
.where(UserModel.uuid == user_uuid.bytes)
|
||||||
UserRole.user_uuid == user_uuid.bytes,
|
|
||||||
UserRole.org_uuid == org_uuid.bytes,
|
|
||||||
)
|
|
||||||
.values(role=new_role)
|
.values(role=new_role)
|
||||||
)
|
)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
if result.rowcount == 0:
|
if result.rowcount == 0:
|
||||||
raise ValueError("User is not a member of this organization")
|
raise ValueError("User not found")
|
||||||
|
|
||||||
|
# Permission operations
|
||||||
|
async def create_permission(self, permission: Permission) -> None:
|
||||||
|
async with self.session() as session:
|
||||||
|
permission_model = PermissionModel(
|
||||||
|
id=permission.id,
|
||||||
|
display_name=permission.display_name,
|
||||||
|
)
|
||||||
|
session.add(permission_model)
|
||||||
|
|
||||||
|
async def get_permission(self, permission_id: str) -> Permission:
|
||||||
|
async with self.session() as session:
|
||||||
|
stmt = select(PermissionModel).where(PermissionModel.id == permission_id)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
permission_model = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if permission_model:
|
||||||
|
return Permission(
|
||||||
|
id=permission_model.id,
|
||||||
|
display_name=permission_model.display_name,
|
||||||
|
)
|
||||||
|
raise ValueError("Permission not found")
|
||||||
|
|
||||||
|
async def update_permission(self, permission: Permission) -> None:
|
||||||
|
async with self.session() as session:
|
||||||
|
stmt = (
|
||||||
|
update(PermissionModel)
|
||||||
|
.where(PermissionModel.id == permission.id)
|
||||||
|
.values(display_name=permission.display_name)
|
||||||
|
)
|
||||||
|
await session.execute(stmt)
|
||||||
|
|
||||||
|
async def delete_permission(self, permission_id: str) -> None:
|
||||||
|
async with self.session() as session:
|
||||||
|
stmt = delete(PermissionModel).where(PermissionModel.id == permission_id)
|
||||||
|
await session.execute(stmt)
|
||||||
|
|
||||||
|
async def add_permission_to_organization(
|
||||||
|
self, org_id: str, permission_id: str
|
||||||
|
) -> None:
|
||||||
|
async with self.session() as session:
|
||||||
|
# Get organization and permission models
|
||||||
|
org_uuid = UUID(org_id)
|
||||||
|
org_stmt = select(OrgModel).where(OrgModel.uuid == org_uuid.bytes)
|
||||||
|
org_result = await session.execute(org_stmt)
|
||||||
|
org_model = org_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
permission_stmt = select(PermissionModel).where(
|
||||||
|
PermissionModel.id == permission_id
|
||||||
|
)
|
||||||
|
permission_result = await session.execute(permission_stmt)
|
||||||
|
permission_model = permission_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not org_model:
|
||||||
|
raise ValueError("Organization not found")
|
||||||
|
if not permission_model:
|
||||||
|
raise ValueError("Permission not found")
|
||||||
|
|
||||||
|
# Create the org-permission relationship
|
||||||
|
org_permission = OrgPermission(
|
||||||
|
org_uuid=org_uuid.bytes, permission_id=permission_id
|
||||||
|
)
|
||||||
|
session.add(org_permission)
|
||||||
|
|
||||||
|
async def remove_permission_from_organization(
|
||||||
|
self, org_id: str, permission_id: str
|
||||||
|
) -> None:
|
||||||
|
async with self.session() as session:
|
||||||
|
# Convert string ID to UUID bytes for lookup
|
||||||
|
org_uuid = UUID(org_id)
|
||||||
|
# Delete the org-permission relationship
|
||||||
|
stmt = delete(OrgPermission).where(
|
||||||
|
OrgPermission.org_uuid == org_uuid.bytes,
|
||||||
|
OrgPermission.permission_id == permission_id,
|
||||||
|
)
|
||||||
|
await session.execute(stmt)
|
||||||
|
|
||||||
|
async def get_organization_permissions(self, org_id: str) -> list[Permission]:
|
||||||
|
async with self.session() as session:
|
||||||
|
# Convert string ID to UUID bytes for lookup
|
||||||
|
org_uuid = UUID(org_id)
|
||||||
|
stmt = select(OrgPermission).where(OrgPermission.org_uuid == org_uuid.bytes)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
org_permission_models = result.scalars().all()
|
||||||
|
|
||||||
|
# Fetch the permission details for each org-permission relationship
|
||||||
|
permissions = []
|
||||||
|
for org_permission in org_permission_models:
|
||||||
|
permission_stmt = select(PermissionModel).where(
|
||||||
|
PermissionModel.id == org_permission.permission_id
|
||||||
|
)
|
||||||
|
permission_result = await session.execute(permission_stmt)
|
||||||
|
permission_model = permission_result.scalar_one()
|
||||||
|
|
||||||
|
permission = Permission(
|
||||||
|
id=permission_model.id,
|
||||||
|
display_name=permission_model.display_name,
|
||||||
|
)
|
||||||
|
permissions.append(permission)
|
||||||
|
|
||||||
|
return permissions
|
||||||
|
|
||||||
|
async def get_permission_organizations(self, permission_id: str) -> list[Org]:
|
||||||
|
async with self.session() as session:
|
||||||
|
stmt = select(OrgPermission).where(
|
||||||
|
OrgPermission.permission_id == permission_id
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
org_permission_models = result.scalars().all()
|
||||||
|
|
||||||
|
# Fetch the organization details for each org-permission relationship
|
||||||
|
organizations = []
|
||||||
|
for org_permission in org_permission_models:
|
||||||
|
org_stmt = select(OrgModel).where(
|
||||||
|
OrgModel.uuid == org_permission.org_uuid
|
||||||
|
)
|
||||||
|
org_result = await session.execute(org_stmt)
|
||||||
|
org_model = org_result.scalar_one()
|
||||||
|
|
||||||
|
# Convert UUID bytes back to string for the interface
|
||||||
|
org = Org(id=str(UUID(bytes=org_model.uuid)), options=org_model.options)
|
||||||
|
organizations.append(org)
|
||||||
|
|
||||||
|
return organizations
|
||||||
|
|
||||||
async def cleanup(self) -> None:
|
async def cleanup(self) -> None:
|
||||||
async with self.session() as session:
|
async with self.session() as session:
|
||||||
|
@ -24,7 +24,7 @@ def main():
|
|||||||
host=args.host,
|
host=args.host,
|
||||||
port=args.port,
|
port=args.port,
|
||||||
reload=args.dev,
|
reload=args.dev,
|
||||||
log_level="debug" if args.dev else "info",
|
log_level="info",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,11 +10,13 @@ This module contains all the HTTP API endpoints for:
|
|||||||
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import Cookie, Depends, FastAPI, Request, Response
|
from fastapi import Cookie, Depends, FastAPI, Response
|
||||||
from fastapi.security import HTTPBearer
|
from fastapi.security import HTTPBearer
|
||||||
|
|
||||||
|
from passkey.util import passphrase
|
||||||
|
|
||||||
from .. import aaguid
|
from .. import aaguid
|
||||||
from ..authsession import delete_credential, get_session
|
from ..authsession import delete_credential, get_reset, get_session
|
||||||
from ..db import db
|
from ..db import db
|
||||||
from ..util.tokens import session_key
|
from ..util.tokens import session_key
|
||||||
from . import session
|
from . import session
|
||||||
@ -26,7 +28,7 @@ def register_api_routes(app: FastAPI):
|
|||||||
"""Register all API routes on the FastAPI app."""
|
"""Register all API routes on the FastAPI app."""
|
||||||
|
|
||||||
@app.post("/auth/validate")
|
@app.post("/auth/validate")
|
||||||
async def validate_token(request: Request, response: Response, auth=Cookie(None)):
|
async def validate_token(auth=Cookie(None)):
|
||||||
"""Lightweight token validation endpoint."""
|
"""Lightweight token validation endpoint."""
|
||||||
try:
|
try:
|
||||||
s = await get_session(auth)
|
s = await get_session(auth)
|
||||||
@ -39,11 +41,12 @@ def register_api_routes(app: FastAPI):
|
|||||||
return {"status": "error", "valid": False}
|
return {"status": "error", "valid": False}
|
||||||
|
|
||||||
@app.post("/auth/user-info")
|
@app.post("/auth/user-info")
|
||||||
async def api_user_info(auth=Cookie(None)):
|
async def api_user_info(response: Response, auth=Cookie(None)):
|
||||||
"""Get full user information for the authenticated user."""
|
"""Get full user information for the authenticated user."""
|
||||||
try:
|
try:
|
||||||
s = await get_session(auth, reset_allowed=True)
|
reset = passphrase.is_well_formed(auth)
|
||||||
u = await db.instance.get_user_by_user_uuid(s.user_uuid)
|
s = await (get_reset if reset else get_session)(auth)
|
||||||
|
u = await db.instance.get_user_by_uuid(s.user_uuid)
|
||||||
# Get all credentials for the user
|
# Get all credentials for the user
|
||||||
credential_ids = await db.instance.get_credentials_by_user_uuid(s.user_uuid)
|
credential_ids = await db.instance.get_credentials_by_user_uuid(s.user_uuid)
|
||||||
|
|
||||||
@ -82,10 +85,11 @@ def register_api_routes(app: FastAPI):
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
|
"authenticated": not reset,
|
||||||
"session_type": s.info["type"],
|
"session_type": s.info["type"],
|
||||||
"user": {
|
"user": {
|
||||||
"user_uuid": str(u.user_uuid),
|
"user_uuid": str(u.uuid),
|
||||||
"user_name": u.user_name,
|
"user_name": u.display_name,
|
||||||
"created_at": u.created_at.isoformat() if u.created_at else None,
|
"created_at": u.created_at.isoformat() if u.created_at else None,
|
||||||
"last_seen": u.last_seen.isoformat() if u.last_seen else None,
|
"last_seen": u.last_seen.isoformat() if u.last_seen else None,
|
||||||
"visits": u.visits,
|
"visits": u.visits,
|
||||||
@ -94,8 +98,11 @@ def register_api_routes(app: FastAPI):
|
|||||||
"aaguid_info": aaguid_info,
|
"aaguid_info": aaguid_info,
|
||||||
}
|
}
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
response.status_code = 400
|
||||||
return {"error": f"Failed to get user info: {e}"}
|
return {"error": f"Failed to get user info: {e}"}
|
||||||
except Exception:
|
except Exception:
|
||||||
|
response.status_code = 500
|
||||||
|
|
||||||
return {"error": "Failed to get user info"}
|
return {"error": "Failed to get user info"}
|
||||||
|
|
||||||
@app.post("/auth/logout")
|
@app.post("/auth/logout")
|
||||||
@ -103,7 +110,11 @@ def register_api_routes(app: FastAPI):
|
|||||||
"""Log out the current user by clearing the session cookie and deleting from database."""
|
"""Log out the current user by clearing the session cookie and deleting from database."""
|
||||||
if not auth:
|
if not auth:
|
||||||
return {"status": "success", "message": "Already logged out"}
|
return {"status": "success", "message": "Already logged out"}
|
||||||
await db.instance.delete_session(session_key(auth))
|
# Remove from database if possible
|
||||||
|
try:
|
||||||
|
await db.instance.delete_session(session_key(auth))
|
||||||
|
except Exception:
|
||||||
|
...
|
||||||
response.delete_cookie("auth")
|
response.delete_cookie("auth")
|
||||||
return {"status": "success", "message": "Logged out successfully"}
|
return {"status": "success", "message": "Logged out successfully"}
|
||||||
|
|
||||||
@ -123,18 +134,24 @@ def register_api_routes(app: FastAPI):
|
|||||||
}
|
}
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
response.status_code = 400
|
||||||
return {"error": str(e)}
|
return {"error": str(e)}
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return {"error": f"Failed to set session: {e}"}
|
response.status_code = 500
|
||||||
|
return {"error": "Failed to set session"}
|
||||||
|
|
||||||
@app.delete("/auth/credential/{uuid}")
|
@app.delete("/auth/credential/{uuid}")
|
||||||
async def api_delete_credential(uuid: UUID, auth: str = Cookie(None)):
|
async def api_delete_credential(
|
||||||
|
response: Response, uuid: UUID, auth: str = Cookie(None)
|
||||||
|
):
|
||||||
"""Delete a specific credential for the current user."""
|
"""Delete a specific credential for the current user."""
|
||||||
try:
|
try:
|
||||||
await delete_credential(uuid, auth)
|
await delete_credential(uuid, auth)
|
||||||
return {"status": "success", "message": "Credential deleted successfully"}
|
return {"status": "success", "message": "Credential deleted successfully"}
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
response.status_code = 400
|
||||||
return {"error": str(e)}
|
return {"error": str(e)}
|
||||||
except Exception:
|
except Exception:
|
||||||
|
response.status_code = 500
|
||||||
return {"error": "Failed to delete credential"}
|
return {"error": "Failed to delete credential"}
|
||||||
|
@ -3,9 +3,7 @@ from contextlib import asynccontextmanager
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from fastapi import Cookie, FastAPI, Request, Response
|
from fastapi import Cookie, FastAPI, Request, Response
|
||||||
from fastapi.responses import (
|
from fastapi.responses import FileResponse
|
||||||
FileResponse,
|
|
||||||
)
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
from ..authsession import get_session
|
from ..authsession import get_session
|
||||||
@ -35,24 +33,21 @@ app = FastAPI(lifespan=lifespan)
|
|||||||
# Mount the WebSocket subapp
|
# Mount the WebSocket subapp
|
||||||
app.mount("/auth/ws", ws.app)
|
app.mount("/auth/ws", ws.app)
|
||||||
|
|
||||||
# Register API routes
|
|
||||||
register_api_routes(app)
|
|
||||||
register_reset_routes(app)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/auth/forward-auth")
|
@app.get("/auth/forward-auth")
|
||||||
async def forward_authentication(request: Request, auth=Cookie(None)):
|
async def forward_authentication(request: Request, auth=Cookie(None)):
|
||||||
"""A validation endpoint to use with Caddy forward_auth or Nginx auth_request."""
|
"""A validation endpoint to use with Caddy forward_auth or Nginx auth_request."""
|
||||||
with contextlib.suppress(ValueError):
|
if auth:
|
||||||
s = await get_session(auth)
|
with contextlib.suppress(ValueError):
|
||||||
# If authenticated, return a success response
|
s = await get_session(auth)
|
||||||
if s.info and s.info["type"] == "authenticated":
|
# If authenticated, return a success response
|
||||||
return Response(
|
if s.info and s.info["type"] == "authenticated":
|
||||||
status_code=204,
|
return Response(
|
||||||
headers={
|
status_code=204,
|
||||||
"x-auth-user-uuid": str(s.user_uuid),
|
headers={
|
||||||
},
|
"x-auth-user-uuid": str(s.user_uuid),
|
||||||
)
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Serve the index.html of the authentication app if not authenticated
|
# Serve the index.html of the authentication app if not authenticated
|
||||||
return FileResponse(
|
return FileResponse(
|
||||||
@ -72,3 +67,8 @@ app.mount(
|
|||||||
async def redirect_to_index():
|
async def redirect_to_index():
|
||||||
"""Serve the main authentication app."""
|
"""Serve the main authentication app."""
|
||||||
return FileResponse(STATIC_DIR / "index.html")
|
return FileResponse(STATIC_DIR / "index.html")
|
||||||
|
|
||||||
|
|
||||||
|
# Register API routes
|
||||||
|
register_api_routes(app)
|
||||||
|
register_reset_routes(app)
|
||||||
|
@ -97,30 +97,39 @@ async def websocket_register_new(
|
|||||||
@app.websocket("/add_credential")
|
@app.websocket("/add_credential")
|
||||||
async def websocket_register_add(ws: WebSocket, auth=Cookie(None)):
|
async def websocket_register_add(ws: WebSocket, auth=Cookie(None)):
|
||||||
"""Register a new credential for an existing user."""
|
"""Register a new credential for an existing user."""
|
||||||
print(auth)
|
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
origin = ws.headers.get("origin")
|
origin = ws.headers["origin"]
|
||||||
try:
|
try:
|
||||||
s = await get_session(auth, reset_allowed=True)
|
s = await get_session(auth, reset_allowed=True)
|
||||||
user_uuid = s.user_uuid
|
user_uuid = s.user_uuid
|
||||||
|
|
||||||
# Get user information to get the user_name
|
# Get user information to get the user_name
|
||||||
user = await db.instance.get_user_by_user_uuid(user_uuid)
|
user = await db.instance.get_user_by_uuid(user_uuid)
|
||||||
user_name = user.user_name
|
user_name = user.display_name
|
||||||
challenge_ids = await db.instance.get_credentials_by_user_uuid(user_uuid)
|
challenge_ids = await db.instance.get_credentials_by_user_uuid(user_uuid)
|
||||||
|
|
||||||
# WebAuthn registration
|
# WebAuthn registration
|
||||||
credential = await register_chat(
|
credential = await register_chat(
|
||||||
ws, user_uuid, user_name, challenge_ids, origin
|
ws, user_uuid, user_name, challenge_ids, origin
|
||||||
)
|
)
|
||||||
|
if s.info["type"] == "authenticated":
|
||||||
|
token = auth
|
||||||
|
else:
|
||||||
|
# Replace reset session with a new session
|
||||||
|
await db.instance.delete_session(s.key)
|
||||||
|
token = await create_session(
|
||||||
|
user_uuid, credential.uuid, infodict(ws, "authenticated")
|
||||||
|
)
|
||||||
|
assert isinstance(token, str) and len(token) == 16
|
||||||
# Store the new credential in the database
|
# Store the new credential in the database
|
||||||
await db.instance.create_credential(credential)
|
await db.instance.create_credential(credential)
|
||||||
|
|
||||||
await ws.send_json(
|
await ws.send_json(
|
||||||
{
|
{
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"user_uuid": str(user_uuid),
|
"user_uuid": str(user.uuid),
|
||||||
"credential_id": credential.credential_id.hex(),
|
"credential_uuid": str(credential.uuid),
|
||||||
|
"session_token": token,
|
||||||
"message": "New credential added successfully",
|
"message": "New credential added successfully",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -136,7 +145,7 @@ async def websocket_register_add(ws: WebSocket, auth=Cookie(None)):
|
|||||||
@app.websocket("/authenticate")
|
@app.websocket("/authenticate")
|
||||||
async def websocket_authenticate(ws: WebSocket):
|
async def websocket_authenticate(ws: WebSocket):
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
origin = ws.headers.get("origin")
|
origin = ws.headers["origin"]
|
||||||
try:
|
try:
|
||||||
options, challenge = passkey.auth_generate_options()
|
options, challenge = passkey.auth_generate_options()
|
||||||
await ws.send_json(options)
|
await ws.send_json(options)
|
||||||
|
@ -2,6 +2,8 @@ import base64
|
|||||||
import hashlib
|
import hashlib
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
|
from .passphrase import is_well_formed
|
||||||
|
|
||||||
|
|
||||||
def create_token() -> str:
|
def create_token() -> str:
|
||||||
return secrets.token_urlsafe(12) # 16 characters Base64
|
return secrets.token_urlsafe(12) # 16 characters Base64
|
||||||
@ -14,4 +16,10 @@ def session_key(token: str) -> bytes:
|
|||||||
|
|
||||||
|
|
||||||
def reset_key(passphrase: str) -> bytes:
|
def reset_key(passphrase: str) -> bytes:
|
||||||
|
if not is_well_formed(passphrase):
|
||||||
|
raise ValueError(
|
||||||
|
"Trying to reset with a session token in place of a passphrase"
|
||||||
|
if len(passphrase) == 16
|
||||||
|
else "Invalid passphrase format"
|
||||||
|
)
|
||||||
return b"rset" + hashlib.sha512(passphrase.encode()).digest()[:12]
|
return b"rset" + hashlib.sha512(passphrase.encode()).digest()[:12]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user