Major cleanup and refactoring of the backend (frontend not fully updated).

This commit is contained in:
Leo Vasanko 2025-08-01 12:32:27 -06:00
parent 0cfa622bf1
commit c5e5fe23e3
16 changed files with 451 additions and 920 deletions

View File

@ -22,7 +22,12 @@ import AddCredentialView from '@/components/AddCredentialView.vue'
const store = useAuthStore()
onMounted(async () => {
// Check for device addition session first
// Was an error message passed in the URL?
const message = location.hash.substring(1)
if (message) {
store.showMessage(decodeURIComponent(message), 'error')
history.replaceState(null, '', location.pathname)
}
try {
await store.loadUserInfo()
} catch (error) {

View File

@ -15,7 +15,7 @@
<script setup>
import { useAuthStore } from '@/stores/auth'
import { registerWithSession } from '@/utils/passkey'
import { registerCredential } from '@/utils/passkey'
import { ref, onMounted } from 'vue'
const authStore = useAuthStore()
@ -25,9 +25,7 @@ const hasDeviceSession = ref(false)
onMounted(async () => {
try {
// Check if we have a device addition session
const response = await fetch('/auth/device-session-check', {
credentials: 'include'
})
const response = await fetch('/auth/device-session-check')
const data = await response.json()
if (data.device_addition_session) {
@ -50,7 +48,7 @@ function register() {
authStore.isLoading = true
authStore.showMessage('Starting registration...', 'info')
registerWithSession().finally(() => {
registerCredential().finally(() => {
authStore.isLoading = false
}).then(() => {
authStore.showMessage('Passkey registered successfully!', 'success', 2000)

View File

@ -46,7 +46,7 @@ const copyLink = async (event) => {
onMounted(async () => {
try {
const response = await fetch('/auth/create-device-link', { method: 'POST' })
const response = await fetch('/auth/create-link', { method: 'POST' })
const result = await response.json()
if (result.error) throw new Error(result.error)
@ -63,7 +63,7 @@ onMounted(async () => {
})
}
} catch (error) {
console.error('Failed to fetch device link:', error)
console.error('Failed to create link:', error)
}
})
</script>

View File

@ -33,10 +33,7 @@ export const useAuthStore = defineStore('auth', {
async setSessionCookie(sessionToken) {
const response = await fetch('/auth/set-session', {
method: 'POST',
headers: {
'Authorization': `Bearer ${sessionToken}`,
'Content-Type': 'application/json'
},
headers: {'Authorization': `Bearer ${sessionToken}`},
})
const result = await response.json()
if (result.error) {

View File

@ -22,10 +22,7 @@ export async function registerCredential() {
return register('/auth/ws/add_credential')
}
export async function registerWithToken(token) {
return register('/auth/ws/add_device_credential', { token })
}
export async function registerWithSession() {
return register('/auth/ws/add_device_credential_session')
return register('/auth/ws/add_credential', { token })
}
export async function authenticateUser() {

50
passkey/db/__init__.py Normal file
View File

@ -0,0 +1,50 @@
"""
Database module for WebAuthn passkey authentication.
This module provides dataclasses and database abstractions for managing
users, credentials, and sessions in a WebAuthn authentication system.
"""
from dataclasses import dataclass
from datetime import datetime
from uuid import UUID
@dataclass
class User:
"""User data structure."""
user_uuid: UUID
user_name: str
created_at: datetime | None = None
last_seen: datetime | None = None
visits: int = 0
@dataclass
class Credential:
"""Credential data structure."""
uuid: UUID
credential_id: bytes
user_uuid: UUID
aaguid: UUID
public_key: bytes
sign_count: int
created_at: datetime
last_used: datetime | None = None
last_verified: datetime | None = None
@dataclass
class Session:
"""Session data structure."""
key: bytes
user_uuid: UUID
expires: datetime
credential_uuid: UUID | None = None
info: dict | None = None
__all__ = ["User", "Credential", "Session"]

View File

@ -5,10 +5,8 @@ This module provides an async database layer using SQLAlchemy async mode
for managing users and credentials in a WebAuthn authentication system.
"""
import secrets
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import datetime, timedelta
from datetime import datetime
from uuid import UUID
from sqlalchemy import (
@ -25,7 +23,7 @@ from sqlalchemy.dialects.sqlite import BLOB, JSON
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from ..sansio import StoredCredential
from . import Credential, Session, User
DB_PATH = "sqlite+aiosqlite:///webauthn.db"
@ -38,7 +36,7 @@ class Base(DeclarativeBase):
class UserModel(Base):
__tablename__ = "users"
user_id: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
user_uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
user_name: Mapped[str] = mapped_column(String, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
last_seen: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
@ -52,10 +50,12 @@ class UserModel(Base):
class CredentialModel(Base):
__tablename__ = "credentials"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
credential_id: Mapped[bytes] = mapped_column(LargeBinary(64), unique=True)
user_id: Mapped[bytes] = mapped_column(
LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
credential_id: Mapped[bytes] = mapped_column(
LargeBinary(64), unique=True, index=True
)
user_uuid: Mapped[bytes] = mapped_column(
LargeBinary(16), ForeignKey("users.user_uuid", ondelete="CASCADE")
)
aaguid: Mapped[bytes] = mapped_column(LargeBinary(16), nullable=False)
public_key: Mapped[bytes] = mapped_column(BLOB, nullable=False)
@ -71,29 +71,20 @@ class CredentialModel(Base):
class SessionModel(Base):
__tablename__ = "sessions"
token: Mapped[str] = mapped_column(String(32), primary_key=True)
user_id: Mapped[bytes] = mapped_column(
LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
key: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
user_uuid: Mapped[bytes] = mapped_column(
LargeBinary(16), ForeignKey("users.user_uuid", ondelete="CASCADE")
)
credential_id: Mapped[int | None] = mapped_column(
Integer, ForeignKey("credentials.id", ondelete="SET NULL")
credential_uuid: Mapped[bytes | None] = mapped_column(
LargeBinary(16), ForeignKey("credentials.uuid", ondelete="CASCADE")
)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
expires: Mapped[datetime] = mapped_column(DateTime, nullable=False)
info: Mapped[dict | None] = mapped_column(JSON, nullable=True)
# Relationship to user
user: Mapped["UserModel"] = relationship("UserModel")
@dataclass
class User:
user_id: UUID
user_name: str
created_at: datetime | None = None
last_seen: datetime | None = None
visits: int = 0
# Global engine and session factory
engine = create_async_engine(DB_PATH, echo=False)
async_session_factory = async_sessionmaker(engine, expire_on_commit=False)
@ -116,15 +107,15 @@ class DB:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def get_user_by_user_id(self, user_id: UUID) -> User:
"""Get user record by WebAuthn user ID."""
stmt = select(UserModel).where(UserModel.user_id == user_id.bytes)
async def get_user_by_user_uuid(self, user_uuid: UUID) -> User:
"""Get user record by WebAuthn user UUID."""
stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes)
result = await self.session.execute(stmt)
user_model = result.scalar_one_or_none()
if user_model:
return User(
user_id=UUID(bytes=user_model.user_id),
user_uuid=UUID(bytes=user_model.user_uuid),
user_name=user_model.user_name,
created_at=user_model.created_at,
last_seen=user_model.last_seen,
@ -135,7 +126,7 @@ class DB:
async def create_user(self, user: User) -> None:
"""Create a new user."""
user_model = UserModel(
user_id=user.user_id.bytes,
user_uuid=user.user_uuid.bytes,
user_name=user.user_name,
created_at=user.created_at or datetime.now(),
last_seen=user.last_seen,
@ -144,11 +135,12 @@ class DB:
self.session.add(user_model)
await self.session.flush()
async def create_credential(self, credential: StoredCredential) -> None:
async def create_credential(self, credential: Credential) -> None:
"""Store a credential for a user."""
credential_model = CredentialModel(
uuid=credential.uuid.bytes,
credential_id=credential.credential_id,
user_id=credential.user_id.bytes,
user_uuid=credential.user_uuid.bytes,
aaguid=credential.aaguid.bytes,
public_key=credential.public_key,
sign_count=credential.sign_count,
@ -159,7 +151,7 @@ class DB:
self.session.add(credential_model)
await self.session.flush()
async def get_credential_by_id(self, credential_id: bytes) -> StoredCredential:
async def get_credential_by_id(self, credential_id: bytes) -> Credential:
"""Get credential by credential ID."""
stmt = select(CredentialModel).where(
CredentialModel.credential_id == credential_id
@ -167,10 +159,12 @@ class DB:
result = await self.session.execute(stmt)
credential_model = result.scalar_one_or_none()
if credential_model:
return StoredCredential(
if not credential_model:
raise ValueError("Credential not registered")
return Credential(
uuid=UUID(bytes=credential_model.uuid),
credential_id=credential_model.credential_id,
user_id=UUID(bytes=credential_model.user_id),
user_uuid=UUID(bytes=credential_model.user_uuid),
aaguid=UUID(bytes=credential_model.aaguid),
public_key=credential_model.public_key,
sign_count=credential_model.sign_count,
@ -178,17 +172,16 @@ class DB:
last_used=credential_model.last_used,
last_verified=credential_model.last_verified,
)
raise ValueError("Credential not registered")
async def get_credentials_by_user_id(self, user_id: UUID) -> list[bytes]:
async def get_credentials_by_user_uuid(self, user_uuid: UUID) -> list[bytes]:
"""Get all credential IDs for a user."""
stmt = select(CredentialModel.credential_id).where(
CredentialModel.user_id == user_id.bytes
CredentialModel.user_uuid == user_uuid.bytes
)
result = await self.session.execute(stmt)
return [row[0] for row in result.fetchall()]
async def update_credential(self, credential: StoredCredential) -> None:
async def update_credential(self, credential: Credential) -> None:
"""Update the sign count, created_at, last_used, and last_verified for a credential."""
stmt = (
update(CredentialModel)
@ -202,7 +195,7 @@ class DB:
)
await self.session.execute(stmt)
async def login(self, user_id: UUID, credential: StoredCredential) -> None:
async def login(self, user_uuid: UUID, credential: Credential) -> None:
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
async with self.session.begin():
# Update credential
@ -211,137 +204,77 @@ class DB:
# Update user's last_seen and increment visits
stmt = (
update(UserModel)
.where(UserModel.user_id == user_id.bytes)
.where(UserModel.user_uuid == user_uuid.bytes)
.values(last_seen=credential.last_used, visits=UserModel.visits + 1)
)
await self.session.execute(stmt)
async def create_new_session(
self, user_id: UUID, credential: StoredCredential
) -> None:
"""Create a new session for a user by incrementing visits and updating last_seen."""
async with self.session.begin():
# Update credential
await self.update_credential(credential)
# Update user's last_seen and increment visits
stmt = (
update(UserModel)
.where(UserModel.user_id == user_id.bytes)
.values(last_seen=credential.last_used, visits=UserModel.visits + 1)
)
await self.session.execute(stmt)
async def delete_credential(self, credential_id: bytes) -> None:
async def delete_credential(self, uuid: UUID, user_uuid: UUID) -> None:
"""Delete a credential by its ID."""
stmt = delete(CredentialModel).where(
CredentialModel.credential_id == credential_id
stmt = (
delete(CredentialModel)
.where(CredentialModel.uuid == uuid.bytes)
.where(CredentialModel.user_uuid == user_uuid.bytes)
)
await self.session.execute(stmt)
await self.session.commit()
async def get_user_by_username(self, user_name: str) -> User | None:
"""Get user by username."""
stmt = select(UserModel).where(UserModel.user_name == user_name)
result = await self.session.execute(stmt)
user_model = result.scalar_one_or_none()
if user_model:
return User(
user_id=UUID(bytes=user_model.user_id),
user_name=user_model.user_name,
created_at=user_model.created_at,
last_seen=user_model.last_seen,
visits=user_model.visits,
)
return None
async def create_session(
self,
user_id: UUID,
credential_id: int | None = None,
token: str | None = None,
info: dict | None = None,
) -> str:
"""Create a new authentication session for a user. If credential_id is None, creates a session without a specific credential."""
if token is None:
token = secrets.token_urlsafe(12)
user_uuid: UUID,
key: bytes,
expires: datetime,
info: dict,
credential_uuid: UUID | None = None,
) -> bytes:
"""Create a new authentication session for a user. If credential_uuid is None, creates a session without a specific credential."""
session_model = SessionModel(
token=token,
user_id=user_id.bytes,
credential_id=credential_id,
created_at=datetime.now(),
key=key,
user_uuid=user_uuid.bytes,
credential_uuid=credential_uuid.bytes if credential_uuid else None,
expires=expires,
info=info,
)
self.session.add(session_model)
await self.session.flush()
return token
return key
async def create_session_by_credential_id(
self,
user_id: UUID,
credential_id: bytes | None,
token: str | None = None,
info: dict | None = None,
) -> str:
"""Create a new authentication session for a user using WebAuthn credential ID. If credential_id is None, creates a session without a specific credential."""
if credential_id is None:
return await self.create_session(user_id, None, token, info)
async def get_session(self, key: bytes) -> Session | None:
"""Get session by 16-byte key."""
stmt = select(SessionModel).where(SessionModel.key == key)
result = await self.session.execute(stmt)
session_model = result.scalar_one_or_none()
# Get the database ID from the credential
stmt = select(CredentialModel.id).where(
CredentialModel.credential_id == credential_id
if session_model:
return Session(
key=session_model.key,
user_uuid=UUID(bytes=session_model.user_uuid),
credential_uuid=UUID(bytes=session_model.credential_uuid)
if session_model.credential_uuid
else None,
expires=session_model.expires,
info=session_model.info,
)
result = await self.session.execute(stmt)
db_credential_id = result.scalar_one()
return await self.create_session(user_id, db_credential_id, token, info)
async def get_session(self, token: str) -> SessionModel | None:
"""Get session by token string."""
stmt = select(SessionModel).where(SessionModel.token == token)
result = await self.session.execute(stmt)
session = result.scalar_one_or_none()
if session:
# Check if session is expired (24 hours)
expiry_time = session.created_at + timedelta(hours=24)
if datetime.now() > expiry_time:
# Clean up expired session
await self.delete_session(token)
return None
return session
return None
async def delete_session(self, key: bytes) -> None:
"""Delete a session by 16-byte key."""
await self.session.execute(delete(SessionModel).where(SessionModel.key == key))
async def delete_session(self, token: str) -> None:
"""Delete a session by token."""
stmt = delete(SessionModel).where(SessionModel.token == token)
await self.session.execute(stmt)
async def update_session(self, key: bytes, expires: datetime, info: dict) -> None:
"""Update session expiration time and/or info."""
await self.session.execute(
update(SessionModel)
.where(SessionModel.key == key)
.values(expires=expires, info=info)
)
async def cleanup_expired_sessions(self) -> None:
"""Remove expired sessions (older than 24 hours)."""
expiry_time = datetime.now() - timedelta(hours=24)
stmt = delete(SessionModel).where(SessionModel.created_at < expiry_time)
"""Remove expired sessions."""
current_time = datetime.now()
stmt = delete(SessionModel).where(SessionModel.expires < current_time)
await self.session.execute(stmt)
async def refresh_session(self, token: str) -> str | None:
"""Refresh a session by updating its created_at timestamp."""
session = await self.get_session(token)
if not session:
return None
# Delete old session
await self.delete_session(token)
# Create new session with same user and credential
return await self.create_session(
user_id=UUID(bytes=session.user_id),
credential_id=session.credential_id,
info=session.info,
)
# Standalone functions that handle database connections internally
async def init_database() -> None:
@ -350,7 +283,7 @@ async def init_database() -> None:
await db.init_db()
async def create_user_and_credential(user: User, credential: StoredCredential) -> None:
async def create_user_and_credential(user: User, credential: Credential) -> None:
"""Create a new user and their first credential in a single transaction."""
async with connect() as db:
await db.session.begin()
@ -360,97 +293,73 @@ async def create_user_and_credential(user: User, credential: StoredCredential) -
await db.create_credential(credential)
async def get_user_by_id(user_id: UUID) -> User:
"""Get user record by WebAuthn user ID."""
async def get_user_by_uuid(user_uuid: UUID) -> User:
"""Get user record by WebAuthn user UUID."""
async with connect() as db:
return await db.get_user_by_user_id(user_id)
return await db.get_user_by_user_uuid(user_uuid)
async def create_credential_for_user(credential: StoredCredential) -> None:
async def create_credential_for_user(credential: Credential) -> None:
"""Store a credential for an existing user."""
async with connect() as db:
await db.create_credential(credential)
async def get_credential_by_id(credential_id: bytes) -> StoredCredential:
async def get_credential_by_id(credential_id: bytes) -> Credential:
"""Get credential by credential ID."""
async with connect() as db:
return await db.get_credential_by_id(credential_id)
async def get_user_credentials(user_id: UUID) -> list[bytes]:
async def get_user_credentials(user_uuid: UUID) -> list[bytes]:
"""Get all credential IDs for a user."""
async with connect() as db:
return await db.get_credentials_by_user_id(user_id)
return await db.get_credentials_by_user_uuid(user_uuid)
async def login_user(user_id: UUID, credential: StoredCredential) -> None:
async def login_user(user_uuid: UUID, credential: Credential) -> None:
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
async with connect() as db:
await db.login(user_id, credential)
await db.login(user_uuid, credential)
async def delete_user_credential(credential_id: bytes) -> None:
async def delete_credential(uuid: UUID, user_uuid: UUID) -> None:
"""Delete a credential by its ID."""
async with connect() as db:
await db.delete_credential(credential_id)
async def create_new_session(user_id: UUID, credential: StoredCredential) -> None:
"""Create a new session for a user by incrementing visits and updating last_seen."""
async with connect() as db:
await db.create_new_session(user_id, credential)
async def get_user_by_username(user_name: str) -> User | None:
"""Get user by username."""
async with connect() as db:
return await db.get_user_by_username(user_name)
await db.delete_credential(uuid, user_uuid)
async def create_session(
user_id: UUID,
credential_id: int | None = None,
token: str | None = None,
info: dict | None = None,
) -> str:
"""Create a new authentication session for a user. If credential_id is None, creates a session without a specific credential."""
user_uuid: UUID,
key: bytes,
expires: datetime,
info: dict,
credential_uuid: UUID | None = None,
) -> bytes:
"""Create a new authentication session for a user. If credential_uuid is None, creates a session without a specific credential."""
async with connect() as db:
return await db.create_session(user_id, credential_id, token, info)
return await db.create_session(user_uuid, key, expires, info, credential_uuid)
async def create_session_by_credential_id(
user_id: UUID,
credential_id: bytes | None,
token: str | None = None,
info: dict | None = None,
) -> str:
"""Create a new authentication session for a user using WebAuthn credential ID. If credential_id is None, creates a session without a specific credential."""
async def get_session(key: bytes) -> Session | None:
"""Get session by 16-byte key."""
async with connect() as db:
return await db.create_session_by_credential_id(
user_id, credential_id, token, info
)
return await db.get_session(key)
async def get_session(token: str) -> SessionModel | None:
"""Get session by token string."""
async def delete_session(key: bytes) -> None:
"""Delete a session by 16-byte key."""
async with connect() as db:
return await db.get_session(token)
await db.delete_session(key)
async def delete_session(token: str) -> None:
"""Delete a session by token."""
async def update_session(key: bytes, expires: datetime, info: dict) -> None:
"""Update session expiration time and/or info."""
async with connect() as db:
await db.delete_session(token)
await db.update_session(key, expires, info)
async def cleanup_expired_sessions() -> None:
"""Remove expired sessions (older than 24 hours)."""
"""Remove expired sessions."""
async with connect() as db:
await db.cleanup_expired_sessions()
async def refresh_session(token: str) -> str | None:
"""Refresh a session by updating its created_at timestamp."""
async with connect() as db:
return await db.refresh_session(token)

View File

@ -8,67 +8,67 @@ This module contains all the HTTP API endpoints for:
- Login/logout functionality
"""
from fastapi import FastAPI, Request, Response
from uuid import UUID
from fastapi import Cookie, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBearer
from .. import aaguid
from ..db import sql
from ..util.session import refresh_session_token, validate_session_token
from .session import (
clear_session_cookie,
get_current_user,
get_session_token_from_bearer,
get_session_token_from_cookie,
set_session_cookie,
)
from ..util.tokens import session_key
from . import session
bearer_auth = HTTPBearer(auto_error=True)
def register_api_routes(app: FastAPI):
"""Register all API routes on the FastAPI app."""
@app.post("/auth/user-info")
async def api_user_info(request: Request, response: Response):
"""Get user information and credentials from session cookie."""
@app.post("/auth/validate")
async def validate_token(request: Request, response: Response, auth=Cookie(None)):
"""Lightweight token validation endpoint."""
try:
user = await get_current_user(request)
if not user:
return {"error": "Not authenticated"}
# Get current session credential ID
current_credential_id = None
session_token = get_session_token_from_cookie(request)
if session_token:
token_data = await validate_session_token(session_token)
if token_data:
current_credential_id = token_data.get("credential_id")
s = await session.get_session(auth)
return {
"status": "success",
"valid": True,
"user_uuid": str(s.user_uuid),
}
except ValueError:
return {"status": "error", "valid": False}
@app.post("/auth/user-info")
async def api_user_info(request: Request, response: Response, auth=Cookie(None)):
"""Get full user information for the authenticated user."""
try:
s = await session.get_session(auth, reset_allowed=True)
u = await sql.get_user_by_uuid(s.user_uuid)
# Get all credentials for the user
credential_ids = await sql.get_user_credentials(user.user_id)
credential_ids = await sql.get_user_credentials(s.user_uuid)
credentials = []
user_aaguids = set()
for cred_id in credential_ids:
stored_cred = await sql.get_credential_by_id(cred_id)
c = await sql.get_credential_by_id(cred_id)
# Convert AAGUID to string format
aaguid_str = str(stored_cred.aaguid)
aaguid_str = str(c.aaguid)
user_aaguids.add(aaguid_str)
# Check if this is the current session credential
is_current_session = current_credential_id == stored_cred.credential_id
is_current_session = s.credential_uuid == c.uuid
credentials.append(
{
"credential_id": stored_cred.credential_id.hex(),
"credential_uuid": str(c.uuid),
"aaguid": aaguid_str,
"created_at": stored_cred.created_at.isoformat(),
"last_used": stored_cred.last_used.isoformat()
if stored_cred.last_used
"created_at": c.created_at.isoformat(),
"last_used": c.last_used.isoformat() if c.last_used else None,
"last_verified": c.last_verified.isoformat()
if c.last_verified
else None,
"last_verified": stored_cred.last_verified.isoformat()
if stored_cred.last_verified
else None,
"sign_count": stored_cred.sign_count,
"sign_count": c.sign_count,
"is_current_session": is_current_session,
}
)
@ -82,13 +82,11 @@ def register_api_routes(app: FastAPI):
return {
"status": "success",
"user": {
"user_id": str(user.user_id),
"user_name": user.user_name,
"created_at": user.created_at.isoformat()
if user.created_at
else None,
"last_seen": user.last_seen.isoformat() if user.last_seen else None,
"visits": user.visits,
"user_uuid": str(u.user_uuid),
"user_name": u.user_name,
"created_at": u.created_at.isoformat() if u.created_at else None,
"last_seen": u.last_seen.isoformat() if u.last_seen else None,
"visits": u.visits,
},
"credentials": credentials,
"aaguid_info": aaguid_info,
@ -97,196 +95,44 @@ def register_api_routes(app: FastAPI):
return {"error": f"Failed to get user info: {str(e)}"}
@app.post("/auth/logout")
async def api_logout(request: Request, response: Response):
async def api_logout(response: Response, auth=Cookie(None)):
"""Log out the current user by clearing the session cookie and deleting from database."""
# Get the session token before clearing the cookie
session_token = get_session_token_from_cookie(request)
# Clear the cookie
clear_session_cookie(response)
# Delete the session from the database if it exists
if session_token:
from ..util.session import logout_session
try:
await logout_session(session_token)
except Exception:
# Continue even if session deletion fails
pass
if not auth:
return {"status": "success", "message": "Already logged out"}
await sql.delete_session(session_key(auth))
response.delete_cookie("auth")
return {"status": "success", "message": "Logged out successfully"}
@app.post("/auth/set-session")
async def api_set_session(request: Request, response: Response):
"""Set session cookie using JWT token from request body or Authorization header."""
async def api_set_session(
request: Request, response: Response, auth=Depends(bearer_auth)
):
"""Set session cookie from Authorization header. Fetched after login by WebSocket."""
try:
session_token = await get_session_token_from_bearer(request)
if not session_token:
return {"error": "No session token provided"}
# Validate the session token
token_data = await validate_session_token(session_token)
if not token_data:
return {"error": "Invalid or expired session token"}
# Set the HTTP-only cookie
set_session_cookie(response, session_token)
user = await session.get_session(auth.credentials)
if not user:
raise ValueError("Invalid Authorization header.")
session.set_session_cookie(response, auth.credentials)
return {
"status": "success",
"message": "Session cookie set successfully",
"user_id": str(token_data["user_id"]),
"user_uuid": str(user.user_uuid),
}
except ValueError as e:
return {"error": str(e)}
except Exception as e:
return {"error": f"Failed to set session: {str(e)}"}
@app.post("/auth/delete-credential")
async def api_delete_credential(request: Request):
@app.delete("/auth/credential/{uuid}")
async def api_delete_credential(uuid: UUID, auth: str = Cookie(None)):
"""Delete a specific credential for the current user."""
try:
user = await get_current_user(request)
if not user:
return {"error": "Not authenticated"}
# Get the credential ID from the request body
try:
body = await request.json()
credential_id = body.get("credential_id")
if not credential_id:
return {"error": "credential_id is required"}
except Exception:
return {"error": "Invalid request body"}
# Convert credential_id from hex string to bytes
try:
credential_id_bytes = bytes.fromhex(credential_id)
except ValueError:
return {"error": "Invalid credential_id format"}
# First, verify the credential belongs to the current user
try:
stored_cred = await sql.get_credential_by_id(credential_id_bytes)
if stored_cred.user_id != user.user_id:
return {"error": "Credential not found or access denied"}
except ValueError:
return {"error": "Credential not found"}
# Check if this is the current session credential
session_token = get_session_token_from_cookie(request)
if session_token:
token_data = await validate_session_token(session_token)
if (
token_data
and token_data.get("credential_id") == credential_id_bytes
):
return {"error": "Cannot delete current session credential"}
# Get user's remaining credentials count
remaining_credentials = await sql.get_user_credentials(user.user_id)
if len(remaining_credentials) <= 1:
return {"error": "Cannot delete last remaining credential"}
# Delete the credential
await sql.delete_user_credential(credential_id_bytes)
await session.delete_credential(uuid, auth)
return {"status": "success", "message": "Credential deleted successfully"}
except Exception as e:
return {"error": f"Failed to delete credential: {str(e)}"}
@app.get("/auth/sessions")
async def api_get_sessions(request: Request):
"""Get all active sessions for the current user."""
try:
user = await get_current_user(request)
if not user:
return {"error": "Authentication required"}
# Get all sessions for this user
from sqlalchemy import select
from ..db.sql import SessionModel, connect
async with connect() as db:
stmt = select(SessionModel).where(
SessionModel.user_id == user.user_id.bytes
)
result = await db.session.execute(stmt)
session_models = result.scalars().all()
sessions = []
current_token = get_session_token_from_cookie(request)
for session in session_models:
# Check if session is expired
from datetime import datetime, timedelta
expiry_time = session.created_at + timedelta(hours=24)
is_expired = datetime.now() > expiry_time
sessions.append(
{
"token": session.token[:8]
+ "...", # Only show first 8 chars for security
"created_at": session.created_at.isoformat(),
"client_ip": session.info.get("client_ip")
if session.info
else None,
"user_agent": session.info.get("user_agent")
if session.info
else None,
"connection_type": session.info.get(
"connection_type", "http"
)
if session.info
else "http",
"is_current": session.token == current_token,
"is_reset_token": session.credential_id is None,
"is_expired": is_expired,
}
)
return {
"status": "success",
"sessions": sessions,
"total_sessions": len(sessions),
}
except Exception as e:
return {"error": f"Failed to get sessions: {str(e)}"}
async def validate_token(request: Request, response: Response) -> dict:
"""Validate a session token and return user info. Also refreshes the token if valid."""
try:
session_token = get_session_token_from_cookie(request)
if not session_token:
return {"error": "No session token found"}
# Validate the session token
token_data = await validate_session_token(session_token)
if not token_data:
clear_session_cookie(response)
return {"error": "Invalid or expired session token"}
# Refresh the token if valid
new_token = await refresh_session_token(session_token)
if new_token:
set_session_cookie(response, new_token)
return {
"status": "success",
"valid": True,
"refreshed": bool(new_token),
"user_id": str(token_data["user_id"]),
"credential_id": token_data["credential_id"].hex()
if token_data["credential_id"]
else None,
"created_at": token_data["created_at"].isoformat(),
}
except Exception as e:
return {"error": f"Failed to validate token: {str(e)}"}
except ValueError as e:
return {"error": str(e)}
except Exception:
return {"error": "Failed to delete credential"}

View File

@ -9,15 +9,12 @@ This module provides a simple WebAuthn implementation that:
- Enables true passwordless authentication where users don't need to enter a user_name
"""
import contextlib
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import (
FastAPI,
Request,
Response,
)
from fastapi import Cookie, FastAPI, Request, Response
from fastapi.responses import (
FileResponse,
JSONResponse,
@ -25,12 +22,9 @@ from fastapi.responses import (
from fastapi.staticfiles import StaticFiles
from ..db import sql
from .api import (
register_api_routes,
validate_token,
)
from . import session, ws
from .api import register_api_routes
from .reset import register_reset_routes
from .ws import ws_app
STATIC_DIR = Path(__file__).parent.parent / "frontend-build"
@ -44,7 +38,7 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan)
# Mount the WebSocket subapp
app.mount("/auth/ws", ws_app)
app.mount("/auth/ws", ws.app)
# Register API routes
register_api_routes(app)
@ -52,13 +46,14 @@ register_reset_routes(app)
@app.get("/auth/forward-auth")
async def forward_authentication(request: Request):
"""A verification endpoint to use with Caddy forward_auth or Nginx auth_request."""
# Create a dummy response object for internal validation (we won't use it for cookies)
response = Response()
async def forward_authentication(request: Request, auth=Cookie(None)):
"""A validation endpoint to use with Caddy forward_auth or Nginx auth_request."""
with contextlib.suppress(ValueError):
s = await session.get_session(auth)
# If authenticated, return a success response
if s.info and s.info["type"] == "authenticated":
return Response(status_code=204, headers={"x-auth-user": str(s.user_uuid)})
result = await validate_token(request, response)
if result.get("status") != "success":
# Serve the index.html of the authentication app if not authenticated
return FileResponse(
STATIC_DIR / "index.html",
@ -66,12 +61,6 @@ async def forward_authentication(request: Request):
headers={"www-authenticate": "PrivateToken"},
)
# If authenticated, return a success response
return Response(
status_code=204,
headers={"x-auth-user-id": result["user_id"]},
)
# Serve static files
app.mount(

View File

@ -1,114 +1,74 @@
"""
Device addition API handlers for WebAuthn authentication.
import logging
This module provides endpoints for authenticated users to:
- Generate device addition links with human-readable tokens
- Validate device addition tokens
- Add new passkeys to existing accounts via tokens
"""
from uuid import UUID
from fastapi import FastAPI, Path, Request
from fastapi import Cookie, HTTPException, Request
from fastapi.responses import RedirectResponse
from ..db import sql
from ..util.passphrase import generate
from ..util.session import get_client_info
from .session import get_current_user, is_device_addition_session, set_session_cookie
from ..util import passphrase, tokens
from . import session
def register_reset_routes(app: FastAPI):
def register_reset_routes(app):
"""Register all device addition/reset routes on the FastAPI app."""
@app.post("/auth/create-device-link")
async def api_create_device_link(request: Request):
@app.post("/auth/create-link")
async def api_create_link(request: Request, auth=Cookie(None)):
"""Create a device addition link for the authenticated user."""
try:
# Require authentication
user = await get_current_user(request)
if not user:
return {"error": "Authentication required"}
s = await session.get_session(auth)
# Generate a human-readable token
token = generate(n=4, sep=".") # e.g., "able-ocean-forest-dawn"
# Create session token in database with credential_id=None for device addition
client_info = get_client_info(request)
await sql.create_session(user.user_id, None, token, client_info)
# Generate the device addition link with pretty URL
addition_link = f"{request.headers.get('origin', '')}/auth/{token}"
return {
"status": "success",
"message": "Device addition link generated successfully",
"addition_link": addition_link,
"expires_in_hours": 24,
}
except Exception as e:
return {"error": f"Failed to create device addition link: {str(e)}"}
@app.get("/auth/device-session-check")
async def check_device_session(request: Request):
"""Check if the current session is for device addition."""
is_device_session = await is_device_addition_session(request)
return {"device_addition_session": is_device_session}
@app.get("/auth/{passphrase}")
async def reset_authentication(
request: Request,
passphrase: str = Path(pattern=r"^\w+(\.\w+){2,}$"),
):
try:
# Get session token to validate it exists and get user_id
session = await sql.get_session(passphrase)
if not session:
# Token doesn't exist, redirect to home
return RedirectResponse(url="/", status_code=303)
# Check if this is a device addition session (credential_id is None)
if session.credential_id is not None:
# Not a device addition session, redirect to home
return RedirectResponse(url="/", status_code=303)
# Create a device addition session token for the user
client_info = get_client_info(request)
session_token = await sql.create_session(
UUID(bytes=session.user_id), None, None, client_info
token = passphrase.generate() # e.g., "cross.rotate.yin.note.evoke"
await sql.create_session(
user_uuid=s.user_uuid,
key=tokens.reset_key(token),
expires=session.expires(),
info=session.infodict(request, "device addition"),
)
# Create response and set session cookie
response = RedirectResponse(url="/auth/", status_code=303)
set_session_cookie(response, session_token)
return response
except Exception:
# On any error, redirect to home
return RedirectResponse(url="/", status_code=303)
async def use_reset_token(token: str) -> dict:
"""Delete a device addition token after successful use."""
try:
# Get session token first to validate it exists and is not expired
session = await sql.get_session(token)
if not session:
return {"error": "Invalid or expired device addition token"}
# Check if this is a device addition session (credential_id is None)
if session.credential_id is not None:
return {"error": "Invalid device addition token"}
# Delete the token (it's now used)
await sql.delete_session(token)
# Generate the device addition link with pretty URL
path = request.url.path.removesuffix("create-link") + token
url = f"{request.headers['origin']}{path}"
return {
"status": "success",
"message": "Device addition token used successfully",
"message": "Registration link generated successfully",
"url": url,
"expires": session.expires().isoformat(),
}
except ValueError:
return {"error": "Authentication required"}
except Exception as e:
return {"error": f"Failed to use device addition token: {str(e)}"}
return {"error": f"Failed to create registration link: {str(e)}"}
@app.get("/auth/{reset_token}")
async def reset_authentication(
request: Request,
reset_token: str,
):
"""Verifies the token and redirects to auth app for credential registration."""
# This route should only match to exact passphrases
print(f"Reset handler called with url: {request.url.path}")
if not passphrase.is_well_formed(reset_token):
raise HTTPException(status_code=404)
try:
# Get session token to validate it exists and get user_id
key = tokens.reset_key(reset_token)
sess = await sql.get_session(key)
if not sess:
raise ValueError("Invalid or expired registration token")
response = RedirectResponse(url="/auth/", status_code=303)
session.set_session_cookie(response, reset_token)
return response
except Exception as e:
# On any error, redirect to auth app
if isinstance(e, ValueError):
msg = str(e)
else:
logging.exception("Internal Server Error in reset_authentication")
msg = "Internal Server Error"
return RedirectResponse(url=f"/auth/#{msg}", status_code=303)

View File

@ -5,144 +5,85 @@ This module provides session management functionality including:
- Getting current user from session cookies
- Setting and clearing HTTP-only cookies
- Session validation and token handling
- Device addition token management
- Device addition route handlers
"""
from datetime import datetime, timedelta
from uuid import UUID
from fastapi import Request, Response
from ..db.sql import User, get_user_by_id
from ..util.session import validate_session_token
from ..db import Session, sql
from ..util import passphrase
from ..util.tokens import create_token, reset_key, session_key
COOKIE_NAME = "auth"
COOKIE_MAX_AGE = 86400 # 24 hours
EXPIRES = timedelta(hours=24)
async def get_current_user(request: Request) -> User | None:
"""Get the current user from the session cookie."""
session_token = request.cookies.get(COOKIE_NAME)
if not session_token:
return None
token_data = await validate_session_token(session_token)
if not token_data:
return None
try:
user = await get_user_by_id(token_data["user_id"])
return user
except Exception:
return None
def expires() -> datetime:
return datetime.now() + EXPIRES
def set_session_cookie(response: Response, session_token: str) -> None:
def infodict(request: Request, type: str) -> dict:
"""Extract client information from request."""
return {
"ip": request.client.host if request.client else "",
"user_agent": request.headers.get("user-agent", "")[:500],
"type": type,
}
async def create_session(user_uuid: UUID, info: dict, credential_uuid: UUID) -> str:
"""Create a new session and return a session token."""
token = create_token()
await sql.create_session(
user_uuid=user_uuid,
key=session_key(token),
expires=datetime.now() + EXPIRES,
info=info,
credential_uuid=credential_uuid,
)
return token
async def get_session(token: str, reset_allowed=False) -> Session:
"""Validate a session token and return session data if valid."""
if passphrase.is_well_formed(token):
if not reset_allowed:
raise ValueError("Reset link is not allowed for this endpoint")
key = reset_key(token)
else:
key = session_key(token)
session = await sql.get_session(key)
if not session:
raise ValueError("Invalid or expired session token")
return session
async def refresh_session_token(token: str):
"""Refresh a session extending its expiry."""
# Get the current session
s = await sql.update_session(session_key(token), datetime.now() + EXPIRES, {})
if not s:
raise ValueError("Session not found or expired")
def set_session_cookie(response: Response, token: str) -> None:
"""Set the session token as an HTTP-only cookie."""
response.set_cookie(
key=COOKIE_NAME,
value=session_token,
max_age=COOKIE_MAX_AGE,
key="auth",
value=token,
max_age=int(EXPIRES.total_seconds()),
httponly=True,
secure=True,
samesite="lax",
path="/auth/",
)
def clear_session_cookie(response: Response) -> None:
"""Clear the session cookie."""
response.delete_cookie(key=COOKIE_NAME)
def get_session_token_from_cookie(request: Request) -> str | None:
"""Extract session token from request cookies."""
return request.cookies.get(COOKIE_NAME)
async def validate_session_from_request(request: Request) -> dict | None:
"""Validate session token from request and return token data."""
session_token = get_session_token_from_cookie(request)
if not session_token:
return None
return await validate_session_token(session_token)
async def get_session_token_from_bearer(request: Request) -> str | None:
"""Extract session token from Authorization header or request body."""
# Try to get token from Authorization header first
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header.removeprefix("Bearer ")
async def get_user_from_cookie_string(cookie_header: str) -> UUID | None:
"""Parse cookie header and return user ID if valid session exists."""
if not cookie_header:
return None
# Parse cookies from header (simple implementation)
cookies = {}
for cookie in cookie_header.split(";"):
cookie = cookie.strip()
if "=" in cookie:
name, value = cookie.split("=", 1)
cookies[name] = value
session_token = cookies.get(COOKIE_NAME)
if not session_token:
return None
token_data = await validate_session_token(session_token)
if not token_data:
return None
return token_data["user_id"]
async def is_device_addition_session(request: Request) -> bool:
"""Check if the current session is for device addition."""
session_token = request.cookies.get(COOKIE_NAME)
if not session_token:
return False
token_data = await validate_session_token(session_token)
if not token_data:
return False
return token_data.get("device_addition", False)
async def get_device_addition_user_id(request: Request) -> UUID | None:
"""Get user ID from device addition session."""
session_token = request.cookies.get(COOKIE_NAME)
if not session_token:
return None
token_data = await validate_session_token(session_token)
if not token_data or not token_data.get("device_addition"):
return None
return token_data.get("user_id")
async def get_device_addition_user_id_from_cookie(cookie_header: str) -> UUID | None:
"""Parse cookie header and return user ID if valid device addition session exists."""
if not cookie_header:
return None
# Parse cookies from header (simple implementation)
cookies = {}
for cookie in cookie_header.split(";"):
cookie = cookie.strip()
if "=" in cookie:
name, value = cookie.split("=", 1)
cookies[name] = value
session_token = cookies.get(COOKIE_NAME)
if not session_token:
return None
token_data = await validate_session_token(session_token)
if not token_data or not token_data.get("device_addition"):
return None
return token_data["user_id"]
async def delete_credential(credential_uuid: UUID, auth: str):
"""Delete a specific credential for the current user."""
s = await get_session(auth)
await sql.delete_credential(credential_uuid, s.user_uuid)

View File

@ -13,17 +13,18 @@ from datetime import datetime
from uuid import UUID
import uuid7
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi import Cookie, FastAPI, Query, Request, WebSocket, WebSocketDisconnect
from webauthn.helpers.exceptions import InvalidAuthenticationResponse
from ..db import sql
from ..db.sql import User
from passkey.fastapi import session
from ..db import User, sql
from ..sansio import Passkey
from ..util.session import create_session_token, get_client_info_from_websocket
from .session import get_user_from_cookie_string
from ..util.tokens import create_token, reset_key, session_key
from .session import create_session, infodict
# Create a FastAPI subapp for WebSocket endpoints
ws_app = FastAPI()
app = FastAPI()
# Initialize the passkey instance
passkey = Passkey(
@ -34,51 +35,55 @@ passkey = Passkey(
async def register_chat(
ws: WebSocket,
user_id: UUID,
user_uuid: UUID,
user_name: str,
credential_ids: list[bytes] | None = None,
origin: str | None = None,
):
"""Generate registration options and send them to the client."""
options, challenge = passkey.reg_generate_options(
user_id=user_id,
user_id=user_uuid,
user_name=user_name,
credential_ids=credential_ids,
origin=origin,
)
await ws.send_json(options)
response = await ws.receive_json()
return passkey.reg_verify(response, challenge, user_id, origin=origin)
return passkey.reg_verify(response, challenge, user_uuid, origin=origin)
@ws_app.websocket("/register_new")
async def websocket_register_new(ws: WebSocket, user_name: str):
@app.websocket("/register")
async def websocket_register_new(
request: Request, ws: WebSocket, user_name: str = Query(""), auth=Cookie(None)
):
"""Register a new user and with a new passkey credential."""
await ws.accept()
origin = ws.headers.get("origin")
try:
user_id = uuid7.create()
user_uuid = uuid7.create()
# WebAuthn registration
credential = await register_chat(ws, user_id, user_name, origin=origin)
credential = await register_chat(ws, user_uuid, user_name, origin=origin)
# Store the user and credential in the database
await sql.create_user_and_credential(
User(user_id, user_name, created_at=datetime.now()),
User(user_uuid, user_name, created_at=datetime.now()),
credential,
)
# Create a session token for the new user
client_info = get_client_info_from_websocket(ws)
session_token = await create_session_token(
user_id, credential.credential_id, client_info
token = create_token()
await sql.create_session(
user_uuid=user_uuid,
key=session_key(token),
expires=datetime.now() + session.EXPIRES,
info=infodict(request, "authenticated"),
credential_uuid=credential.uuid,
)
await ws.send_json(
{
"status": "success",
"user_id": str(user_id),
"session_token": session_token,
"user_uuid": str(user_uuid),
"session_token": token,
}
)
except ValueError as e:
@ -90,28 +95,31 @@ async def websocket_register_new(ws: WebSocket, user_name: str):
await ws.send_json({"error": "Internal Server Error"})
@ws_app.websocket("/add_credential")
async def websocket_register_add(ws: WebSocket):
@app.websocket("/add_credential")
async def websocket_register_add(ws: WebSocket, token: str | None = None):
"""Register a new credential for an existing user."""
await ws.accept()
origin = ws.headers.get("origin")
try:
# Authenticate user via cookie
cookie_header = ws.headers.get("cookie", "")
user_id = await get_user_from_cookie_string(cookie_header)
if not user_id:
await ws.send_json({"error": "Authentication required"})
if not token:
await ws.send_json({"error": "Token is required"})
return
# If a token is provided, use it to look up the session
key = reset_key(token)
s = await sql.get_session(key)
if not s:
await ws.send_json({"error": "Invalid or expired token"})
return
user_uuid = s.user_uuid
# Get user information to get the user_name
user = await sql.get_user_by_id(user_id)
user = await sql.get_user_by_uuid(user_uuid)
user_name = user.user_name
challenge_ids = await sql.get_user_credentials(user_id)
challenge_ids = await sql.get_user_credentials(user_uuid)
# WebAuthn registration
credential = await register_chat(
ws, user_id, user_name, challenge_ids, origin=origin
ws, user_uuid, user_name, challenge_ids, origin=origin
)
# Store the new credential in the database
await sql.create_credential_for_user(credential)
@ -119,7 +127,7 @@ async def websocket_register_add(ws: WebSocket):
await ws.send_json(
{
"status": "success",
"user_id": str(user_id),
"user_uuid": str(user_uuid),
"credential_id": credential.credential_id.hex(),
"message": "New credential added successfully",
}
@ -133,103 +141,8 @@ async def websocket_register_add(ws: WebSocket):
await ws.send_json({"error": "Internal Server Error"})
@ws_app.websocket("/add_device_credential")
async def websocket_add_device_credential(ws: WebSocket, token: str):
"""Add a new credential for an existing user via device addition token."""
await ws.accept()
origin = ws.headers.get("origin")
try:
reset_token = await sql.get_session(token)
if not reset_token:
await ws.send_json({"error": "Invalid or expired device addition token"})
return
# Get user information
user = await sql.get_user_by_id(reset_token.user_id)
# WebAuthn registration
# Fetch challenge IDs for the user
challenge_ids = await sql.get_user_credentials(reset_token.user_id)
credential = await register_chat(
ws, reset_token.user_id, user.user_name, challenge_ids, origin=origin
)
# Store the new credential in the database
await sql.create_credential_for_user(credential)
# Delete the device addition token (it's now used)
await sql.delete_reset_token(token)
await ws.send_json(
{
"status": "success",
"user_id": str(reset_token.user_id),
"credential_id": credential.credential_id.hex(),
"message": "New credential added successfully via device addition token",
}
)
except ValueError as e:
await ws.send_json({"error": str(e)})
except WebSocketDisconnect:
pass
except Exception:
logging.exception("Internal Server Error")
await ws.send_json({"error": "Internal Server Error"})
@ws_app.websocket("/add_device_credential_session")
async def websocket_add_device_credential_session(ws: WebSocket):
"""Add a new credential for an existing user via device addition session."""
await ws.accept()
origin = ws.headers.get("origin")
try:
# Get device addition user ID from session cookie
cookie_header = ws.headers.get("cookie", "")
from .session import get_device_addition_user_id_from_cookie
user_id = await get_device_addition_user_id_from_cookie(cookie_header)
if not user_id:
await ws.send_json({"error": "No valid device addition session found"})
return
# Get user information
user = await sql.get_user_by_id(user_id)
if not user:
await ws.send_json({"error": "User not found"})
return
# WebAuthn registration
# Fetch challenge IDs for the user
challenge_ids = await sql.get_user_credentials(user_id)
credential = await register_chat(
ws, user_id, user.user_name, challenge_ids, origin=origin
)
# Store the new credential in the database
await sql.create_credential_for_user(credential)
await ws.send_json(
{
"status": "success",
"user_id": str(user_id),
"credential_id": credential.credential_id.hex(),
"message": "New credential added successfully via device addition session",
}
)
except ValueError as e:
await ws.send_json({"error": str(e)})
except WebSocketDisconnect:
pass
except Exception:
logging.exception("Internal Server Error")
await ws.send_json({"error": "Internal Server Error"})
@ws_app.websocket("/authenticate")
async def websocket_authenticate(ws: WebSocket):
@app.websocket("/authenticate")
async def websocket_authenticate(request: Request, ws: WebSocket):
await ws.accept()
origin = ws.headers.get("origin")
try:
@ -242,19 +155,21 @@ async def websocket_authenticate(ws: WebSocket):
# Verify the credential matches the stored data
passkey.auth_verify(credential, challenge, stored_cred, origin=origin)
# Update both credential and user's last_seen timestamp
await sql.login_user(stored_cred.user_id, stored_cred)
await sql.login_user(stored_cred.user_uuid, stored_cred)
# Create a session token for the authenticated user
client_info = get_client_info_from_websocket(ws)
session_token = await create_session_token(
stored_cred.user_id, stored_cred.credential_id, client_info
assert stored_cred.uuid is not None
token = await create_session(
user_uuid=stored_cred.user_uuid,
info=infodict(request, "auth"),
credential_uuid=stored_cred.uuid,
)
await ws.send_json(
{
"status": "success",
"user_id": str(stored_cred.user_id),
"session_token": session_token,
"user_uuid": str(stored_cred.user_uuid),
"session_token": token,
}
)
except (ValueError, InvalidAuthenticationResponse) as e:

View File

@ -8,7 +8,6 @@ This module provides a unified interface for WebAuthn operations including:
"""
import json
from dataclasses import dataclass
from datetime import datetime
from uuid import UUID
@ -36,21 +35,7 @@ from webauthn.helpers.structs import (
UserVerificationRequirement,
)
@dataclass
class StoredCredential:
"""Credential data stored in the database."""
# Fields set only at registration time
credential_id: bytes
user_id: UUID
aaguid: UUID
public_key: bytes
# Mutable fields that may be updated during authentication
sign_count: int
created_at: datetime
last_used: datetime | None = None
last_verified: datetime | None = None
from .db import Credential
class Passkey:
@ -129,7 +114,7 @@ class Passkey:
expected_challenge: bytes,
user_id: UUID,
origin: str | None = None,
) -> StoredCredential:
) -> Credential:
"""
Verify registration response.
@ -147,7 +132,7 @@ class Passkey:
expected_origin=origin or self.origin,
expected_rp_id=self.rp_id,
)
return StoredCredential(
return Credential(
credential_id=credential.raw_id,
user_id=user_id,
aaguid=UUID(registration.aaguid),
@ -195,7 +180,7 @@ class Passkey:
self,
credential: AuthenticationCredential,
expected_challenge: bytes,
stored_cred: StoredCredential,
stored_cred: Credential,
origin: str | None = None,
) -> VerifiedAuthentication:
"""

View File

@ -2,8 +2,18 @@ import secrets
from .wordlist import words
N_WORDS = 5
def generate(n=4, sep="."):
wset = set(words)
def generate(n=N_WORDS, sep="."):
"""Generate a password of random words without repeating any word."""
wl = list(words)
wl = words.copy()
return sep.join(wl.pop(secrets.randbelow(len(wl))) for i in range(n))
def is_well_formed(passphrase: str, n=N_WORDS, sep=".") -> bool:
"""Check if the passphrase is well-formed according to the regex pattern."""
p = passphrase.split(sep)
return len(p) == n and all(w in wset for w in passphrase.split("."))

View File

@ -1,88 +0,0 @@
"""
Database session management for WebAuthn authentication.
This module provides session management using database tokens instead of JWT tokens.
Session tokens are stored in the database and validated on each request.
"""
from datetime import datetime
from typing import Optional
from uuid import UUID
from fastapi import Request
from ..db import sql
def get_client_info(request: Request) -> dict:
"""Extract client information from FastAPI request and return as dict."""
# Get client IP (handle X-Forwarded-For for proxies)
# Get user agent
return {
"client_ip": request.client.host if request.client else "",
"user_agent": request.headers.get("user-agent", "")[:500],
}
def get_client_info_from_websocket(ws) -> dict:
"""Extract client information from WebSocket connection and return as dict."""
# Get client IP from WebSocket
client_ip = None
if hasattr(ws, "client") and ws.client:
client_ip = ws.client.host
# Check for forwarded headers
if hasattr(ws, "headers"):
forwarded_for = ws.headers.get("x-forwarded-for")
if forwarded_for:
client_ip = forwarded_for.split(",")[0].strip()
# Get user agent from WebSocket headers
user_agent = None
if hasattr(ws, "headers"):
user_agent = ws.headers.get("user-agent")
# Truncate user agent if too long
if user_agent and len(user_agent) > 500: # Keep some margin
user_agent = user_agent[:500]
return {
"client_ip": client_ip,
"user_agent": user_agent,
"timestamp": datetime.now().isoformat(),
"connection_type": "websocket",
}
async def create_session_token(
user_id: UUID, credential_id: bytes, info: dict | None = None
) -> str:
"""Create a session token for a user."""
return await sql.create_session_by_credential_id(user_id, credential_id, None, info)
async def validate_session_token(token: str) -> Optional[dict]:
"""Validate a session token."""
session_data = await sql.get_session(token)
if not session_data:
return None
return {
"user_id": session_data["user_id"],
"credential_id": session_data["credential_id"],
"created_at": session_data["created_at"],
}
async def refresh_session_token(token: str) -> Optional[str]:
"""Refresh a session token."""
return await sql.refresh_session(token)
async def delete_session_token(token: str) -> None:
"""Delete a session token."""
await sql.delete_session(token)
async def logout_session(token: str) -> None:
"""Log out a user by deleting their session token."""
await sql.delete_session(token)

17
passkey/util/tokens.py Normal file
View File

@ -0,0 +1,17 @@
import base64
import hashlib
import secrets
def create_token() -> str:
return secrets.token_urlsafe(12) # 16 characters Base64
def session_key(token: str) -> bytes:
if len(token) != 16:
raise ValueError("Session token must be exactly 16 characters long")
return b"sess" + base64.urlsafe_b64decode(token)
def reset_key(passphrase: str) -> bytes:
return b"rset" + hashlib.sha512(passphrase.encode()).digest()[:12]