Checkpoint, fixing reset token handling broken in earlier edits.

This commit is contained in:
Leo Vasanko 2025-08-06 09:55:14 -06:00
parent c42864794a
commit cf138d90c5
11 changed files with 392 additions and 170 deletions

View File

@ -2,6 +2,7 @@
<div class="container">
<div class="view active">
<h1>🔑 Add Device Credential</h1>
<h3>👤 {{ authStore.currentUser.user_name }}</h3>
<button
class="btn-primary"
:disabled="authStore.isLoading"
@ -15,13 +16,13 @@
<script setup>
import { useAuthStore } from '@/stores/auth'
import { computed } from 'vue'
import { registerCredential } from '@/utils/passkey'
import { ref, computed } from 'vue'
const authStore = useAuthStore()
const hasDeviceSession = computed(() => !!authStore.currentUser)
function register() {
async function register() {
if (!hasDeviceSession.value) {
authStore.showMessage('No valid device addition session', 'error')
return
@ -29,18 +30,22 @@ function register() {
authStore.isLoading = true
authStore.showMessage('Starting registration...', 'info')
registerCredential().finally(() => {
authStore.isLoading = false
}).then(() => {
try {
const result = await registerCredential()
console.log("Result", result)
await authStore.setSessionCookie(result.session_token)
authStore.showMessage('Passkey registered successfully!', 'success', 2000)
authStore.currentView = 'profile'
}).catch((error) => {
} catch (error) {
console.error('Registration error:', error)
if (error.name === "NotAllowedError") {
authStore.showMessage('Registration cancelled', 'error')
} else {
authStore.showMessage(`Registration failed: ${error.message}`, 'error')
}
})
} finally {
authStore.isLoading = false
}
}
</script>

View File

@ -46,13 +46,12 @@ export const useAuthStore = defineStore('auth', {
try {
const result = await registerUser(user_name)
await this.setSessionCookie(result.session_token)
this.currentUser = {
user_id: result.user_id,
user_name: user_name,
}
await this.setSessionCookie(result.session_token)
return result
} finally {
this.isLoading = false

View File

@ -8,7 +8,6 @@ export async function register(url, options) {
const optionsJSON = await ws.receive_json()
const registrationResponse = await startRegistration({ optionsJSON })
ws.send_json(registrationResponse)
const result = await ws.receive_json()
ws.close()
return result;
@ -31,7 +30,6 @@ export async function authenticateUser() {
const optionsJSON = await ws.receive_json()
const authResponse = await startAuthentication({ optionsJSON })
ws.send_json(authResponse)
const result = await ws.receive_json()
ws.close()
return result

View File

@ -12,7 +12,6 @@ from datetime import datetime, timedelta
from uuid import UUID
from .db import Session, db
from .util import passphrase
from .util.tokens import create_token, reset_key, session_key
EXPIRES = timedelta(hours=24)
@ -22,29 +21,30 @@ def expires() -> datetime:
return datetime.now() + EXPIRES
async def create_session(user_uuid: UUID, info: dict, credential_uuid: UUID) -> str:
async def create_session(user_uuid: UUID, credential_uuid: UUID, info: dict) -> str:
"""Create a new session and return a session token."""
token = create_token()
await db.instance.create_session(
user_uuid=user_uuid,
credential_uuid=credential_uuid,
key=session_key(token),
expires=datetime.now() + EXPIRES,
info=info,
credential_uuid=credential_uuid,
)
return token
async def get_session(token: str, reset_allowed=False) -> Session:
"""Validate a session token and return session data if valid."""
if passphrase.is_well_formed(token):
if not reset_allowed:
raise ValueError("Reset link is not allowed for this endpoint")
key = reset_key(token)
else:
key = session_key(token)
async def get_reset(token: str) -> Session:
"""Validate a credential reset token. Returns None if the token is not well formed (i.e. it is another type of token)."""
session = await db.instance.get_session(reset_key(token))
if not session:
raise ValueError("Invalid or expired session token")
return session
session = await db.instance.get_session(key)
async def get_session(token: str) -> Session:
"""Validate a session token and return session data if valid."""
session = await db.instance.get_session(session_key(token))
if not session:
raise ValueError("Invalid or expired session token")
return session

View File

@ -13,8 +13,10 @@ from uuid import UUID
@dataclass
class User:
user_uuid: UUID
user_name: str
uuid: UUID
display_name: str
org_uuid: UUID | None = None
role: str | None = None
created_at: datetime | None = None
last_seen: datetime | None = None
visits: int = 0
@ -41,6 +43,14 @@ class Org:
options: dict
@dataclass
class Permission:
"""Permission data structure."""
id: str # String primary key (max 32 chars)
display_name: str
@dataclass
class Session:
"""Session data structure."""
@ -68,7 +78,7 @@ class DatabaseInterface(ABC):
# User operations
@abstractmethod
async def get_user_by_user_uuid(self, user_uuid: UUID) -> User:
async def get_user_by_uuid(self, user_uuid: UUID) -> User:
"""Get user record by WebAuthn user UUID."""
@abstractmethod
@ -147,20 +157,69 @@ class DatabaseInterface(ABC):
async def add_user_to_organization(
self, user_uuid: UUID, org_id: str, role: str
) -> None:
"""Add a user to an organization with a specific role."""
"""Set a user's organization and role."""
@abstractmethod
async def remove_user_from_organization(self, user_uuid: UUID, org_id: str) -> None:
"""Remove a user from an organization."""
async def remove_user_from_organization(self, user_uuid: UUID) -> None:
"""Remove a user from their organization."""
@abstractmethod
async def get_user_org_role(self, user_uuid: UUID) -> list[tuple[Org, str]]:
"""Get all organizations for a user with their roles."""
async def get_user_organization(self, user_uuid: UUID) -> tuple[Org, str] | None:
"""Get the organization and role for a user."""
@abstractmethod
async def get_organization_users(self, org_id: str) -> list[tuple[User, str]]:
"""Get all users in an organization with their roles."""
@abstractmethod
async def get_user_role_in_organization(
self, user_uuid: UUID, org_id: str
) -> str | None:
"""Get a user's role in a specific organization."""
@abstractmethod
async def update_user_role_in_organization(
self, user_uuid: UUID, new_role: str
) -> None:
"""Update a user's role in their organization."""
# Permission operations
@abstractmethod
async def create_permission(self, permission: Permission) -> None:
"""Create a new permission."""
@abstractmethod
async def get_permission(self, permission_id: str) -> Permission:
"""Get permission by ID."""
@abstractmethod
async def update_permission(self, permission: Permission) -> None:
"""Update permission details."""
@abstractmethod
async def delete_permission(self, permission_id: str) -> None:
"""Delete permission by ID."""
@abstractmethod
async def add_permission_to_organization(
self, org_id: str, permission_id: str
) -> None:
"""Add a permission to an organization."""
@abstractmethod
async def remove_permission_from_organization(
self, org_id: str, permission_id: str
) -> None:
"""Remove a permission from an organization."""
@abstractmethod
async def get_organization_permissions(self, org_id: str) -> list[Permission]:
"""Get all permissions assigned to an organization."""
@abstractmethod
async def get_permission_organizations(self, permission_id: str) -> list[Org]:
"""Get all organizations that have a specific permission."""
# Combined operations
@abstractmethod
async def login(self, user_uuid: UUID, credential: Credential) -> None:
@ -200,6 +259,7 @@ __all__ = [
"Credential",
"Session",
"Org",
"Permission",
"DatabaseInterface",
"db",
]

View File

@ -21,9 +21,9 @@ from sqlalchemy import (
)
from sqlalchemy.dialects.sqlite import BLOB, JSON
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from . import Credential, DatabaseInterface, Org, Session, User, db
from . import Credential, DatabaseInterface, Org, Permission, Session, User, db
DB_PATH = "sqlite+aiosqlite:///passkey-auth.sqlite"
@ -38,25 +38,43 @@ class Base(DeclarativeBase):
pass
# Association model for many-to-many relationship between users and organizations with roles
class UserRole(Base):
__tablename__ = "user_roles"
# Association model for many-to-many relationship between organizations and permissions
class OrgPermission(Base):
"""Permissions each Org is allowed to grant to its roles."""
__tablename__ = "org_permissions"
user_uuid: Mapped[bytes] = mapped_column(
LargeBinary(16),
ForeignKey("users.user_uuid", ondelete="CASCADE"),
primary_key=True,
)
org_uuid: Mapped[bytes] = mapped_column(
LargeBinary(16),
ForeignKey("orgs.uuid", ondelete="CASCADE"),
primary_key=True,
)
role: Mapped[str] = mapped_column(String, nullable=False)
permission_id: Mapped[str] = mapped_column(
String(32),
ForeignKey("permissions.id", ondelete="CASCADE"),
primary_key=True,
)
"""Permissions each Org is allowed to grant to its roles."""
# Relationships to the actual models
user: Mapped["UserModel"] = relationship("UserModel", back_populates="user_orgs")
org: Mapped["OrgModel"] = relationship("OrgModel", back_populates="user_orgs")
__tablename__ = "org_permissions"
org_uuid: Mapped[bytes] = mapped_column(
LargeBinary(16),
ForeignKey("orgs.uuid", ondelete="CASCADE"),
primary_key=True,
)
permission_id: Mapped[str] = mapped_column(
String(32),
ForeignKey("permissions.id", ondelete="CASCADE"),
primary_key=True,
)
class PermissionModel(Base):
__tablename__ = "permissions"
id: Mapped[str] = mapped_column(String(128), primary_key=True)
display_name: Mapped[str] = mapped_column(String, nullable=False)
class OrgModel(Base):
@ -65,31 +83,20 @@ class OrgModel(Base):
uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
options: Mapped[dict] = mapped_column(JSON, nullable=False, default=dict)
# Relationship to user-org associations
user_orgs: Mapped[list["UserRole"]] = relationship(
"UserRoleModel", back_populates="org", cascade="all, delete-orphan"
)
class UserModel(Base):
__tablename__ = "users"
user_uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
user_name: Mapped[str] = mapped_column(String, nullable=False)
uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
display_name: Mapped[str] = mapped_column(String, nullable=False)
org_uuid: Mapped[bytes | None] = mapped_column(
LargeBinary(16), ForeignKey("orgs.uuid", ondelete="SET NULL"), nullable=True
)
role: Mapped[str | None] = mapped_column(String, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
last_seen: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
visits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
# Relationship to credentials
credentials: Mapped[list["CredentialModel"]] = relationship(
"CredentialModel", back_populates="user", cascade="all, delete-orphan"
)
# Relationship to user-org associations
user_orgs: Mapped[list["UserRole"]] = relationship(
"UserOrgModel", back_populates="user", cascade="all, delete-orphan"
)
class CredentialModel(Base):
__tablename__ = "credentials"
@ -98,7 +105,7 @@ class CredentialModel(Base):
LargeBinary(64), unique=True, index=True
)
user_uuid: Mapped[bytes] = mapped_column(
LargeBinary(16), ForeignKey("users.user_uuid", ondelete="CASCADE")
LargeBinary(16), ForeignKey("users.uuid", ondelete="CASCADE")
)
aaguid: Mapped[bytes] = mapped_column(LargeBinary(16), nullable=False)
public_key: Mapped[bytes] = mapped_column(BLOB, nullable=False)
@ -107,16 +114,13 @@ class CredentialModel(Base):
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
last_verified: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
# Relationship to user
user: Mapped["UserModel"] = relationship("UserModel", back_populates="credentials")
class SessionModel(Base):
__tablename__ = "sessions"
key: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
user_uuid: Mapped[bytes] = mapped_column(
LargeBinary(16), ForeignKey("users.user_uuid", ondelete="CASCADE")
LargeBinary(16), ForeignKey("users.uuid", ondelete="CASCADE")
)
credential_uuid: Mapped[bytes | None] = mapped_column(
LargeBinary(16), ForeignKey("credentials.uuid", ondelete="CASCADE")
@ -124,9 +128,6 @@ class SessionModel(Base):
expires: Mapped[datetime] = mapped_column(DateTime, nullable=False)
info: Mapped[dict | None] = mapped_column(JSON, nullable=True)
# Relationship to user
user: Mapped["UserModel"] = relationship("UserModel")
class DB(DatabaseInterface):
"""Database class that handles its own connections."""
@ -152,16 +153,20 @@ class DB(DatabaseInterface):
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def get_user_by_user_uuid(self, user_uuid: UUID) -> User:
async def get_user_by_uuid(self, user_uuid: UUID) -> User:
async with self.session() as session:
stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes)
stmt = select(UserModel).where(UserModel.uuid == user_uuid.bytes)
result = await session.execute(stmt)
user_model = result.scalar_one_or_none()
if user_model:
return User(
user_uuid=UUID(bytes=user_model.user_uuid),
user_name=user_model.user_name,
uuid=UUID(bytes=user_model.uuid),
display_name=user_model.display_name,
org_uuid=UUID(bytes=user_model.org_uuid)
if user_model.org_uuid
else None,
role=user_model.role,
created_at=user_model.created_at,
last_seen=user_model.last_seen,
visits=user_model.visits,
@ -171,8 +176,10 @@ class DB(DatabaseInterface):
async def create_user(self, user: User) -> None:
async with self.session() as session:
user_model = UserModel(
user_uuid=user.user_uuid.bytes,
user_name=user.user_name,
uuid=user.uuid.bytes,
display_name=user.display_name,
org_uuid=user.org_uuid.bytes if user.org_uuid else None,
role=user.role,
created_at=user.created_at or datetime.now(),
last_seen=user.last_seen,
visits=user.visits,
@ -256,7 +263,7 @@ class DB(DatabaseInterface):
# Update user's last_seen and increment visits
stmt = (
update(UserModel)
.where(UserModel.user_uuid == user_uuid.bytes)
.where(UserModel.uuid == user_uuid.bytes)
.values(last_seen=credential.last_used, visits=UserModel.visits + 1)
)
await session.execute(stmt)
@ -270,8 +277,10 @@ class DB(DatabaseInterface):
# Create user
user_model = UserModel(
user_uuid=user.user_uuid.bytes,
user_name=user.user_name,
uuid=user.uuid.bytes,
display_name=user.display_name,
org_uuid=user.org_uuid.bytes if user.org_uuid else None,
role=user.role,
created_at=user.created_at or datetime.now(),
last_seen=user.last_seen,
visits=user.visits,
@ -399,7 +408,7 @@ class DB(DatabaseInterface):
) -> None:
async with self.session() as session:
# Get user and organization models
user_stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes)
user_stmt = select(UserModel).where(UserModel.uuid == user_uuid.bytes)
user_result = await session.execute(user_stmt)
user_model = user_result.scalar_one_or_none()
@ -414,67 +423,65 @@ class DB(DatabaseInterface):
if not org_model:
raise ValueError("Organization not found")
# Create the user-org relationship with role
user_org = UserRole(
user_uuid=user_uuid.bytes, org_uuid=org_uuid.bytes, role=role
)
session.add(user_org)
async def remove_user_from_organization(self, user_uuid: UUID, org_id: str) -> None:
async with self.session() as session:
# Convert string ID to UUID bytes for lookup
org_uuid = UUID(org_id)
# Delete the user-org relationship
stmt = delete(UserRole).where(
UserRole.user_uuid == user_uuid.bytes,
UserRole.org_uuid == org_uuid.bytes,
# Update the user's organization and role
stmt = (
update(UserModel)
.where(UserModel.uuid == user_uuid.bytes)
.values(org_uuid=org_uuid.bytes, role=role)
)
await session.execute(stmt)
async def get_user_org_role(self, user_uuid: UUID) -> list[tuple[Org, str]]:
async def remove_user_from_organization(self, user_uuid: UUID) -> None:
async with self.session() as session:
stmt = select(UserRole).where(UserRole.user_uuid == user_uuid.bytes)
# Clear the user's organization and role
stmt = (
update(UserModel)
.where(UserModel.uuid == user_uuid.bytes)
.values(org_uuid=None, role=None)
)
await session.execute(stmt)
async def get_user_organization(self, user_uuid: UUID) -> tuple[Org, str] | None:
async with self.session() as session:
stmt = select(UserModel).where(UserModel.uuid == user_uuid.bytes)
result = await session.execute(stmt)
user_org_models = result.scalars().all()
user_model = result.scalar_one_or_none()
# Fetch the organization details for each user-org relationship
org_role_pairs = []
for user_org in user_org_models:
org_stmt = select(OrgModel).where(OrgModel.uuid == user_org.org_uuid)
org_result = await session.execute(org_stmt)
org_model = org_result.scalar_one()
if not user_model or not user_model.org_uuid:
return None
# Convert UUID bytes back to string for the interface
org = Org(id=str(UUID(bytes=org_model.uuid)), options=org_model.options)
org_role_pairs.append((org, user_org.role))
# Fetch the organization details
org_stmt = select(OrgModel).where(OrgModel.uuid == user_model.org_uuid)
org_result = await session.execute(org_stmt)
org_model = org_result.scalar_one()
return org_role_pairs
# Convert UUID bytes back to string for the interface
org = Org(id=str(UUID(bytes=org_model.uuid)), options=org_model.options)
return (org, user_model.role or "")
async def get_organization_users(self, org_id: str) -> list[tuple[User, str]]:
async with self.session() as session:
# Convert string ID to UUID bytes for lookup
org_uuid = UUID(org_id)
stmt = select(UserRole).where(UserRole.org_uuid == org_uuid.bytes)
stmt = select(UserModel).where(UserModel.org_uuid == org_uuid.bytes)
result = await session.execute(stmt)
user_org_models = result.scalars().all()
user_models = result.scalars().all()
# Fetch the user details for each user-org relationship
# Create user objects with their roles
user_role_pairs = []
for user_org in user_org_models:
user_stmt = select(UserModel).where(
UserModel.user_uuid == user_org.user_uuid
)
user_result = await session.execute(user_stmt)
user_model = user_result.scalar_one()
for user_model in user_models:
user = User(
user_uuid=UUID(bytes=user_model.user_uuid),
user_name=user_model.user_name,
uuid=UUID(bytes=user_model.uuid),
display_name=user_model.display_name,
org_uuid=UUID(bytes=user_model.org_uuid)
if user_model.org_uuid
else None,
role=user_model.role,
created_at=user_model.created_at,
last_seen=user_model.last_seen,
visits=user_model.visits,
)
user_role_pairs.append((user, user_org.role))
user_role_pairs.append((user, user_model.role or ""))
return user_role_pairs
@ -485,31 +492,150 @@ class DB(DatabaseInterface):
async with self.session() as session:
# Convert string ID to UUID bytes for lookup
org_uuid = UUID(org_id)
stmt = select(UserRole.role).where(
UserRole.user_uuid == user_uuid.bytes,
UserRole.org_uuid == org_uuid.bytes,
stmt = select(UserModel.role).where(
UserModel.uuid == user_uuid.bytes,
UserModel.org_uuid == org_uuid.bytes,
)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def update_user_role_in_organization(
self, user_uuid: UUID, org_id: str, new_role: str
self, user_uuid: UUID, new_role: str
) -> None:
"""Update a user's role in an organization."""
"""Update a user's role in their organization."""
async with self.session() as session:
# Convert string ID to UUID bytes for lookup
org_uuid = UUID(org_id)
stmt = (
update(UserRole)
.where(
UserRole.user_uuid == user_uuid.bytes,
UserRole.org_uuid == org_uuid.bytes,
)
update(UserModel)
.where(UserModel.uuid == user_uuid.bytes)
.values(role=new_role)
)
result = await session.execute(stmt)
if result.rowcount == 0:
raise ValueError("User is not a member of this organization")
raise ValueError("User not found")
# Permission operations
async def create_permission(self, permission: Permission) -> None:
async with self.session() as session:
permission_model = PermissionModel(
id=permission.id,
display_name=permission.display_name,
)
session.add(permission_model)
async def get_permission(self, permission_id: str) -> Permission:
async with self.session() as session:
stmt = select(PermissionModel).where(PermissionModel.id == permission_id)
result = await session.execute(stmt)
permission_model = result.scalar_one_or_none()
if permission_model:
return Permission(
id=permission_model.id,
display_name=permission_model.display_name,
)
raise ValueError("Permission not found")
async def update_permission(self, permission: Permission) -> None:
async with self.session() as session:
stmt = (
update(PermissionModel)
.where(PermissionModel.id == permission.id)
.values(display_name=permission.display_name)
)
await session.execute(stmt)
async def delete_permission(self, permission_id: str) -> None:
async with self.session() as session:
stmt = delete(PermissionModel).where(PermissionModel.id == permission_id)
await session.execute(stmt)
async def add_permission_to_organization(
self, org_id: str, permission_id: str
) -> None:
async with self.session() as session:
# Get organization and permission models
org_uuid = UUID(org_id)
org_stmt = select(OrgModel).where(OrgModel.uuid == org_uuid.bytes)
org_result = await session.execute(org_stmt)
org_model = org_result.scalar_one_or_none()
permission_stmt = select(PermissionModel).where(
PermissionModel.id == permission_id
)
permission_result = await session.execute(permission_stmt)
permission_model = permission_result.scalar_one_or_none()
if not org_model:
raise ValueError("Organization not found")
if not permission_model:
raise ValueError("Permission not found")
# Create the org-permission relationship
org_permission = OrgPermission(
org_uuid=org_uuid.bytes, permission_id=permission_id
)
session.add(org_permission)
async def remove_permission_from_organization(
self, org_id: str, permission_id: str
) -> None:
async with self.session() as session:
# Convert string ID to UUID bytes for lookup
org_uuid = UUID(org_id)
# Delete the org-permission relationship
stmt = delete(OrgPermission).where(
OrgPermission.org_uuid == org_uuid.bytes,
OrgPermission.permission_id == permission_id,
)
await session.execute(stmt)
async def get_organization_permissions(self, org_id: str) -> list[Permission]:
async with self.session() as session:
# Convert string ID to UUID bytes for lookup
org_uuid = UUID(org_id)
stmt = select(OrgPermission).where(OrgPermission.org_uuid == org_uuid.bytes)
result = await session.execute(stmt)
org_permission_models = result.scalars().all()
# Fetch the permission details for each org-permission relationship
permissions = []
for org_permission in org_permission_models:
permission_stmt = select(PermissionModel).where(
PermissionModel.id == org_permission.permission_id
)
permission_result = await session.execute(permission_stmt)
permission_model = permission_result.scalar_one()
permission = Permission(
id=permission_model.id,
display_name=permission_model.display_name,
)
permissions.append(permission)
return permissions
async def get_permission_organizations(self, permission_id: str) -> list[Org]:
async with self.session() as session:
stmt = select(OrgPermission).where(
OrgPermission.permission_id == permission_id
)
result = await session.execute(stmt)
org_permission_models = result.scalars().all()
# Fetch the organization details for each org-permission relationship
organizations = []
for org_permission in org_permission_models:
org_stmt = select(OrgModel).where(
OrgModel.uuid == org_permission.org_uuid
)
org_result = await session.execute(org_stmt)
org_model = org_result.scalar_one()
# Convert UUID bytes back to string for the interface
org = Org(id=str(UUID(bytes=org_model.uuid)), options=org_model.options)
organizations.append(org)
return organizations
async def cleanup(self) -> None:
async with self.session() as session:

View File

@ -24,7 +24,7 @@ def main():
host=args.host,
port=args.port,
reload=args.dev,
log_level="debug" if args.dev else "info",
log_level="info",
)

View File

@ -10,11 +10,13 @@ This module contains all the HTTP API endpoints for:
from uuid import UUID
from fastapi import Cookie, Depends, FastAPI, Request, Response
from fastapi import Cookie, Depends, FastAPI, Response
from fastapi.security import HTTPBearer
from passkey.util import passphrase
from .. import aaguid
from ..authsession import delete_credential, get_session
from ..authsession import delete_credential, get_reset, get_session
from ..db import db
from ..util.tokens import session_key
from . import session
@ -26,7 +28,7 @@ def register_api_routes(app: FastAPI):
"""Register all API routes on the FastAPI app."""
@app.post("/auth/validate")
async def validate_token(request: Request, response: Response, auth=Cookie(None)):
async def validate_token(auth=Cookie(None)):
"""Lightweight token validation endpoint."""
try:
s = await get_session(auth)
@ -39,11 +41,12 @@ def register_api_routes(app: FastAPI):
return {"status": "error", "valid": False}
@app.post("/auth/user-info")
async def api_user_info(auth=Cookie(None)):
async def api_user_info(response: Response, auth=Cookie(None)):
"""Get full user information for the authenticated user."""
try:
s = await get_session(auth, reset_allowed=True)
u = await db.instance.get_user_by_user_uuid(s.user_uuid)
reset = passphrase.is_well_formed(auth)
s = await (get_reset if reset else get_session)(auth)
u = await db.instance.get_user_by_uuid(s.user_uuid)
# Get all credentials for the user
credential_ids = await db.instance.get_credentials_by_user_uuid(s.user_uuid)
@ -82,10 +85,11 @@ def register_api_routes(app: FastAPI):
return {
"status": "success",
"authenticated": not reset,
"session_type": s.info["type"],
"user": {
"user_uuid": str(u.user_uuid),
"user_name": u.user_name,
"user_uuid": str(u.uuid),
"user_name": u.display_name,
"created_at": u.created_at.isoformat() if u.created_at else None,
"last_seen": u.last_seen.isoformat() if u.last_seen else None,
"visits": u.visits,
@ -94,8 +98,11 @@ def register_api_routes(app: FastAPI):
"aaguid_info": aaguid_info,
}
except ValueError as e:
response.status_code = 400
return {"error": f"Failed to get user info: {e}"}
except Exception:
response.status_code = 500
return {"error": "Failed to get user info"}
@app.post("/auth/logout")
@ -103,7 +110,11 @@ def register_api_routes(app: FastAPI):
"""Log out the current user by clearing the session cookie and deleting from database."""
if not auth:
return {"status": "success", "message": "Already logged out"}
await db.instance.delete_session(session_key(auth))
# Remove from database if possible
try:
await db.instance.delete_session(session_key(auth))
except Exception:
...
response.delete_cookie("auth")
return {"status": "success", "message": "Logged out successfully"}
@ -123,18 +134,24 @@ def register_api_routes(app: FastAPI):
}
except ValueError as e:
response.status_code = 400
return {"error": str(e)}
except Exception as e:
return {"error": f"Failed to set session: {e}"}
except Exception:
response.status_code = 500
return {"error": "Failed to set session"}
@app.delete("/auth/credential/{uuid}")
async def api_delete_credential(uuid: UUID, auth: str = Cookie(None)):
async def api_delete_credential(
response: Response, uuid: UUID, auth: str = Cookie(None)
):
"""Delete a specific credential for the current user."""
try:
await delete_credential(uuid, auth)
return {"status": "success", "message": "Credential deleted successfully"}
except ValueError as e:
response.status_code = 400
return {"error": str(e)}
except Exception:
response.status_code = 500
return {"error": "Failed to delete credential"}

View File

@ -3,9 +3,7 @@ from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import Cookie, FastAPI, Request, Response
from fastapi.responses import (
FileResponse,
)
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from ..authsession import get_session
@ -35,24 +33,21 @@ app = FastAPI(lifespan=lifespan)
# Mount the WebSocket subapp
app.mount("/auth/ws", ws.app)
# Register API routes
register_api_routes(app)
register_reset_routes(app)
@app.get("/auth/forward-auth")
async def forward_authentication(request: Request, auth=Cookie(None)):
"""A validation endpoint to use with Caddy forward_auth or Nginx auth_request."""
with contextlib.suppress(ValueError):
s = await get_session(auth)
# If authenticated, return a success response
if s.info and s.info["type"] == "authenticated":
return Response(
status_code=204,
headers={
"x-auth-user-uuid": str(s.user_uuid),
},
)
if auth:
with contextlib.suppress(ValueError):
s = await get_session(auth)
# If authenticated, return a success response
if s.info and s.info["type"] == "authenticated":
return Response(
status_code=204,
headers={
"x-auth-user-uuid": str(s.user_uuid),
},
)
# Serve the index.html of the authentication app if not authenticated
return FileResponse(
@ -72,3 +67,8 @@ app.mount(
async def redirect_to_index():
"""Serve the main authentication app."""
return FileResponse(STATIC_DIR / "index.html")
# Register API routes
register_api_routes(app)
register_reset_routes(app)

View File

@ -97,30 +97,39 @@ async def websocket_register_new(
@app.websocket("/add_credential")
async def websocket_register_add(ws: WebSocket, auth=Cookie(None)):
"""Register a new credential for an existing user."""
print(auth)
await ws.accept()
origin = ws.headers.get("origin")
origin = ws.headers["origin"]
try:
s = await get_session(auth, reset_allowed=True)
user_uuid = s.user_uuid
# Get user information to get the user_name
user = await db.instance.get_user_by_user_uuid(user_uuid)
user_name = user.user_name
user = await db.instance.get_user_by_uuid(user_uuid)
user_name = user.display_name
challenge_ids = await db.instance.get_credentials_by_user_uuid(user_uuid)
# WebAuthn registration
credential = await register_chat(
ws, user_uuid, user_name, challenge_ids, origin
)
if s.info["type"] == "authenticated":
token = auth
else:
# Replace reset session with a new session
await db.instance.delete_session(s.key)
token = await create_session(
user_uuid, credential.uuid, infodict(ws, "authenticated")
)
assert isinstance(token, str) and len(token) == 16
# Store the new credential in the database
await db.instance.create_credential(credential)
await ws.send_json(
{
"status": "success",
"user_uuid": str(user_uuid),
"credential_id": credential.credential_id.hex(),
"user_uuid": str(user.uuid),
"credential_uuid": str(credential.uuid),
"session_token": token,
"message": "New credential added successfully",
}
)
@ -136,7 +145,7 @@ async def websocket_register_add(ws: WebSocket, auth=Cookie(None)):
@app.websocket("/authenticate")
async def websocket_authenticate(ws: WebSocket):
await ws.accept()
origin = ws.headers.get("origin")
origin = ws.headers["origin"]
try:
options, challenge = passkey.auth_generate_options()
await ws.send_json(options)

View File

@ -2,6 +2,8 @@ import base64
import hashlib
import secrets
from .passphrase import is_well_formed
def create_token() -> str:
return secrets.token_urlsafe(12) # 16 characters Base64
@ -14,4 +16,10 @@ def session_key(token: str) -> bytes:
def reset_key(passphrase: str) -> bytes:
if not is_well_formed(passphrase):
raise ValueError(
"Trying to reset with a session token in place of a passphrase"
if len(passphrase) == 16
else "Invalid passphrase format"
)
return b"rset" + hashlib.sha512(passphrase.encode()).digest()[:12]