Removal of JWT code, cleanup, using User dataclass rather than UserModel in APIs.

This commit is contained in:
Leo Vasanko 2025-07-27 23:44:26 -06:00
parent 208419c2b1
commit 0cfa622bf1
4 changed files with 22 additions and 202 deletions

View File

@ -298,27 +298,21 @@ class DB:
return await self.create_session(user_id, db_credential_id, token, info) return await self.create_session(user_id, db_credential_id, token, info)
async def get_session(self, token: str) -> dict | None: async def get_session(self, token: str) -> SessionModel | None:
"""Get session by token string.""" """Get session by token string."""
stmt = select(SessionModel).where(SessionModel.token == token) stmt = select(SessionModel).where(SessionModel.token == token)
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
session_model = result.scalar_one_or_none() session = result.scalar_one_or_none()
if session_model: if session:
# Check if session is expired (24 hours) # Check if session is expired (24 hours)
expiry_time = session_model.created_at + timedelta(hours=24) expiry_time = session.created_at + timedelta(hours=24)
if datetime.now() > expiry_time: if datetime.now() > expiry_time:
# Clean up expired session # Clean up expired session
await self.delete_session(token) await self.delete_session(token)
return None return None
return { return session
"token": session_model.token,
"user_id": UUID(bytes=session_model.user_id),
"credential_id": session_model.credential_id,
"created_at": session_model.created_at,
"info": session_model.info or {},
}
return None return None
async def delete_session(self, token: str) -> None: async def delete_session(self, token: str) -> None:
@ -334,8 +328,8 @@ class DB:
async def refresh_session(self, token: str) -> str | None: async def refresh_session(self, token: str) -> str | None:
"""Refresh a session by updating its created_at timestamp.""" """Refresh a session by updating its created_at timestamp."""
session_data = await self.get_session(token) session = await self.get_session(token)
if not session_data: if not session:
return None return None
# Delete old session # Delete old session
@ -343,9 +337,9 @@ class DB:
# Create new session with same user and credential # Create new session with same user and credential
return await self.create_session( return await self.create_session(
session_data["user_id"], user_id=UUID(bytes=session.user_id),
session_data["credential_id"], credential_id=session.credential_id,
info=session_data["info"], info=session.info,
) )
@ -438,7 +432,7 @@ async def create_session_by_credential_id(
) )
async def get_session(token: str) -> dict | None: async def get_session(token: str) -> SessionModel | None:
"""Get session by token string.""" """Get session by token string."""
async with connect() as db: async with connect() as db:
return await db.get_session(token) return await db.get_session(token)

View File

@ -7,6 +7,8 @@ This module provides endpoints for authenticated users to:
- Add new passkeys to existing accounts via tokens - Add new passkeys to existing accounts via tokens
""" """
from uuid import UUID
from fastapi import FastAPI, Path, Request from fastapi import FastAPI, Path, Request
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
@ -61,20 +63,20 @@ def register_reset_routes(app: FastAPI):
): ):
try: try:
# Get session token to validate it exists and get user_id # Get session token to validate it exists and get user_id
session_data = await sql.get_session(passphrase) session = await sql.get_session(passphrase)
if not session_data: if not session:
# Token doesn't exist, redirect to home # Token doesn't exist, redirect to home
return RedirectResponse(url="/", status_code=303) return RedirectResponse(url="/", status_code=303)
# Check if this is a device addition session (credential_id is None) # Check if this is a device addition session (credential_id is None)
if session_data["credential_id"] is not None: if session.credential_id is not None:
# Not a device addition session, redirect to home # Not a device addition session, redirect to home
return RedirectResponse(url="/", status_code=303) return RedirectResponse(url="/", status_code=303)
# Create a device addition session token for the user # Create a device addition session token for the user
client_info = get_client_info(request) client_info = get_client_info(request)
session_token = await sql.create_session( session_token = await sql.create_session(
session_data["user_id"], None, None, client_info UUID(bytes=session.user_id), None, None, client_info
) )
# Create response and set session cookie # Create response and set session cookie
@ -92,12 +94,12 @@ async def use_reset_token(token: str) -> dict:
"""Delete a device addition token after successful use.""" """Delete a device addition token after successful use."""
try: try:
# Get session token first to validate it exists and is not expired # Get session token first to validate it exists and is not expired
session_data = await sql.get_session(token) session = await sql.get_session(token)
if not session_data: if not session:
return {"error": "Invalid or expired device addition token"} return {"error": "Invalid or expired device addition token"}
# Check if this is a device addition session (credential_id is None) # Check if this is a device addition session (credential_id is None)
if session_data["credential_id"] is not None: if session.credential_id is not None:
return {"error": "Invalid device addition token"} return {"error": "Invalid device addition token"}
# Delete the token (it's now used) # Delete the token (it's now used)

View File

@ -9,7 +9,7 @@ This module contains all WebSocket endpoints for:
""" """
import logging import logging
from datetime import datetime, timedelta from datetime import datetime
from uuid import UUID from uuid import UUID
import uuid7 import uuid7
@ -139,17 +139,11 @@ async def websocket_add_device_credential(ws: WebSocket, token: str):
await ws.accept() await ws.accept()
origin = ws.headers.get("origin") origin = ws.headers.get("origin")
try: try:
reset_token = await sql.get_reset_token(token) reset_token = await sql.get_session(token)
if not reset_token: if not reset_token:
await ws.send_json({"error": "Invalid or expired device addition token"}) await ws.send_json({"error": "Invalid or expired device addition token"})
return return
# Check if token is expired (24 hours)
expiry_time = reset_token.created_at + timedelta(hours=24)
if datetime.now() > expiry_time:
await ws.send_json({"error": "Device addition token has expired"})
return
# Get user information # Get user information
user = await sql.get_user_by_id(reset_token.user_id) user = await sql.get_user_by_id(reset_token.user_id)

View File

@ -1,170 +0,0 @@
"""
JWT session management for WebAuthn authentication.
This module provides JWT token generation and validation for managing user sessions
after successful WebAuthn authentication. Tokens contain user ID and credential ID
for session validation.
"""
import secrets
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional
from uuid import UUID
import jwt
SECRET_FILE = Path("server-secret.bin")
def load_or_create_secret() -> bytes:
"""Load JWT secret from file or create a new one."""
if SECRET_FILE.exists():
return SECRET_FILE.read_bytes()
else:
# Generate a new 16-byte secret
secret = secrets.token_bytes(16)
SECRET_FILE.write_bytes(secret)
return secret
class JWTManager:
"""Manages JWT tokens for user sessions."""
def __init__(self, secret_key: bytes, algorithm: str = "HS256"):
self.secret_key = secret_key
self.algorithm = algorithm
self.token_expiry = timedelta(hours=24) # Tokens expire after 24 hours
def create_token(self, user_id: UUID, credential_id: bytes) -> str:
"""
Create a JWT token for a user session.
Args:
user_id: The user's UUID
credential_id: The credential ID used for authentication
Returns:
JWT token string
"""
now = datetime.now()
payload = {
"user_id": str(user_id),
"credential_id": credential_id.hex(),
"iat": now,
"exp": now + self.token_expiry,
"iss": "passkeyauth",
}
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
def create_token_without_credential(self, user_id: UUID) -> str:
"""
Create a JWT token for device addition (without credential ID).
Args:
user_id: The user's UUID
Returns:
JWT token string for device addition
"""
now = datetime.now()
payload = {
"user_id": str(user_id),
"credential_id": None, # No credential for device addition
"device_addition": True, # Flag to indicate this is for device addition
"iat": now,
"exp": now + timedelta(hours=2), # Shorter expiry for device addition
"iss": "passkeyauth",
}
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
def validate_token(self, token: str) -> Optional[dict]:
"""
Validate a JWT token and return the payload.
Args:
token: JWT token string
Returns:
Dictionary with user_id and credential_id, or None if invalid
"""
try:
payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm],
issuer="passkeyauth",
)
result = {
"user_id": UUID(payload["user_id"]),
"issued_at": payload["iat"],
"expires_at": payload["exp"],
}
# Handle credential_id for regular tokens vs device addition tokens
if payload.get("credential_id") is not None:
result["credential_id"] = bytes.fromhex(payload["credential_id"])
else:
result["credential_id"] = None
# Add device addition flag if present
if payload.get("device_addition"):
result["device_addition"] = True
return result
except jwt.ExpiredSignatureError:
return None
except jwt.InvalidTokenError:
return None
def refresh_token(self, token: str) -> Optional[str]:
"""
Refresh a JWT token if it's still valid.
Args:
token: Current JWT token
Returns:
New JWT token string, or None if the current token is invalid
"""
payload = self.validate_token(token)
if payload is None:
return None
return self.create_token(payload["user_id"], payload["credential_id"])
# Global JWT manager instance
_jwt_manager: JWTManager | None = None
def get_jwt_manager() -> JWTManager:
"""Get the global JWT manager instance."""
global _jwt_manager
if _jwt_manager is None:
secret = load_or_create_secret()
_jwt_manager = JWTManager(secret)
return _jwt_manager # type: ignore
def create_session_token(user_id: UUID, credential_id: bytes) -> str:
"""Create a session token for a user."""
return get_jwt_manager().create_token(user_id, credential_id)
def create_device_addition_token(user_id: UUID) -> str:
"""Create a token for device addition."""
return get_jwt_manager().create_token_without_credential(user_id)
def validate_session_token(token: str) -> Optional[dict]:
"""Validate a session token."""
return get_jwt_manager().validate_token(token)
def refresh_session_token(token: str) -> Optional[str]:
"""Refresh a session token."""
return get_jwt_manager().refresh_token(token)