Major cleanup, refactoring, device registrations.
This commit is contained in:
@@ -16,8 +16,8 @@ from .jwt_manager import refresh_session_token, validate_session_token
|
||||
from .session_manager import (
|
||||
clear_session_cookie,
|
||||
get_current_user,
|
||||
get_session_token_from_auth_header_or_body,
|
||||
get_session_token_from_request,
|
||||
get_session_token_from_bearer,
|
||||
get_session_token_from_cookie,
|
||||
set_session_cookie,
|
||||
)
|
||||
|
||||
@@ -52,7 +52,7 @@ async def get_user_credentials(request: Request) -> dict:
|
||||
|
||||
# Get current session credential ID
|
||||
current_credential_id = None
|
||||
session_token = get_session_token_from_request(request)
|
||||
session_token = get_session_token_from_cookie(request)
|
||||
if session_token:
|
||||
token_data = validate_session_token(session_token)
|
||||
if token_data:
|
||||
@@ -65,34 +65,30 @@ async def get_user_credentials(request: Request) -> dict:
|
||||
user_aaguids = set()
|
||||
|
||||
for cred_id in credential_ids:
|
||||
try:
|
||||
stored_cred = await db.get_credential_by_id(cred_id)
|
||||
stored_cred = await db.get_credential_by_id(cred_id)
|
||||
|
||||
# Convert AAGUID to string format
|
||||
aaguid_str = str(stored_cred.aaguid)
|
||||
user_aaguids.add(aaguid_str)
|
||||
# Convert AAGUID to string format
|
||||
aaguid_str = str(stored_cred.aaguid)
|
||||
user_aaguids.add(aaguid_str)
|
||||
|
||||
# Check if this is the current session credential
|
||||
is_current_session = current_credential_id == stored_cred.credential_id
|
||||
# Check if this is the current session credential
|
||||
is_current_session = current_credential_id == stored_cred.credential_id
|
||||
|
||||
credentials.append(
|
||||
{
|
||||
"credential_id": stored_cred.credential_id.hex(),
|
||||
"aaguid": aaguid_str,
|
||||
"created_at": stored_cred.created_at.isoformat(),
|
||||
"last_used": stored_cred.last_used.isoformat()
|
||||
if stored_cred.last_used
|
||||
else None,
|
||||
"last_verified": stored_cred.last_verified.isoformat()
|
||||
if stored_cred.last_verified
|
||||
else None,
|
||||
"sign_count": stored_cred.sign_count,
|
||||
"is_current_session": is_current_session,
|
||||
}
|
||||
)
|
||||
except ValueError:
|
||||
# Skip invalid credentials
|
||||
continue
|
||||
credentials.append(
|
||||
{
|
||||
"credential_id": stored_cred.credential_id.hex(),
|
||||
"aaguid": aaguid_str,
|
||||
"created_at": stored_cred.created_at.isoformat(),
|
||||
"last_used": stored_cred.last_used.isoformat()
|
||||
if stored_cred.last_used
|
||||
else None,
|
||||
"last_verified": stored_cred.last_verified.isoformat()
|
||||
if stored_cred.last_verified
|
||||
else None,
|
||||
"sign_count": stored_cred.sign_count,
|
||||
"is_current_session": is_current_session,
|
||||
}
|
||||
)
|
||||
|
||||
# Get AAGUID information for only the AAGUIDs that the user has
|
||||
aaguid_manager = get_aaguid_manager()
|
||||
@@ -113,7 +109,7 @@ async def get_user_credentials(request: Request) -> dict:
|
||||
async def refresh_token(request: Request, response: Response) -> dict:
|
||||
"""Refresh the session token."""
|
||||
try:
|
||||
session_token = get_session_token_from_request(request)
|
||||
session_token = get_session_token_from_cookie(request)
|
||||
if not session_token:
|
||||
return {"error": "No session token found"}
|
||||
|
||||
@@ -134,7 +130,7 @@ async def refresh_token(request: Request, response: Response) -> dict:
|
||||
async def validate_token(request: Request) -> dict:
|
||||
"""Validate a session token and return user info."""
|
||||
try:
|
||||
session_token = get_session_token_from_request(request)
|
||||
session_token = get_session_token_from_cookie(request)
|
||||
if not session_token:
|
||||
return {"error": "No session token found"}
|
||||
|
||||
@@ -165,7 +161,7 @@ async def logout(response: Response) -> dict:
|
||||
async def set_session(request: Request, response: Response) -> dict:
|
||||
"""Set session cookie using JWT token from request body or Authorization header."""
|
||||
try:
|
||||
session_token = await get_session_token_from_auth_header_or_body(request)
|
||||
session_token = await get_session_token_from_bearer(request)
|
||||
|
||||
if not session_token:
|
||||
return {"error": "No session token provided"}
|
||||
@@ -219,7 +215,7 @@ async def delete_credential(request: Request) -> dict:
|
||||
return {"error": "Credential not found"}
|
||||
|
||||
# Check if this is the current session credential
|
||||
session_token = get_session_token_from_request(request)
|
||||
session_token = get_session_token_from_cookie(request)
|
||||
if session_token:
|
||||
token_data = validate_session_token(session_token)
|
||||
if token_data and token_data.get("credential_id") == credential_id_bytes:
|
||||
|
||||
@@ -9,14 +9,19 @@ This module provides a simple WebAuthn implementation that:
|
||||
- Enables true passwordless authentication where users don't need to enter a user_name
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import FastAPI, Request, Response, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi import (
|
||||
Path as FastAPIPath,
|
||||
)
|
||||
from fastapi.responses import FileResponse, RedirectResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from webauthn.helpers.exceptions import InvalidAuthenticationResponse
|
||||
|
||||
from . import db
|
||||
from .api_handlers import (
|
||||
@@ -40,7 +45,7 @@ STATIC_DIR = Path(__file__).parent.parent / "static"
|
||||
passkey = Passkey(
|
||||
rp_id="localhost",
|
||||
rp_name="Passkey Auth",
|
||||
origin="http://localhost:8000",
|
||||
origin="http://localhost:3000",
|
||||
)
|
||||
|
||||
|
||||
@@ -53,15 +58,12 @@ async def lifespan(app: FastAPI):
|
||||
app = FastAPI(title="Passkey Auth", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.websocket("/ws/new_user_registration")
|
||||
async def websocket_register_new(ws: WebSocket):
|
||||
@app.websocket("/auth/ws/register_new")
|
||||
async def websocket_register_new(ws: WebSocket, user_name: str):
|
||||
"""Register a new user and with a new passkey credential."""
|
||||
await ws.accept()
|
||||
try:
|
||||
# Data for the new user account
|
||||
form = await ws.receive_json()
|
||||
user_id = uuid4()
|
||||
user_name = form["user_name"]
|
||||
|
||||
# WebAuthn registration
|
||||
credential = await register_chat(ws, user_id, user_name)
|
||||
@@ -86,9 +88,12 @@ async def websocket_register_new(ws: WebSocket):
|
||||
await ws.send_json({"error": str(e)})
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception:
|
||||
logging.exception("Internal Server Error")
|
||||
await ws.send_json({"error": "Internal Server Error"})
|
||||
|
||||
|
||||
@app.websocket("/ws/add_credential")
|
||||
@app.websocket("/auth/ws/add_credential")
|
||||
async def websocket_register_add(ws: WebSocket):
|
||||
"""Register a new credential for an existing user."""
|
||||
await ws.accept()
|
||||
@@ -108,7 +113,6 @@ async def websocket_register_add(ws: WebSocket):
|
||||
|
||||
# WebAuthn registration
|
||||
credential = await register_chat(ws, user_id, user_name, challenge_ids)
|
||||
print(f"New credential for user {user_id}: {credential}")
|
||||
# Store the new credential in the database
|
||||
await db.create_credential_for_user(credential)
|
||||
|
||||
@@ -124,24 +128,16 @@ async def websocket_register_add(ws: WebSocket):
|
||||
await ws.send_json({"error": str(e)})
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as e:
|
||||
await ws.send_json({"error": f"Server error: {str(e)}"})
|
||||
except Exception:
|
||||
logging.exception("Internal Server Error")
|
||||
await ws.send_json({"error": "Internal Server Error"})
|
||||
|
||||
|
||||
@app.websocket("/ws/add_device_credential")
|
||||
async def websocket_add_device_credential(ws: WebSocket):
|
||||
@app.websocket("/auth/ws/add_device_credential")
|
||||
async def websocket_add_device_credential(ws: WebSocket, token: str):
|
||||
"""Add a new credential for an existing user via device addition token."""
|
||||
await ws.accept()
|
||||
try:
|
||||
# Get device addition token from client
|
||||
message = await ws.receive_json()
|
||||
token = message.get("token")
|
||||
|
||||
if not token:
|
||||
await ws.send_json({"error": "Device addition token is required"})
|
||||
return
|
||||
|
||||
# Validate device addition token
|
||||
reset_token = await db.get_reset_token(token)
|
||||
if not reset_token:
|
||||
await ws.send_json({"error": "Invalid or expired device addition token"})
|
||||
@@ -182,8 +178,9 @@ async def websocket_add_device_credential(ws: WebSocket):
|
||||
await ws.send_json({"error": str(e)})
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as e:
|
||||
await ws.send_json({"error": f"Server error: {str(e)}"})
|
||||
except Exception:
|
||||
logging.exception("Internal Server Error")
|
||||
await ws.send_json({"error": "Internal Server Error"})
|
||||
|
||||
|
||||
async def register_chat(
|
||||
@@ -200,11 +197,10 @@ async def register_chat(
|
||||
)
|
||||
await ws.send_json(options)
|
||||
response = await ws.receive_json()
|
||||
print(response)
|
||||
return passkey.reg_verify(response, challenge, user_id)
|
||||
|
||||
|
||||
@app.websocket("/ws/authenticate")
|
||||
@app.websocket("/auth/ws/authenticate")
|
||||
async def websocket_authenticate(ws: WebSocket):
|
||||
await ws.accept()
|
||||
try:
|
||||
@@ -231,114 +227,109 @@ async def websocket_authenticate(ws: WebSocket):
|
||||
"session_token": session_token,
|
||||
}
|
||||
)
|
||||
except ValueError as e:
|
||||
except (ValueError, InvalidAuthenticationResponse) as e:
|
||||
logging.exception("ValueError")
|
||||
await ws.send_json({"error": str(e)})
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception:
|
||||
logging.exception("Internal Server Error")
|
||||
await ws.send_json({"error": "Internal Server Error"})
|
||||
|
||||
|
||||
@app.get("/api/user-info")
|
||||
@app.get("/auth/user-info")
|
||||
async def api_get_user_info(request: Request):
|
||||
"""Get user information from session cookie."""
|
||||
return await get_user_info(request)
|
||||
|
||||
|
||||
@app.get("/api/user-credentials")
|
||||
@app.get("/auth/user-credentials")
|
||||
async def api_get_user_credentials(request: Request):
|
||||
"""Get all credentials for a user using session cookie."""
|
||||
return await get_user_credentials(request)
|
||||
|
||||
|
||||
@app.post("/api/refresh-token")
|
||||
@app.post("/auth/refresh-token")
|
||||
async def api_refresh_token(request: Request, response: Response):
|
||||
"""Refresh the session token."""
|
||||
return await refresh_token(request, response)
|
||||
|
||||
|
||||
@app.get("/api/validate-token")
|
||||
@app.get("/auth/validate-token")
|
||||
async def api_validate_token(request: Request):
|
||||
"""Validate a session token and return user info."""
|
||||
return await validate_token(request)
|
||||
|
||||
|
||||
@app.post("/api/logout")
|
||||
@app.post("/auth/logout")
|
||||
async def api_logout(response: Response):
|
||||
"""Log out the current user by clearing the session cookie."""
|
||||
return await logout(response)
|
||||
|
||||
|
||||
@app.post("/api/set-session")
|
||||
@app.post("/auth/set-session")
|
||||
async def api_set_session(request: Request, response: Response):
|
||||
"""Set session cookie using JWT token from request body or Authorization header."""
|
||||
return await set_session(request, response)
|
||||
|
||||
|
||||
@app.post("/api/delete-credential")
|
||||
@app.post("/auth/delete-credential")
|
||||
async def api_delete_credential(request: Request):
|
||||
"""Delete a specific credential for the current user."""
|
||||
return await delete_credential(request)
|
||||
|
||||
|
||||
@app.post("/api/create-device-link")
|
||||
@app.post("/auth/create-device-link")
|
||||
async def api_create_device_link(request: Request):
|
||||
"""Create a device addition link for the authenticated user."""
|
||||
return await create_device_addition_link(request)
|
||||
|
||||
|
||||
@app.post("/api/validate-device-token")
|
||||
@app.post("/auth/validate-device-token")
|
||||
async def api_validate_device_token(request: Request):
|
||||
"""Validate a device addition token."""
|
||||
return await validate_device_addition_token(request)
|
||||
|
||||
|
||||
@app.get("/auth/{passphrase}")
|
||||
async def reset_authentication(
|
||||
passphrase: str = FastAPIPath(pattern=r"^\w+(\.\w+){2,}$"),
|
||||
):
|
||||
response = RedirectResponse(url="/", status_code=303)
|
||||
response.set_cookie(
|
||||
key="auth-token",
|
||||
value=passphrase,
|
||||
httponly=False,
|
||||
secure=True,
|
||||
samesite="strict",
|
||||
max_age=2,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/auth/user-info-by-passphrase")
|
||||
async def api_get_user_info_by_passphrase(token: str):
|
||||
"""Get user information using the passphrase."""
|
||||
reset_token = await db.get_reset_token(token)
|
||||
if not reset_token:
|
||||
return Response(content="Invalid or expired passphrase", status_code=403)
|
||||
|
||||
user = await db.get_user_by_id(reset_token.user_id)
|
||||
if not user:
|
||||
return Response(content="User not found", status_code=404)
|
||||
|
||||
return {"user_name": user.user_name}
|
||||
|
||||
|
||||
# Serve static files
|
||||
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def get_index():
|
||||
"""Redirect to login page"""
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
return RedirectResponse(url="/auth/login", status_code=302)
|
||||
|
||||
|
||||
@app.get("/auth/login")
|
||||
async def get_login_page():
|
||||
"""Serve the login page"""
|
||||
return FileResponse(STATIC_DIR / "login.html")
|
||||
|
||||
|
||||
@app.get("/auth/register")
|
||||
async def get_register_page():
|
||||
"""Serve the register page"""
|
||||
return FileResponse(STATIC_DIR / "register.html")
|
||||
|
||||
|
||||
@app.get("/auth/dashboard")
|
||||
async def get_dashboard_page():
|
||||
"""Redirect to profile (dashboard is now profile)"""
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
return RedirectResponse(url="/auth/profile", status_code=302)
|
||||
|
||||
|
||||
@app.get("/auth/profile")
|
||||
async def get_profile_page():
|
||||
"""Serve the profile page"""
|
||||
return FileResponse(STATIC_DIR / "profile.html")
|
||||
|
||||
|
||||
@app.get("/auth/reset")
|
||||
async def get_reset_page_without_token():
|
||||
"""Serve the reset page without a token"""
|
||||
return FileResponse(STATIC_DIR / "reset.html")
|
||||
|
||||
|
||||
@app.get("/reset/{token}")
|
||||
async def get_reset_page(token: str):
|
||||
"""Serve the reset page with the token in URL"""
|
||||
return FileResponse(STATIC_DIR / "reset.html")
|
||||
# Catch-all route for SPA - serve index.html for all non-API routes
|
||||
@app.get("/{path:path}")
|
||||
async def spa_handler(path: str):
|
||||
"""Serve the Vue SPA for all routes (except API and static)"""
|
||||
return FileResponse(STATIC_DIR / "index.html")
|
||||
|
||||
|
||||
def main():
|
||||
@@ -355,4 +346,5 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
main()
|
||||
|
||||
@@ -60,7 +60,7 @@ class Passkey:
|
||||
self,
|
||||
rp_id: str,
|
||||
rp_name: str,
|
||||
origin: str,
|
||||
origin: str | None = None,
|
||||
supported_pub_key_algs: list[COSEAlgorithmIdentifier] | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -74,7 +74,7 @@ class Passkey:
|
||||
"""
|
||||
self.rp_id = rp_id
|
||||
self.rp_name = rp_name
|
||||
self.origin = origin
|
||||
self.origin = origin or f"https://{rp_id}"
|
||||
self.supported_pub_key_algs = supported_pub_key_algs or [
|
||||
COSEAlgorithmIdentifier.EDDSA,
|
||||
COSEAlgorithmIdentifier.ECDSA_SHA_256,
|
||||
|
||||
@@ -25,19 +25,18 @@ async def create_device_addition_link(request: Request) -> dict:
|
||||
return {"error": "Authentication required"}
|
||||
|
||||
# Generate a human-readable token
|
||||
token = generate(n=4, sep="-") # e.g., "able-ocean-forest-dawn"
|
||||
token = generate(n=4, sep=".") # e.g., "able-ocean-forest-dawn"
|
||||
|
||||
# Create reset token in database
|
||||
await db.create_reset_token(user.user_id, token)
|
||||
|
||||
# Generate the device addition link with pretty URL
|
||||
addition_link = f"http://localhost:8000/reset/{token}"
|
||||
addition_link = f"{request.headers.get('origin', '')}/auth/{token}"
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Device addition link generated successfully",
|
||||
"addition_link": addition_link,
|
||||
"token": token,
|
||||
"expires_in_hours": 24,
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ This module provides session management functionality including:
|
||||
- Session validation and token handling
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request, Response
|
||||
@@ -15,11 +14,11 @@ from fastapi import Request, Response
|
||||
from .db import User, get_user_by_id
|
||||
from .jwt_manager import validate_session_token
|
||||
|
||||
COOKIE_NAME = "session_token"
|
||||
COOKIE_NAME = "auth"
|
||||
COOKIE_MAX_AGE = 86400 # 24 hours
|
||||
|
||||
|
||||
async def get_current_user(request: Request) -> Optional[User]:
|
||||
async def get_current_user(request: Request) -> User | None:
|
||||
"""Get the current user from the session cookie."""
|
||||
session_token = request.cookies.get(COOKIE_NAME)
|
||||
if not session_token:
|
||||
@@ -43,7 +42,7 @@ def set_session_cookie(response: Response, session_token: str) -> None:
|
||||
value=session_token,
|
||||
max_age=COOKIE_MAX_AGE,
|
||||
httponly=True,
|
||||
secure=False, # Set to True in production with HTTPS
|
||||
secure=True,
|
||||
samesite="lax",
|
||||
)
|
||||
|
||||
@@ -53,36 +52,29 @@ def clear_session_cookie(response: Response) -> None:
|
||||
response.delete_cookie(key=COOKIE_NAME)
|
||||
|
||||
|
||||
def get_session_token_from_request(request: Request) -> Optional[str]:
|
||||
def get_session_token_from_cookie(request: Request) -> str | None:
|
||||
"""Extract session token from request cookies."""
|
||||
return request.cookies.get(COOKIE_NAME)
|
||||
|
||||
|
||||
async def validate_session_from_request(request: Request) -> Optional[dict]:
|
||||
async def validate_session_from_request(request: Request) -> dict | None:
|
||||
"""Validate session token from request and return token data."""
|
||||
session_token = get_session_token_from_request(request)
|
||||
session_token = get_session_token_from_cookie(request)
|
||||
if not session_token:
|
||||
return None
|
||||
|
||||
return validate_session_token(session_token)
|
||||
|
||||
|
||||
async def get_session_token_from_auth_header_or_body(request: Request) -> Optional[str]:
|
||||
async def get_session_token_from_bearer(request: Request) -> str | None:
|
||||
"""Extract session token from Authorization header or request body."""
|
||||
# Try to get token from Authorization header first
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
return auth_header[7:] # Remove "Bearer " prefix
|
||||
|
||||
# Try to get from request body
|
||||
try:
|
||||
body = await request.json()
|
||||
return body.get("session_token")
|
||||
except Exception:
|
||||
return None
|
||||
return auth_header.removeprefix("Bearer ")
|
||||
|
||||
|
||||
async def get_user_from_cookie_string(cookie_header: str) -> Optional[UUID]:
|
||||
async def get_user_from_cookie_string(cookie_header: str) -> UUID | None:
|
||||
"""Parse cookie header and return user ID if valid session exists."""
|
||||
if not cookie_header:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user