Error handling cleanup for WS too.

This commit is contained in:
Leo Vasanko 2025-08-06 10:53:13 -06:00
parent c9ae53ef79
commit ba5f2d8bd9

View File

@ -1,15 +1,6 @@
"""
WebSocket handlers for passkey authentication operations.
This module contains all WebSocket endpoints for:
- User registration
- Adding credentials to existing users
- Device credential addition via token
- Authentication
"""
import logging import logging
from datetime import datetime from datetime import datetime
from functools import wraps
from uuid import UUID from uuid import UUID
import uuid7 import uuid7
@ -23,6 +14,25 @@ from ..util import passphrase
from ..util.tokens import create_token, session_key from ..util.tokens import create_token, session_key
from .session import infodict from .session import infodict
# WebSocket error handling decorator
def websocket_error_handler(func):
@wraps(func)
async def wrapper(ws: WebSocket, *args, **kwargs):
try:
await ws.accept()
return await func(ws, *args, **kwargs)
except WebSocketDisconnect:
pass
except (ValueError, InvalidAuthenticationResponse) as e:
await ws.send_json({"detail": str(e)})
except Exception:
logging.exception("Internal Server Error")
await ws.send_json({"detail": "Internal Server Error"})
return wrapper
# Create a FastAPI subapp for WebSocket endpoints # Create a FastAPI subapp for WebSocket endpoints
app = FastAPI() app = FastAPI()
@ -53,131 +63,104 @@ async def register_chat(
@app.websocket("/register") @app.websocket("/register")
@websocket_error_handler
async def websocket_register_new( async def websocket_register_new(
ws: WebSocket, user_name: str = Query(""), auth=Cookie(None) ws: WebSocket, user_name: str = Query(""), auth=Cookie(None)
): ):
"""Register a new user and with a new passkey credential.""" """Register a new user and with a new passkey credential."""
await ws.accept() origin = ws.headers["origin"]
origin = ws.headers.get("origin") user_uuid = uuid7.create()
try: # WebAuthn registration
user_uuid = uuid7.create() credential = await register_chat(ws, user_uuid, user_name, origin=origin)
# WebAuthn registration
credential = await register_chat(ws, user_uuid, user_name, origin=origin)
# Store the user and credential in the database # Store the user and credential in the database
await db.instance.create_user_and_credential( await db.instance.create_user_and_credential(
User(user_uuid, user_name, created_at=datetime.now()), User(user_uuid, user_name, created_at=datetime.now()),
credential, credential,
) )
# Create a session token for the new user # Create a session token for the new user
token = create_token() token = create_token()
await db.instance.create_session( await db.instance.create_session(
user_uuid=user_uuid, user_uuid=user_uuid,
key=session_key(token), key=session_key(token),
expires=datetime.now() + EXPIRES, expires=datetime.now() + EXPIRES,
info=infodict(ws, "authenticated"), info=infodict(ws, "authenticated"),
credential_uuid=credential.uuid, credential_uuid=credential.uuid,
) )
await ws.send_json( await ws.send_json(
{ {
"user_uuid": str(user_uuid), "user_uuid": str(user_uuid),
"session_token": token, "session_token": token,
} }
) )
except ValueError as e:
await ws.send_json({"detail": str(e)})
except WebSocketDisconnect:
pass
except Exception:
logging.exception("Internal Server Error")
await ws.send_json({"detail": "Internal Server Error"})
@app.websocket("/add_credential") @app.websocket("/add_credential")
@websocket_error_handler
async def websocket_register_add(ws: WebSocket, auth=Cookie(None)): async def websocket_register_add(ws: WebSocket, auth=Cookie(None)):
"""Register a new credential for an existing user.""" """Register a new credential for an existing user."""
await ws.accept()
origin = ws.headers["origin"] origin = ws.headers["origin"]
try: # Try to get either a regular session or a reset session
# Try to get either a regular session or a reset session reset = passphrase.is_well_formed(auth)
reset = passphrase.is_well_formed(auth) s = await (get_reset if reset else get_session)(auth)
s = await (get_reset if reset else get_session)(auth) user_uuid = s.user_uuid
user_uuid = s.user_uuid
# Get user information to get the user_name # Get user information to get the user_name
user = await db.instance.get_user_by_uuid(user_uuid) user = await db.instance.get_user_by_uuid(user_uuid)
user_name = user.display_name user_name = user.display_name
challenge_ids = await db.instance.get_credentials_by_user_uuid(user_uuid) challenge_ids = await db.instance.get_credentials_by_user_uuid(user_uuid)
# WebAuthn registration # WebAuthn registration
credential = await register_chat( credential = await register_chat(ws, user_uuid, user_name, challenge_ids, origin)
ws, user_uuid, user_name, challenge_ids, origin if reset:
# Replace reset session with a new session
await db.instance.delete_session(s.key)
token = await create_session(
user_uuid, credential.uuid, infodict(ws, "authenticated")
) )
if reset: else:
# Replace reset session with a new session token = auth
await db.instance.delete_session(s.key) assert isinstance(token, str) and len(token) == 16
token = await create_session( # Store the new credential in the database
user_uuid, credential.uuid, infodict(ws, "authenticated") await db.instance.create_credential(credential)
)
else:
token = auth
assert isinstance(token, str) and len(token) == 16
# Store the new credential in the database
await db.instance.create_credential(credential)
await ws.send_json( await ws.send_json(
{ {
"user_uuid": str(user.uuid), "user_uuid": str(user.uuid),
"credential_uuid": str(credential.uuid), "credential_uuid": str(credential.uuid),
"session_token": token, "session_token": token,
"message": "New credential added successfully", "message": "New credential added successfully",
} }
) )
except ValueError as e:
await ws.send_json({"detail": str(e)})
except WebSocketDisconnect:
pass
except Exception:
logging.exception("Internal Server Error")
await ws.send_json({"detail": "Internal Server Error"})
@app.websocket("/authenticate") @app.websocket("/authenticate")
@websocket_error_handler
async def websocket_authenticate(ws: WebSocket): async def websocket_authenticate(ws: WebSocket):
await ws.accept()
origin = ws.headers["origin"] origin = ws.headers["origin"]
try: options, challenge = passkey.auth_generate_options()
options, challenge = passkey.auth_generate_options() await ws.send_json(options)
await ws.send_json(options) # Wait for the client to use his authenticator to authenticate
# Wait for the client to use his authenticator to authenticate credential = passkey.auth_parse(await ws.receive_json())
credential = passkey.auth_parse(await ws.receive_json()) # Fetch from the database by credential ID
# Fetch from the database by credential ID stored_cred = await db.instance.get_credential_by_id(credential.raw_id)
stored_cred = await db.instance.get_credential_by_id(credential.raw_id) # Verify the credential matches the stored data
# Verify the credential matches the stored data passkey.auth_verify(credential, challenge, stored_cred, origin=origin)
passkey.auth_verify(credential, challenge, stored_cred, origin=origin) # Update both credential and user's last_seen timestamp
# Update both credential and user's last_seen timestamp await db.instance.login(stored_cred.user_uuid, stored_cred)
await db.instance.login(stored_cred.user_uuid, stored_cred)
# Create a session token for the authenticated user # Create a session token for the authenticated user
assert stored_cred.uuid is not None assert stored_cred.uuid is not None
token = await create_session( token = await create_session(
user_uuid=stored_cred.user_uuid, user_uuid=stored_cred.user_uuid,
info=infodict(ws, "auth"), info=infodict(ws, "auth"),
credential_uuid=stored_cred.uuid, credential_uuid=stored_cred.uuid,
) )
await ws.send_json( await ws.send_json(
{ {
"user_uuid": str(stored_cred.user_uuid), "user_uuid": str(stored_cred.user_uuid),
"session_token": token, "session_token": token,
} }
) )
except (ValueError, InvalidAuthenticationResponse) as e:
logging.exception("ValueError")
await ws.send_json({"detail": str(e)})
except WebSocketDisconnect:
pass
except Exception:
logging.exception("Internal Server Error")
await ws.send_json({"detail": "Internal Server Error"})