From 0cfa622bf1df35c012776dc9c3a2d2a93a533db5 Mon Sep 17 00:00:00 2001 From: Leo Vasanko Date: Sun, 27 Jul 2025 23:44:26 -0600 Subject: [PATCH] Removal of JWT code, cleanup, using User dataclass rather than UserModel in APIs. --- passkey/db/sql.py | 28 +++---- passkey/fastapi/reset.py | 16 ++-- passkey/fastapi/ws.py | 10 +-- passkey/util/jwt.py | 170 --------------------------------------- 4 files changed, 22 insertions(+), 202 deletions(-) delete mode 100644 passkey/util/jwt.py diff --git a/passkey/db/sql.py b/passkey/db/sql.py index e493656..fc0d4b9 100644 --- a/passkey/db/sql.py +++ b/passkey/db/sql.py @@ -298,27 +298,21 @@ class DB: return await self.create_session(user_id, db_credential_id, token, info) - async def get_session(self, token: str) -> dict | None: + async def get_session(self, token: str) -> SessionModel | None: """Get session by token string.""" stmt = select(SessionModel).where(SessionModel.token == token) result = await self.session.execute(stmt) - session_model = result.scalar_one_or_none() + session = result.scalar_one_or_none() - if session_model: + if session: # Check if session is expired (24 hours) - expiry_time = session_model.created_at + timedelta(hours=24) + expiry_time = session.created_at + timedelta(hours=24) if datetime.now() > expiry_time: # Clean up expired session await self.delete_session(token) return None - return { - "token": session_model.token, - "user_id": UUID(bytes=session_model.user_id), - "credential_id": session_model.credential_id, - "created_at": session_model.created_at, - "info": session_model.info or {}, - } + return session return None async def delete_session(self, token: str) -> None: @@ -334,8 +328,8 @@ class DB: async def refresh_session(self, token: str) -> str | None: """Refresh a session by updating its created_at timestamp.""" - session_data = await self.get_session(token) - if not session_data: + session = await self.get_session(token) + if not session: return None # Delete old session @@ -343,9 +337,9 @@ class DB: # Create new session with same user and credential return await self.create_session( - session_data["user_id"], - session_data["credential_id"], - info=session_data["info"], + user_id=UUID(bytes=session.user_id), + credential_id=session.credential_id, + info=session.info, ) @@ -438,7 +432,7 @@ async def create_session_by_credential_id( ) -async def get_session(token: str) -> dict | None: +async def get_session(token: str) -> SessionModel | None: """Get session by token string.""" async with connect() as db: return await db.get_session(token) diff --git a/passkey/fastapi/reset.py b/passkey/fastapi/reset.py index ff4efbb..354d82f 100644 --- a/passkey/fastapi/reset.py +++ b/passkey/fastapi/reset.py @@ -7,6 +7,8 @@ This module provides endpoints for authenticated users to: - Add new passkeys to existing accounts via tokens """ +from uuid import UUID + from fastapi import FastAPI, Path, Request from fastapi.responses import RedirectResponse @@ -61,20 +63,20 @@ def register_reset_routes(app: FastAPI): ): try: # Get session token to validate it exists and get user_id - session_data = await sql.get_session(passphrase) - if not session_data: + session = await sql.get_session(passphrase) + if not session: # Token doesn't exist, redirect to home return RedirectResponse(url="/", status_code=303) # Check if this is a device addition session (credential_id is None) - if session_data["credential_id"] is not None: + if session.credential_id is not None: # Not a device addition session, redirect to home return RedirectResponse(url="/", status_code=303) # Create a device addition session token for the user client_info = get_client_info(request) session_token = await sql.create_session( - session_data["user_id"], None, None, client_info + UUID(bytes=session.user_id), None, None, client_info ) # Create response and set session cookie @@ -92,12 +94,12 @@ async def use_reset_token(token: str) -> dict: """Delete a device addition token after successful use.""" try: # Get session token first to validate it exists and is not expired - session_data = await sql.get_session(token) - if not session_data: + session = await sql.get_session(token) + if not session: return {"error": "Invalid or expired device addition token"} # Check if this is a device addition session (credential_id is None) - if session_data["credential_id"] is not None: + if session.credential_id is not None: return {"error": "Invalid device addition token"} # Delete the token (it's now used) diff --git a/passkey/fastapi/ws.py b/passkey/fastapi/ws.py index 7d5bc99..1e5c662 100644 --- a/passkey/fastapi/ws.py +++ b/passkey/fastapi/ws.py @@ -9,7 +9,7 @@ This module contains all WebSocket endpoints for: """ import logging -from datetime import datetime, timedelta +from datetime import datetime from uuid import UUID import uuid7 @@ -139,17 +139,11 @@ async def websocket_add_device_credential(ws: WebSocket, token: str): await ws.accept() origin = ws.headers.get("origin") try: - reset_token = await sql.get_reset_token(token) + reset_token = await sql.get_session(token) if not reset_token: await ws.send_json({"error": "Invalid or expired device addition token"}) return - # Check if token is expired (24 hours) - expiry_time = reset_token.created_at + timedelta(hours=24) - if datetime.now() > expiry_time: - await ws.send_json({"error": "Device addition token has expired"}) - return - # Get user information user = await sql.get_user_by_id(reset_token.user_id) diff --git a/passkey/util/jwt.py b/passkey/util/jwt.py deleted file mode 100644 index 4bfac25..0000000 --- a/passkey/util/jwt.py +++ /dev/null @@ -1,170 +0,0 @@ -""" -JWT session management for WebAuthn authentication. - -This module provides JWT token generation and validation for managing user sessions -after successful WebAuthn authentication. Tokens contain user ID and credential ID -for session validation. -""" - -import secrets -from datetime import datetime, timedelta -from pathlib import Path -from typing import Optional -from uuid import UUID - -import jwt - -SECRET_FILE = Path("server-secret.bin") - - -def load_or_create_secret() -> bytes: - """Load JWT secret from file or create a new one.""" - if SECRET_FILE.exists(): - return SECRET_FILE.read_bytes() - else: - # Generate a new 16-byte secret - secret = secrets.token_bytes(16) - SECRET_FILE.write_bytes(secret) - return secret - - -class JWTManager: - """Manages JWT tokens for user sessions.""" - - def __init__(self, secret_key: bytes, algorithm: str = "HS256"): - self.secret_key = secret_key - self.algorithm = algorithm - self.token_expiry = timedelta(hours=24) # Tokens expire after 24 hours - - def create_token(self, user_id: UUID, credential_id: bytes) -> str: - """ - Create a JWT token for a user session. - - Args: - user_id: The user's UUID - credential_id: The credential ID used for authentication - - Returns: - JWT token string - """ - now = datetime.now() - payload = { - "user_id": str(user_id), - "credential_id": credential_id.hex(), - "iat": now, - "exp": now + self.token_expiry, - "iss": "passkeyauth", - } - - return jwt.encode(payload, self.secret_key, algorithm=self.algorithm) - - def create_token_without_credential(self, user_id: UUID) -> str: - """ - Create a JWT token for device addition (without credential ID). - - Args: - user_id: The user's UUID - - Returns: - JWT token string for device addition - """ - now = datetime.now() - payload = { - "user_id": str(user_id), - "credential_id": None, # No credential for device addition - "device_addition": True, # Flag to indicate this is for device addition - "iat": now, - "exp": now + timedelta(hours=2), # Shorter expiry for device addition - "iss": "passkeyauth", - } - - return jwt.encode(payload, self.secret_key, algorithm=self.algorithm) - - def validate_token(self, token: str) -> Optional[dict]: - """ - Validate a JWT token and return the payload. - - Args: - token: JWT token string - - Returns: - Dictionary with user_id and credential_id, or None if invalid - """ - try: - payload = jwt.decode( - token, - self.secret_key, - algorithms=[self.algorithm], - issuer="passkeyauth", - ) - - result = { - "user_id": UUID(payload["user_id"]), - "issued_at": payload["iat"], - "expires_at": payload["exp"], - } - - # Handle credential_id for regular tokens vs device addition tokens - if payload.get("credential_id") is not None: - result["credential_id"] = bytes.fromhex(payload["credential_id"]) - else: - result["credential_id"] = None - - # Add device addition flag if present - if payload.get("device_addition"): - result["device_addition"] = True - - return result - except jwt.ExpiredSignatureError: - return None - except jwt.InvalidTokenError: - return None - - def refresh_token(self, token: str) -> Optional[str]: - """ - Refresh a JWT token if it's still valid. - - Args: - token: Current JWT token - - Returns: - New JWT token string, or None if the current token is invalid - """ - payload = self.validate_token(token) - if payload is None: - return None - - return self.create_token(payload["user_id"], payload["credential_id"]) - - -# Global JWT manager instance -_jwt_manager: JWTManager | None = None - - -def get_jwt_manager() -> JWTManager: - """Get the global JWT manager instance.""" - global _jwt_manager - if _jwt_manager is None: - secret = load_or_create_secret() - _jwt_manager = JWTManager(secret) - return _jwt_manager # type: ignore - - -def create_session_token(user_id: UUID, credential_id: bytes) -> str: - """Create a session token for a user.""" - return get_jwt_manager().create_token(user_id, credential_id) - - -def create_device_addition_token(user_id: UUID) -> str: - """Create a token for device addition.""" - return get_jwt_manager().create_token_without_credential(user_id) - - -def validate_session_token(token: str) -> Optional[dict]: - """Validate a session token.""" - return get_jwt_manager().validate_token(token) - - -def refresh_session_token(token: str) -> Optional[str]: - """Refresh a session token.""" - return get_jwt_manager().refresh_token(token)