Compare commits
No commits in common. "0cfa622bf1df35c012776dc9c3a2d2a93a533db5" and "dc0b0f461322494677fd16bae2afba256d37ef4f" have entirely different histories.
0cfa622bf1
...
dc0b0f4613
@ -298,21 +298,27 @@ class DB:
|
|||||||
|
|
||||||
return await self.create_session(user_id, db_credential_id, token, info)
|
return await self.create_session(user_id, db_credential_id, token, info)
|
||||||
|
|
||||||
async def get_session(self, token: str) -> SessionModel | None:
|
async def get_session(self, token: str) -> dict | None:
|
||||||
"""Get session by token string."""
|
"""Get session by token string."""
|
||||||
stmt = select(SessionModel).where(SessionModel.token == token)
|
stmt = select(SessionModel).where(SessionModel.token == token)
|
||||||
result = await self.session.execute(stmt)
|
result = await self.session.execute(stmt)
|
||||||
session = result.scalar_one_or_none()
|
session_model = result.scalar_one_or_none()
|
||||||
|
|
||||||
if session:
|
if session_model:
|
||||||
# Check if session is expired (24 hours)
|
# Check if session is expired (24 hours)
|
||||||
expiry_time = session.created_at + timedelta(hours=24)
|
expiry_time = session_model.created_at + timedelta(hours=24)
|
||||||
if datetime.now() > expiry_time:
|
if datetime.now() > expiry_time:
|
||||||
# Clean up expired session
|
# Clean up expired session
|
||||||
await self.delete_session(token)
|
await self.delete_session(token)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return session
|
return {
|
||||||
|
"token": session_model.token,
|
||||||
|
"user_id": UUID(bytes=session_model.user_id),
|
||||||
|
"credential_id": session_model.credential_id,
|
||||||
|
"created_at": session_model.created_at,
|
||||||
|
"info": session_model.info or {},
|
||||||
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def delete_session(self, token: str) -> None:
|
async def delete_session(self, token: str) -> None:
|
||||||
@ -328,8 +334,8 @@ class DB:
|
|||||||
|
|
||||||
async def refresh_session(self, token: str) -> str | None:
|
async def refresh_session(self, token: str) -> str | None:
|
||||||
"""Refresh a session by updating its created_at timestamp."""
|
"""Refresh a session by updating its created_at timestamp."""
|
||||||
session = await self.get_session(token)
|
session_data = await self.get_session(token)
|
||||||
if not session:
|
if not session_data:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Delete old session
|
# Delete old session
|
||||||
@ -337,9 +343,9 @@ class DB:
|
|||||||
|
|
||||||
# Create new session with same user and credential
|
# Create new session with same user and credential
|
||||||
return await self.create_session(
|
return await self.create_session(
|
||||||
user_id=UUID(bytes=session.user_id),
|
session_data["user_id"],
|
||||||
credential_id=session.credential_id,
|
session_data["credential_id"],
|
||||||
info=session.info,
|
info=session_data["info"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -432,7 +438,7 @@ async def create_session_by_credential_id(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_session(token: str) -> SessionModel | None:
|
async def get_session(token: str) -> dict | None:
|
||||||
"""Get session by token string."""
|
"""Get session by token string."""
|
||||||
async with connect() as db:
|
async with connect() as db:
|
||||||
return await db.get_session(token)
|
return await db.get_session(token)
|
||||||
|
@ -25,12 +25,12 @@ from fastapi.responses import (
|
|||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
from ..db import sql
|
from ..db import sql
|
||||||
from .api import (
|
from .api_handlers import (
|
||||||
register_api_routes,
|
register_api_routes,
|
||||||
validate_token,
|
validate_token,
|
||||||
)
|
)
|
||||||
from .reset import register_reset_routes
|
from .reset_handlers import register_reset_routes
|
||||||
from .ws import ws_app
|
from .ws_handlers import ws_app
|
||||||
|
|
||||||
STATIC_DIR = Path(__file__).parent.parent / "frontend-build"
|
STATIC_DIR = Path(__file__).parent.parent / "frontend-build"
|
||||||
|
|
||||||
|
@ -7,8 +7,6 @@ This module provides endpoints for authenticated users to:
|
|||||||
- Add new passkeys to existing accounts via tokens
|
- Add new passkeys to existing accounts via tokens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from fastapi import FastAPI, Path, Request
|
from fastapi import FastAPI, Path, Request
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
@ -63,20 +61,20 @@ def register_reset_routes(app: FastAPI):
|
|||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
# Get session token to validate it exists and get user_id
|
# Get session token to validate it exists and get user_id
|
||||||
session = await sql.get_session(passphrase)
|
session_data = await sql.get_session(passphrase)
|
||||||
if not session:
|
if not session_data:
|
||||||
# Token doesn't exist, redirect to home
|
# Token doesn't exist, redirect to home
|
||||||
return RedirectResponse(url="/", status_code=303)
|
return RedirectResponse(url="/", status_code=303)
|
||||||
|
|
||||||
# Check if this is a device addition session (credential_id is None)
|
# Check if this is a device addition session (credential_id is None)
|
||||||
if session.credential_id is not None:
|
if session_data["credential_id"] is not None:
|
||||||
# Not a device addition session, redirect to home
|
# Not a device addition session, redirect to home
|
||||||
return RedirectResponse(url="/", status_code=303)
|
return RedirectResponse(url="/", status_code=303)
|
||||||
|
|
||||||
# Create a device addition session token for the user
|
# Create a device addition session token for the user
|
||||||
client_info = get_client_info(request)
|
client_info = get_client_info(request)
|
||||||
session_token = await sql.create_session(
|
session_token = await sql.create_session(
|
||||||
UUID(bytes=session.user_id), None, None, client_info
|
session_data["user_id"], None, None, client_info
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create response and set session cookie
|
# Create response and set session cookie
|
||||||
@ -90,16 +88,16 @@ def register_reset_routes(app: FastAPI):
|
|||||||
return RedirectResponse(url="/", status_code=303)
|
return RedirectResponse(url="/", status_code=303)
|
||||||
|
|
||||||
|
|
||||||
async def use_reset_token(token: str) -> dict:
|
async def use_device_addition_token(token: str) -> dict:
|
||||||
"""Delete a device addition token after successful use."""
|
"""Delete a device addition token after successful use."""
|
||||||
try:
|
try:
|
||||||
# Get session token first to validate it exists and is not expired
|
# Get session token first to validate it exists and is not expired
|
||||||
session = await sql.get_session(token)
|
session_data = await sql.get_session(token)
|
||||||
if not session:
|
if not session_data:
|
||||||
return {"error": "Invalid or expired device addition token"}
|
return {"error": "Invalid or expired device addition token"}
|
||||||
|
|
||||||
# Check if this is a device addition session (credential_id is None)
|
# Check if this is a device addition session (credential_id is None)
|
||||||
if session.credential_id is not None:
|
if session_data["credential_id"] is not None:
|
||||||
return {"error": "Invalid device addition token"}
|
return {"error": "Invalid device addition token"}
|
||||||
|
|
||||||
# Delete the token (it's now used)
|
# Delete the token (it's now used)
|
@ -9,7 +9,7 @@ This module contains all WebSocket endpoints for:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import uuid7
|
import uuid7
|
||||||
@ -139,11 +139,17 @@ async def websocket_add_device_credential(ws: WebSocket, token: str):
|
|||||||
await ws.accept()
|
await ws.accept()
|
||||||
origin = ws.headers.get("origin")
|
origin = ws.headers.get("origin")
|
||||||
try:
|
try:
|
||||||
reset_token = await sql.get_session(token)
|
reset_token = await sql.get_reset_token(token)
|
||||||
if not reset_token:
|
if not reset_token:
|
||||||
await ws.send_json({"error": "Invalid or expired device addition token"})
|
await ws.send_json({"error": "Invalid or expired device addition token"})
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Check if token is expired (24 hours)
|
||||||
|
expiry_time = reset_token.created_at + timedelta(hours=24)
|
||||||
|
if datetime.now() > expiry_time:
|
||||||
|
await ws.send_json({"error": "Device addition token has expired"})
|
||||||
|
return
|
||||||
|
|
||||||
# Get user information
|
# Get user information
|
||||||
user = await sql.get_user_by_id(reset_token.user_id)
|
user = await sql.get_user_by_id(reset_token.user_id)
|
||||||
|
|
170
passkey/util/jwt.py
Normal file
170
passkey/util/jwt.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
"""
|
||||||
|
JWT session management for WebAuthn authentication.
|
||||||
|
|
||||||
|
This module provides JWT token generation and validation for managing user sessions
|
||||||
|
after successful WebAuthn authentication. Tokens contain user ID and credential ID
|
||||||
|
for session validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
SECRET_FILE = Path("server-secret.bin")
|
||||||
|
|
||||||
|
|
||||||
|
def load_or_create_secret() -> bytes:
|
||||||
|
"""Load JWT secret from file or create a new one."""
|
||||||
|
if SECRET_FILE.exists():
|
||||||
|
return SECRET_FILE.read_bytes()
|
||||||
|
else:
|
||||||
|
# Generate a new 16-byte secret
|
||||||
|
secret = secrets.token_bytes(16)
|
||||||
|
SECRET_FILE.write_bytes(secret)
|
||||||
|
return secret
|
||||||
|
|
||||||
|
|
||||||
|
class JWTManager:
|
||||||
|
"""Manages JWT tokens for user sessions."""
|
||||||
|
|
||||||
|
def __init__(self, secret_key: bytes, algorithm: str = "HS256"):
|
||||||
|
self.secret_key = secret_key
|
||||||
|
self.algorithm = algorithm
|
||||||
|
self.token_expiry = timedelta(hours=24) # Tokens expire after 24 hours
|
||||||
|
|
||||||
|
def create_token(self, user_id: UUID, credential_id: bytes) -> str:
|
||||||
|
"""
|
||||||
|
Create a JWT token for a user session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID
|
||||||
|
credential_id: The credential ID used for authentication
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JWT token string
|
||||||
|
"""
|
||||||
|
now = datetime.now()
|
||||||
|
payload = {
|
||||||
|
"user_id": str(user_id),
|
||||||
|
"credential_id": credential_id.hex(),
|
||||||
|
"iat": now,
|
||||||
|
"exp": now + self.token_expiry,
|
||||||
|
"iss": "passkeyauth",
|
||||||
|
}
|
||||||
|
|
||||||
|
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
||||||
|
|
||||||
|
def create_token_without_credential(self, user_id: UUID) -> str:
|
||||||
|
"""
|
||||||
|
Create a JWT token for device addition (without credential ID).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JWT token string for device addition
|
||||||
|
"""
|
||||||
|
now = datetime.now()
|
||||||
|
payload = {
|
||||||
|
"user_id": str(user_id),
|
||||||
|
"credential_id": None, # No credential for device addition
|
||||||
|
"device_addition": True, # Flag to indicate this is for device addition
|
||||||
|
"iat": now,
|
||||||
|
"exp": now + timedelta(hours=2), # Shorter expiry for device addition
|
||||||
|
"iss": "passkeyauth",
|
||||||
|
}
|
||||||
|
|
||||||
|
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
||||||
|
|
||||||
|
def validate_token(self, token: str) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Validate a JWT token and return the payload.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: JWT token string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with user_id and credential_id, or None if invalid
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
self.secret_key,
|
||||||
|
algorithms=[self.algorithm],
|
||||||
|
issuer="passkeyauth",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"user_id": UUID(payload["user_id"]),
|
||||||
|
"issued_at": payload["iat"],
|
||||||
|
"expires_at": payload["exp"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle credential_id for regular tokens vs device addition tokens
|
||||||
|
if payload.get("credential_id") is not None:
|
||||||
|
result["credential_id"] = bytes.fromhex(payload["credential_id"])
|
||||||
|
else:
|
||||||
|
result["credential_id"] = None
|
||||||
|
|
||||||
|
# Add device addition flag if present
|
||||||
|
if payload.get("device_addition"):
|
||||||
|
result["device_addition"] = True
|
||||||
|
|
||||||
|
return result
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
return None
|
||||||
|
except jwt.InvalidTokenError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def refresh_token(self, token: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Refresh a JWT token if it's still valid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: Current JWT token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New JWT token string, or None if the current token is invalid
|
||||||
|
"""
|
||||||
|
payload = self.validate_token(token)
|
||||||
|
if payload is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.create_token(payload["user_id"], payload["credential_id"])
|
||||||
|
|
||||||
|
|
||||||
|
# Global JWT manager instance
|
||||||
|
_jwt_manager: JWTManager | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_jwt_manager() -> JWTManager:
|
||||||
|
"""Get the global JWT manager instance."""
|
||||||
|
global _jwt_manager
|
||||||
|
if _jwt_manager is None:
|
||||||
|
secret = load_or_create_secret()
|
||||||
|
_jwt_manager = JWTManager(secret)
|
||||||
|
return _jwt_manager # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def create_session_token(user_id: UUID, credential_id: bytes) -> str:
|
||||||
|
"""Create a session token for a user."""
|
||||||
|
return get_jwt_manager().create_token(user_id, credential_id)
|
||||||
|
|
||||||
|
|
||||||
|
def create_device_addition_token(user_id: UUID) -> str:
|
||||||
|
"""Create a token for device addition."""
|
||||||
|
return get_jwt_manager().create_token_without_credential(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_session_token(token: str) -> Optional[dict]:
|
||||||
|
"""Validate a session token."""
|
||||||
|
return get_jwt_manager().validate_token(token)
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_session_token(token: str) -> Optional[str]:
|
||||||
|
"""Refresh a session token."""
|
||||||
|
return get_jwt_manager().refresh_token(token)
|
Loading…
x
Reference in New Issue
Block a user