157 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			157 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import logging
 | |
| from functools import wraps
 | |
| from uuid import UUID
 | |
| 
 | |
| from fastapi import Cookie, FastAPI, WebSocket, WebSocketDisconnect
 | |
| from webauthn.helpers.exceptions import InvalidAuthenticationResponse
 | |
| 
 | |
| from ..authsession import create_session, get_reset, get_session
 | |
| from ..globals import db, passkey
 | |
| from ..util import hostutil, passphrase
 | |
| from ..util.tokens import create_token, session_key
 | |
| from .session import infodict
 | |
| 
 | |
| 
 | |
| # WebSocket error handling decorator
 | |
| def websocket_error_handler(func):
 | |
|     @wraps(func)
 | |
|     async def wrapper(ws: WebSocket, *args, **kwargs):
 | |
|         try:
 | |
|             await ws.accept()
 | |
|             return await func(ws, *args, **kwargs)
 | |
|         except WebSocketDisconnect:
 | |
|             pass
 | |
|         except (ValueError, InvalidAuthenticationResponse) as e:
 | |
|             await ws.send_json({"detail": str(e)})
 | |
|         except Exception:
 | |
|             logging.exception("Internal Server Error")
 | |
|             await ws.send_json({"detail": "Internal Server Error"})
 | |
| 
 | |
|     return wrapper
 | |
| 
 | |
| 
 | |
| # Create a FastAPI subapp for WebSocket endpoints
 | |
| app = FastAPI()
 | |
| 
 | |
| 
 | |
| async def register_chat(
 | |
|     ws: WebSocket,
 | |
|     user_uuid: UUID,
 | |
|     user_name: str,
 | |
|     credential_ids: list[bytes] | None = None,
 | |
|     origin: str | None = None,
 | |
| ):
 | |
|     """Generate registration options and send them to the client."""
 | |
|     options, challenge = passkey.instance.reg_generate_options(
 | |
|         user_id=user_uuid,
 | |
|         user_name=user_name,
 | |
|         credential_ids=credential_ids,
 | |
|         origin=origin,
 | |
|     )
 | |
|     await ws.send_json(options)
 | |
|     response = await ws.receive_json()
 | |
|     return passkey.instance.reg_verify(response, challenge, user_uuid, origin=origin)
 | |
| 
 | |
| 
 | |
| @app.websocket("/register")
 | |
| @websocket_error_handler
 | |
| async def websocket_register_add(
 | |
|     ws: WebSocket,
 | |
|     reset: str | None = None,
 | |
|     name: str | None = None,
 | |
|     auth=Cookie(None, alias="__Host-auth"),
 | |
| ):
 | |
|     """Register a new credential for an existing user.
 | |
| 
 | |
|     Supports either:
 | |
|     - Normal session via auth cookie
 | |
|     - Reset token supplied as ?reset=... (auth cookie ignored)
 | |
|     """
 | |
|     origin = ws.headers["origin"]
 | |
|     host = hostutil.normalize_host(ws.headers.get("host"))
 | |
|     if host is None:
 | |
|         raise ValueError("Missing host header")
 | |
|     if reset is not None:
 | |
|         if not passphrase.is_well_formed(reset):
 | |
|             raise ValueError("Invalid reset token")
 | |
|         s = await get_reset(reset)
 | |
|     else:
 | |
|         if not auth:
 | |
|             raise ValueError("Authentication Required")
 | |
|         s = await get_session(auth, host=host)
 | |
|     user_uuid = s.user_uuid
 | |
| 
 | |
|     # Get user information and determine effective user_name for this registration
 | |
|     user = await db.instance.get_user_by_uuid(user_uuid)
 | |
|     user_name = user.display_name
 | |
|     if name is not None:
 | |
|         stripped = name.strip()
 | |
|         if stripped:
 | |
|             user_name = stripped
 | |
|     challenge_ids = await db.instance.get_credentials_by_user_uuid(user_uuid)
 | |
| 
 | |
|     # WebAuthn registration
 | |
|     credential = await register_chat(ws, user_uuid, user_name, challenge_ids, origin)
 | |
| 
 | |
|     # Create a new session and store everything in database
 | |
|     token = create_token()
 | |
|     metadata = infodict(ws, "authenticated")
 | |
|     await db.instance.create_credential_session(  # type: ignore[attr-defined]
 | |
|         user_uuid=user_uuid,
 | |
|         credential=credential,
 | |
|         reset_key=(s.key if reset is not None else None),
 | |
|         session_key=session_key(token),
 | |
|         display_name=user_name,
 | |
|         host=host,
 | |
|         ip=metadata.get("ip"),
 | |
|         user_agent=metadata.get("user_agent"),
 | |
|     )
 | |
|     auth = token
 | |
| 
 | |
|     assert isinstance(auth, str) and len(auth) == 16
 | |
|     await ws.send_json(
 | |
|         {
 | |
|             "user_uuid": str(user.uuid),
 | |
|             "credential_uuid": str(credential.uuid),
 | |
|             "session_token": auth,
 | |
|             "message": "New credential added successfully",
 | |
|         }
 | |
|     )
 | |
| 
 | |
| 
 | |
| @app.websocket("/authenticate")
 | |
| @websocket_error_handler
 | |
| async def websocket_authenticate(ws: WebSocket):
 | |
|     origin = ws.headers["origin"]
 | |
|     host = hostutil.normalize_host(ws.headers.get("host"))
 | |
|     if host is None:
 | |
|         raise ValueError("Missing host header")
 | |
|     options, challenge = passkey.instance.auth_generate_options()
 | |
|     await ws.send_json(options)
 | |
|     # Wait for the client to use his authenticator to authenticate
 | |
|     credential = passkey.instance.auth_parse(await ws.receive_json())
 | |
|     # Fetch from the database by credential ID
 | |
|     stored_cred = await db.instance.get_credential_by_id(credential.raw_id)
 | |
|     # Verify the credential matches the stored data
 | |
|     passkey.instance.auth_verify(credential, challenge, stored_cred, origin=origin)
 | |
|     # Update both credential and user's last_seen timestamp
 | |
|     await db.instance.login(stored_cred.user_uuid, stored_cred)
 | |
| 
 | |
|     # Create a session token for the authenticated user
 | |
|     assert stored_cred.uuid is not None
 | |
|     metadata = infodict(ws, "auth")
 | |
|     token = await create_session(
 | |
|         user_uuid=stored_cred.user_uuid,
 | |
|         credential_uuid=stored_cred.uuid,
 | |
|         host=host,
 | |
|         ip=metadata.get("ip") or "",
 | |
|         user_agent=metadata.get("user_agent") or "",
 | |
|     )
 | |
| 
 | |
|     await ws.send_json(
 | |
|         {
 | |
|             "user_uuid": str(stored_cred.user_uuid),
 | |
|             "session_token": token,
 | |
|         }
 | |
|     )
 | 
