Files

143 lines
4.6 KiB
Python

import logging
from functools import wraps
from uuid import UUID
from fastapi import Cookie, FastAPI, WebSocket, WebSocketDisconnect
from webauthn.helpers.exceptions import InvalidAuthenticationResponse
from ..authsession import create_session, expires, get_reset, get_session
from ..globals import db, passkey
from ..util import passphrase
from ..util.tokens import create_token, session_key
from .session import infodict
# WebSocket error handling decorator
def websocket_error_handler(func):
@wraps(func)
async def wrapper(ws: WebSocket, *args, **kwargs):
try:
await ws.accept()
return await func(ws, *args, **kwargs)
except WebSocketDisconnect:
pass
except (ValueError, InvalidAuthenticationResponse) as e:
await ws.send_json({"detail": str(e)})
except Exception:
logging.exception("Internal Server Error")
await ws.send_json({"detail": "Internal Server Error"})
return wrapper
# Create a FastAPI subapp for WebSocket endpoints
app = FastAPI()
async def register_chat(
ws: WebSocket,
user_uuid: UUID,
user_name: str,
credential_ids: list[bytes] | None = None,
origin: str | None = None,
):
"""Generate registration options and send them to the client."""
options, challenge = passkey.instance.reg_generate_options(
user_id=user_uuid,
user_name=user_name,
credential_ids=credential_ids,
origin=origin,
)
await ws.send_json(options)
response = await ws.receive_json()
return passkey.instance.reg_verify(response, challenge, user_uuid, origin=origin)
@app.websocket("/register")
@websocket_error_handler
async def websocket_register_add(
ws: WebSocket, reset: str | None = None, name: str | None = None, auth=Cookie(None)
):
"""Register a new credential for an existing user.
Supports either:
- Normal session via auth cookie
- Reset token supplied as ?reset=... (auth cookie ignored)
"""
origin = ws.headers["origin"]
if reset is not None:
if not passphrase.is_well_formed(reset):
raise ValueError("Invalid reset token")
s = await get_reset(reset)
else:
if not auth:
raise ValueError("Authentication Required")
s = await get_session(auth)
user_uuid = s.user_uuid
# Get user information and determine effective user_name for this registration
user = await db.instance.get_user_by_uuid(user_uuid)
user_name = user.display_name
if name is not None:
stripped = name.strip()
if stripped:
user_name = stripped
challenge_ids = await db.instance.get_credentials_by_user_uuid(user_uuid)
# WebAuthn registration
credential = await register_chat(ws, user_uuid, user_name, challenge_ids, origin)
# Create a new session and store everything in database
token = create_token()
await db.instance.create_credential_session( # type: ignore[attr-defined]
user_uuid=user_uuid,
credential=credential,
reset_key=(s.key if reset is not None else None),
session_key=session_key(token),
session_expires=expires(),
session_info=infodict(ws, "authenticated"),
display_name=user_name,
)
auth = token
assert isinstance(auth, str) and len(auth) == 16
await ws.send_json(
{
"user_uuid": str(user.uuid),
"credential_uuid": str(credential.uuid),
"session_token": auth,
"message": "New credential added successfully",
}
)
@app.websocket("/authenticate")
@websocket_error_handler
async def websocket_authenticate(ws: WebSocket):
origin = ws.headers["origin"]
options, challenge = passkey.instance.auth_generate_options()
await ws.send_json(options)
# Wait for the client to use his authenticator to authenticate
credential = passkey.instance.auth_parse(await ws.receive_json())
# Fetch from the database by credential ID
stored_cred = await db.instance.get_credential_by_id(credential.raw_id)
# Verify the credential matches the stored data
passkey.instance.auth_verify(credential, challenge, stored_cred, origin=origin)
# Update both credential and user's last_seen timestamp
await db.instance.login(stored_cred.user_uuid, stored_cred)
# Create a session token for the authenticated user
assert stored_cred.uuid is not None
token = await create_session(
user_uuid=stored_cred.user_uuid,
info=infodict(ws, "auth"),
credential_uuid=stored_cred.uuid,
)
await ws.send_json(
{
"user_uuid": str(stored_cred.user_uuid),
"session_token": token,
}
)