""" Minimal FastAPI WebAuthn server with WebSocket support for passkey registration and authentication. This module provides a simple WebAuthn implementation that: - Uses WebSocket for real-time communication - Supports Resident Keys (discoverable credentials) for passwordless authentication - Maintains challenges locally per connection - Uses async SQLite database for persistent storage of users and credentials - Enables true passwordless authentication where users don't need to enter a user_name """ import logging from contextlib import asynccontextmanager from datetime import datetime from pathlib import Path from uuid import UUID, uuid4 from fastapi import ( FastAPI, Request, Response, WebSocket, WebSocketDisconnect, ) from fastapi import ( Path as FastAPIPath, ) from fastapi.responses import FileResponse, JSONResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from webauthn.helpers.exceptions import InvalidAuthenticationResponse from . import db from .api_handlers import ( delete_credential, get_user_credentials, get_user_info, logout, refresh_token, set_session, validate_token, ) from .db import User from .jwt_manager import create_session_token from .passkey import Passkey from .reset_handlers import create_device_addition_link, validate_device_addition_token from .session_manager import get_user_from_cookie_string STATIC_DIR = Path(__file__).parent / "frontend-static" passkey = Passkey( rp_id="localhost", rp_name="Passkey Auth", ) @asynccontextmanager async def lifespan(app: FastAPI): await db.init_database() yield app = FastAPI(title="Passkey Auth", lifespan=lifespan) @app.websocket("/auth/ws/register_new") async def websocket_register_new(ws: WebSocket, user_name: str): """Register a new user and with a new passkey credential.""" await ws.accept() origin = ws.headers.get("origin") try: user_id = uuid4() # WebAuthn registration credential = await register_chat(ws, user_id, user_name, origin=origin) # Store the user and credential in the database await db.create_user_and_credential( User(user_id, user_name, created_at=datetime.now()), credential, ) # Create a session token for the new user session_token = create_session_token(user_id, credential.credential_id) await ws.send_json( { "status": "success", "user_id": str(user_id), "session_token": session_token, } ) except ValueError as e: await ws.send_json({"error": str(e)}) except WebSocketDisconnect: pass except Exception: logging.exception("Internal Server Error") await ws.send_json({"error": "Internal Server Error"}) @app.websocket("/auth/ws/add_credential") async def websocket_register_add(ws: WebSocket): """Register a new credential for an existing user.""" await ws.accept() origin = ws.headers.get("origin") try: # Authenticate user via cookie cookie_header = ws.headers.get("cookie", "") user_id = await get_user_from_cookie_string(cookie_header) if not user_id: await ws.send_json({"error": "Authentication required"}) return # Get user information to get the user_name user = await db.get_user_by_id(user_id) user_name = user.user_name challenge_ids = await db.get_user_credentials(user_id) # WebAuthn registration credential = await register_chat( ws, user_id, user_name, challenge_ids, origin=origin ) # Store the new credential in the database await db.create_credential_for_user(credential) await ws.send_json( { "status": "success", "user_id": str(user_id), "credential_id": credential.credential_id.hex(), "message": "New credential added successfully", } ) except ValueError as e: await ws.send_json({"error": str(e)}) except WebSocketDisconnect: pass except Exception: logging.exception("Internal Server Error") await ws.send_json({"error": "Internal Server Error"}) @app.websocket("/auth/ws/add_device_credential") async def websocket_add_device_credential(ws: WebSocket, token: str): """Add a new credential for an existing user via device addition token.""" await ws.accept() origin = ws.headers.get("origin") try: reset_token = await db.get_reset_token(token) if not reset_token: await ws.send_json({"error": "Invalid or expired device addition token"}) return # Check if token is expired (24 hours) from datetime import timedelta expiry_time = reset_token.created_at + timedelta(hours=24) if datetime.now() > expiry_time: await ws.send_json({"error": "Device addition token has expired"}) return # Get user information user = await db.get_user_by_id(reset_token.user_id) # WebAuthn registration # Fetch challenge IDs for the user challenge_ids = await db.get_user_credentials(reset_token.user_id) credential = await register_chat( ws, reset_token.user_id, user.user_name, challenge_ids, origin=origin ) # Store the new credential in the database await db.create_credential_for_user(credential) # Delete the device addition token (it's now used) await db.delete_reset_token(token) await ws.send_json( { "status": "success", "user_id": str(reset_token.user_id), "credential_id": credential.credential_id.hex(), "message": "New credential added successfully via device addition token", } ) except ValueError as e: await ws.send_json({"error": str(e)}) except WebSocketDisconnect: pass except Exception: logging.exception("Internal Server Error") await ws.send_json({"error": "Internal Server Error"}) async def register_chat( ws: WebSocket, user_id: UUID, user_name: str, credential_ids: list[bytes] | None = None, origin: str | None = None, ): """Generate registration options and send them to the client.""" options, challenge = passkey.reg_generate_options( user_id=user_id, user_name=user_name, credential_ids=credential_ids, origin=origin, ) await ws.send_json(options) response = await ws.receive_json() return passkey.reg_verify(response, challenge, user_id, origin=origin) @app.websocket("/auth/ws/authenticate") async def websocket_authenticate(ws: WebSocket): await ws.accept() origin = ws.headers.get("origin") try: options, challenge = passkey.auth_generate_options() await ws.send_json(options) # Wait for the client to use his authenticator to authenticate credential = passkey.auth_parse(await ws.receive_json()) # Fetch from the database by credential ID stored_cred = await db.get_credential_by_id(credential.raw_id) # Verify the credential matches the stored data passkey.auth_verify(credential, challenge, stored_cred, origin=origin) # Update both credential and user's last_seen timestamp await db.login_user(stored_cred.user_id, stored_cred) # Create a session token for the authenticated user session_token = create_session_token( stored_cred.user_id, stored_cred.credential_id ) await ws.send_json( { "status": "success", "user_id": str(stored_cred.user_id), "session_token": session_token, } ) except (ValueError, InvalidAuthenticationResponse) as e: logging.exception("ValueError") await ws.send_json({"error": str(e)}) except WebSocketDisconnect: pass except Exception: logging.exception("Internal Server Error") await ws.send_json({"error": "Internal Server Error"}) @app.get("/auth/user-info") async def api_get_user_info(request: Request): """Get user information from session cookie.""" return await get_user_info(request) @app.get("/auth/user-credentials") async def api_get_user_credentials(request: Request): """Get all credentials for a user using session cookie.""" return await get_user_credentials(request) @app.post("/auth/refresh-token") async def api_refresh_token(request: Request, response: Response): """Refresh the session token.""" return await refresh_token(request, response) @app.get("/auth/validate-token") async def api_validate_token(request: Request): """Validate a session token and return user info.""" return await validate_token(request) @app.get("/auth/forward-auth") async def forward_authentication(request: Request): """A verification endpoint to use with Caddy forward_auth or Nginx auth_request.""" result = await validate_token(request) if result.get("status") != "success": # Serve the index.html of the authentication app if not authenticated return FileResponse( STATIC_DIR / "index.html", status_code=401, headers={"www-authenticate": "PrivateToken"}, ) # If authenticated, return a success response return JSONResponse( result, headers={ "x-auth-user-id": result["user_id"], }, ) @app.post("/auth/logout") async def api_logout(response: Response): """Log out the current user by clearing the session cookie.""" return await logout(response) @app.post("/auth/set-session") async def api_set_session(request: Request, response: Response): """Set session cookie using JWT token from request body or Authorization header.""" return await set_session(request, response) @app.post("/auth/delete-credential") async def api_delete_credential(request: Request): """Delete a specific credential for the current user.""" return await delete_credential(request) @app.post("/auth/create-device-link") async def api_create_device_link(request: Request): """Create a device addition link for the authenticated user.""" return await create_device_addition_link(request) @app.post("/auth/validate-device-token") async def api_validate_device_token(request: Request): """Validate a device addition token.""" return await validate_device_addition_token(request) @app.get("/auth/{passphrase}") async def reset_authentication( passphrase: str = FastAPIPath(pattern=r"^\w+(\.\w+){2,}$"), ): response = RedirectResponse(url="/", status_code=303) response.set_cookie( key="auth-token", value=passphrase, httponly=False, secure=True, samesite="strict", max_age=2, ) return response # Serve static files app.mount( "/auth/assets", StaticFiles(directory=STATIC_DIR / "assets"), name="static assets" ) @app.get("/auth") async def redirect_to_index(): """Serve the main authentication app.""" return FileResponse(STATIC_DIR / "index.html") # Catch-all route for SPA - serve index.html for all non-API routes @app.get("/{path:path}") async def spa_handler(request: Request, path: str): """Serve the Vue SPA for all routes (except API and static)""" if "text/html" not in request.headers.get("accept", ""): return Response(content="Not Found", status_code=404) return FileResponse(STATIC_DIR / "index.html") def main(): """Entry point for the application""" import uvicorn uvicorn.run( "passkeyauth.main:app", host="0.0.0.0", port=8000, reload=True, log_level="info", ) if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) main()