Compare commits

..

No commits in common. "0cfa622bf1df35c012776dc9c3a2d2a93a533db5" and "dc0b0f461322494677fd16bae2afba256d37ef4f" have entirely different histories.

6 changed files with 206 additions and 26 deletions

View File

@ -298,21 +298,27 @@ class DB:
return await self.create_session(user_id, db_credential_id, token, info) return await self.create_session(user_id, db_credential_id, token, info)
async def get_session(self, token: str) -> SessionModel | None: async def get_session(self, token: str) -> dict | None:
"""Get session by token string.""" """Get session by token string."""
stmt = select(SessionModel).where(SessionModel.token == token) stmt = select(SessionModel).where(SessionModel.token == token)
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
session = result.scalar_one_or_none() session_model = result.scalar_one_or_none()
if session: if session_model:
# Check if session is expired (24 hours) # Check if session is expired (24 hours)
expiry_time = session.created_at + timedelta(hours=24) expiry_time = session_model.created_at + timedelta(hours=24)
if datetime.now() > expiry_time: if datetime.now() > expiry_time:
# Clean up expired session # Clean up expired session
await self.delete_session(token) await self.delete_session(token)
return None return None
return session return {
"token": session_model.token,
"user_id": UUID(bytes=session_model.user_id),
"credential_id": session_model.credential_id,
"created_at": session_model.created_at,
"info": session_model.info or {},
}
return None return None
async def delete_session(self, token: str) -> None: async def delete_session(self, token: str) -> None:
@ -328,8 +334,8 @@ class DB:
async def refresh_session(self, token: str) -> str | None: async def refresh_session(self, token: str) -> str | None:
"""Refresh a session by updating its created_at timestamp.""" """Refresh a session by updating its created_at timestamp."""
session = await self.get_session(token) session_data = await self.get_session(token)
if not session: if not session_data:
return None return None
# Delete old session # Delete old session
@ -337,9 +343,9 @@ class DB:
# Create new session with same user and credential # Create new session with same user and credential
return await self.create_session( return await self.create_session(
user_id=UUID(bytes=session.user_id), session_data["user_id"],
credential_id=session.credential_id, session_data["credential_id"],
info=session.info, info=session_data["info"],
) )
@ -432,7 +438,7 @@ async def create_session_by_credential_id(
) )
async def get_session(token: str) -> SessionModel | None: async def get_session(token: str) -> dict | None:
"""Get session by token string.""" """Get session by token string."""
async with connect() as db: async with connect() as db:
return await db.get_session(token) return await db.get_session(token)

View File

@ -25,12 +25,12 @@ from fastapi.responses import (
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from ..db import sql from ..db import sql
from .api import ( from .api_handlers import (
register_api_routes, register_api_routes,
validate_token, validate_token,
) )
from .reset import register_reset_routes from .reset_handlers import register_reset_routes
from .ws import ws_app from .ws_handlers import ws_app
STATIC_DIR = Path(__file__).parent.parent / "frontend-build" STATIC_DIR = Path(__file__).parent.parent / "frontend-build"

View File

@ -7,8 +7,6 @@ This module provides endpoints for authenticated users to:
- Add new passkeys to existing accounts via tokens - Add new passkeys to existing accounts via tokens
""" """
from uuid import UUID
from fastapi import FastAPI, Path, Request from fastapi import FastAPI, Path, Request
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
@ -63,20 +61,20 @@ def register_reset_routes(app: FastAPI):
): ):
try: try:
# Get session token to validate it exists and get user_id # Get session token to validate it exists and get user_id
session = await sql.get_session(passphrase) session_data = await sql.get_session(passphrase)
if not session: if not session_data:
# Token doesn't exist, redirect to home # Token doesn't exist, redirect to home
return RedirectResponse(url="/", status_code=303) return RedirectResponse(url="/", status_code=303)
# Check if this is a device addition session (credential_id is None) # Check if this is a device addition session (credential_id is None)
if session.credential_id is not None: if session_data["credential_id"] is not None:
# Not a device addition session, redirect to home # Not a device addition session, redirect to home
return RedirectResponse(url="/", status_code=303) return RedirectResponse(url="/", status_code=303)
# Create a device addition session token for the user # Create a device addition session token for the user
client_info = get_client_info(request) client_info = get_client_info(request)
session_token = await sql.create_session( session_token = await sql.create_session(
UUID(bytes=session.user_id), None, None, client_info session_data["user_id"], None, None, client_info
) )
# Create response and set session cookie # Create response and set session cookie
@ -90,16 +88,16 @@ def register_reset_routes(app: FastAPI):
return RedirectResponse(url="/", status_code=303) return RedirectResponse(url="/", status_code=303)
async def use_reset_token(token: str) -> dict: async def use_device_addition_token(token: str) -> dict:
"""Delete a device addition token after successful use.""" """Delete a device addition token after successful use."""
try: try:
# Get session token first to validate it exists and is not expired # Get session token first to validate it exists and is not expired
session = await sql.get_session(token) session_data = await sql.get_session(token)
if not session: if not session_data:
return {"error": "Invalid or expired device addition token"} return {"error": "Invalid or expired device addition token"}
# Check if this is a device addition session (credential_id is None) # Check if this is a device addition session (credential_id is None)
if session.credential_id is not None: if session_data["credential_id"] is not None:
return {"error": "Invalid device addition token"} return {"error": "Invalid device addition token"}
# Delete the token (it's now used) # Delete the token (it's now used)

View File

@ -9,7 +9,7 @@ This module contains all WebSocket endpoints for:
""" """
import logging import logging
from datetime import datetime from datetime import datetime, timedelta
from uuid import UUID from uuid import UUID
import uuid7 import uuid7
@ -139,11 +139,17 @@ async def websocket_add_device_credential(ws: WebSocket, token: str):
await ws.accept() await ws.accept()
origin = ws.headers.get("origin") origin = ws.headers.get("origin")
try: try:
reset_token = await sql.get_session(token) reset_token = await sql.get_reset_token(token)
if not reset_token: if not reset_token:
await ws.send_json({"error": "Invalid or expired device addition token"}) await ws.send_json({"error": "Invalid or expired device addition token"})
return return
# Check if token is expired (24 hours)
expiry_time = reset_token.created_at + timedelta(hours=24)
if datetime.now() > expiry_time:
await ws.send_json({"error": "Device addition token has expired"})
return
# Get user information # Get user information
user = await sql.get_user_by_id(reset_token.user_id) user = await sql.get_user_by_id(reset_token.user_id)

170
passkey/util/jwt.py Normal file
View File

@ -0,0 +1,170 @@
"""
JWT session management for WebAuthn authentication.
This module provides JWT token generation and validation for managing user sessions
after successful WebAuthn authentication. Tokens contain user ID and credential ID
for session validation.
"""
import secrets
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional
from uuid import UUID
import jwt
SECRET_FILE = Path("server-secret.bin")
def load_or_create_secret() -> bytes:
"""Load JWT secret from file or create a new one."""
if SECRET_FILE.exists():
return SECRET_FILE.read_bytes()
else:
# Generate a new 16-byte secret
secret = secrets.token_bytes(16)
SECRET_FILE.write_bytes(secret)
return secret
class JWTManager:
"""Manages JWT tokens for user sessions."""
def __init__(self, secret_key: bytes, algorithm: str = "HS256"):
self.secret_key = secret_key
self.algorithm = algorithm
self.token_expiry = timedelta(hours=24) # Tokens expire after 24 hours
def create_token(self, user_id: UUID, credential_id: bytes) -> str:
"""
Create a JWT token for a user session.
Args:
user_id: The user's UUID
credential_id: The credential ID used for authentication
Returns:
JWT token string
"""
now = datetime.now()
payload = {
"user_id": str(user_id),
"credential_id": credential_id.hex(),
"iat": now,
"exp": now + self.token_expiry,
"iss": "passkeyauth",
}
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
def create_token_without_credential(self, user_id: UUID) -> str:
"""
Create a JWT token for device addition (without credential ID).
Args:
user_id: The user's UUID
Returns:
JWT token string for device addition
"""
now = datetime.now()
payload = {
"user_id": str(user_id),
"credential_id": None, # No credential for device addition
"device_addition": True, # Flag to indicate this is for device addition
"iat": now,
"exp": now + timedelta(hours=2), # Shorter expiry for device addition
"iss": "passkeyauth",
}
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
def validate_token(self, token: str) -> Optional[dict]:
"""
Validate a JWT token and return the payload.
Args:
token: JWT token string
Returns:
Dictionary with user_id and credential_id, or None if invalid
"""
try:
payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm],
issuer="passkeyauth",
)
result = {
"user_id": UUID(payload["user_id"]),
"issued_at": payload["iat"],
"expires_at": payload["exp"],
}
# Handle credential_id for regular tokens vs device addition tokens
if payload.get("credential_id") is not None:
result["credential_id"] = bytes.fromhex(payload["credential_id"])
else:
result["credential_id"] = None
# Add device addition flag if present
if payload.get("device_addition"):
result["device_addition"] = True
return result
except jwt.ExpiredSignatureError:
return None
except jwt.InvalidTokenError:
return None
def refresh_token(self, token: str) -> Optional[str]:
"""
Refresh a JWT token if it's still valid.
Args:
token: Current JWT token
Returns:
New JWT token string, or None if the current token is invalid
"""
payload = self.validate_token(token)
if payload is None:
return None
return self.create_token(payload["user_id"], payload["credential_id"])
# Global JWT manager instance
_jwt_manager: JWTManager | None = None
def get_jwt_manager() -> JWTManager:
"""Get the global JWT manager instance."""
global _jwt_manager
if _jwt_manager is None:
secret = load_or_create_secret()
_jwt_manager = JWTManager(secret)
return _jwt_manager # type: ignore
def create_session_token(user_id: UUID, credential_id: bytes) -> str:
"""Create a session token for a user."""
return get_jwt_manager().create_token(user_id, credential_id)
def create_device_addition_token(user_id: UUID) -> str:
"""Create a token for device addition."""
return get_jwt_manager().create_token_without_credential(user_id)
def validate_session_token(token: str) -> Optional[dict]:
"""Validate a session token."""
return get_jwt_manager().validate_token(token)
def refresh_session_token(token: str) -> Optional[str]:
"""Refresh a session token."""
return get_jwt_manager().refresh_token(token)