Use origin from requests, rather than hardcode it. This is still constrained to rp_id and its subdomains, so it should be fine for security, also WebSockets make sure the origin doesn't change between stages of the chat.

This commit is contained in:
Leo Vasanko 2025-07-13 12:55:01 -06:00
parent 9711453553
commit 7665044032
2 changed files with 24 additions and 9 deletions

View File

@ -45,7 +45,6 @@ STATIC_DIR = Path(__file__).parent.parent / "static"
passkey = Passkey( passkey = Passkey(
rp_id="localhost", rp_id="localhost",
rp_name="Passkey Auth", rp_name="Passkey Auth",
origin="http://localhost:3000",
) )
@ -62,11 +61,12 @@ app = FastAPI(title="Passkey Auth", lifespan=lifespan)
async def websocket_register_new(ws: WebSocket, user_name: str): async def websocket_register_new(ws: WebSocket, user_name: str):
"""Register a new user and with a new passkey credential.""" """Register a new user and with a new passkey credential."""
await ws.accept() await ws.accept()
origin = ws.headers.get("origin")
try: try:
user_id = uuid4() user_id = uuid4()
# WebAuthn registration # WebAuthn registration
credential = await register_chat(ws, user_id, user_name) credential = await register_chat(ws, user_id, user_name, origin=origin)
# Store the user and credential in the database # Store the user and credential in the database
await db.create_user_and_credential( await db.create_user_and_credential(
@ -97,6 +97,7 @@ async def websocket_register_new(ws: WebSocket, user_name: str):
async def websocket_register_add(ws: WebSocket): async def websocket_register_add(ws: WebSocket):
"""Register a new credential for an existing user.""" """Register a new credential for an existing user."""
await ws.accept() await ws.accept()
origin = ws.headers.get("origin")
try: try:
# Authenticate user via cookie # Authenticate user via cookie
cookie_header = ws.headers.get("cookie", "") cookie_header = ws.headers.get("cookie", "")
@ -112,7 +113,9 @@ async def websocket_register_add(ws: WebSocket):
challenge_ids = await db.get_user_credentials(user_id) challenge_ids = await db.get_user_credentials(user_id)
# WebAuthn registration # WebAuthn registration
credential = await register_chat(ws, user_id, user_name, challenge_ids) credential = await register_chat(
ws, user_id, user_name, challenge_ids, origin=origin
)
# Store the new credential in the database # Store the new credential in the database
await db.create_credential_for_user(credential) await db.create_credential_for_user(credential)
@ -137,6 +140,7 @@ async def websocket_register_add(ws: WebSocket):
async def websocket_add_device_credential(ws: WebSocket, token: str): async def websocket_add_device_credential(ws: WebSocket, token: str):
"""Add a new credential for an existing user via device addition token.""" """Add a new credential for an existing user via device addition token."""
await ws.accept() await ws.accept()
origin = ws.headers.get("origin")
try: try:
reset_token = await db.get_reset_token(token) reset_token = await db.get_reset_token(token)
if not reset_token: if not reset_token:
@ -153,11 +157,13 @@ async def websocket_add_device_credential(ws: WebSocket, token: str):
# Get user information # Get user information
user = await db.get_user_by_id(reset_token.user_id) user = await db.get_user_by_id(reset_token.user_id)
challenge_ids = await db.get_user_credentials(reset_token.user_id)
# WebAuthn registration # WebAuthn registration
# Fetch challenge IDs for the user
challenge_ids = await db.get_user_credentials(reset_token.user_id)
credential = await register_chat( credential = await register_chat(
ws, reset_token.user_id, user.user_name, challenge_ids ws, reset_token.user_id, user.user_name, challenge_ids, origin=origin
) )
# Store the new credential in the database # Store the new credential in the database
@ -188,12 +194,14 @@ async def register_chat(
user_id: UUID, user_id: UUID,
user_name: str, user_name: str,
credential_ids: list[bytes] | None = None, credential_ids: list[bytes] | None = None,
origin: str | None = None,
): ):
"""Generate registration options and send them to the client.""" """Generate registration options and send them to the client."""
options, challenge = passkey.reg_generate_options( options, challenge = passkey.reg_generate_options(
user_id=user_id, user_id=user_id,
user_name=user_name, user_name=user_name,
credential_ids=credential_ids, credential_ids=credential_ids,
origin=origin,
) )
await ws.send_json(options) await ws.send_json(options)
response = await ws.receive_json() response = await ws.receive_json()
@ -203,15 +211,16 @@ async def register_chat(
@app.websocket("/auth/ws/authenticate") @app.websocket("/auth/ws/authenticate")
async def websocket_authenticate(ws: WebSocket): async def websocket_authenticate(ws: WebSocket):
await ws.accept() await ws.accept()
origin = ws.headers.get("origin")
try: try:
options, challenge = passkey.auth_generate_options() options, challenge = passkey.auth_generate_options(origin=origin)
await ws.send_json(options) await ws.send_json(options)
# Wait for the client to use his authenticator to authenticate # Wait for the client to use his authenticator to authenticate
credential = passkey.auth_parse(await ws.receive_json()) credential = passkey.auth_parse(await ws.receive_json())
# Fetch from the database by credential ID # Fetch from the database by credential ID
stored_cred = await db.get_credential_by_id(credential.raw_id) stored_cred = await db.get_credential_by_id(credential.raw_id)
# Verify the credential matches the stored data # Verify the credential matches the stored data
passkey.auth_verify(credential, challenge, stored_cred) passkey.auth_verify(credential, challenge, stored_cred, origin=origin)
# Update both credential and user's last_seen timestamp # Update both credential and user's last_seen timestamp
await db.login_user(stored_cred.user_id, stored_cred) await db.login_user(stored_cred.user_id, stored_cred)

View File

@ -88,6 +88,7 @@ class Passkey:
user_id: UUID, user_id: UUID,
user_name: str, user_name: str,
credential_ids: list[bytes] | None = None, credential_ids: list[bytes] | None = None,
origin: str | None = None,
**regopts, **regopts,
) -> tuple[dict, bytes]: ) -> tuple[dict, bytes]:
""" """
@ -99,6 +100,7 @@ class Passkey:
credential_ids: For an already authenticated user, a list of credential IDs credential_ids: For an already authenticated user, a list of credential IDs
associated with the account. This prevents accidentally adding another associated with the account. This prevents accidentally adding another
credential on an authenticator that already has one of the listed IDs. credential on an authenticator that already has one of the listed IDs.
origin: The origin URL of the application (e.g. "https://app.example.com"). Must be a subdomain or same as rp_id, with port and scheme but no path included.
regopts: Additional arguments to generate_registration_options. regopts: Additional arguments to generate_registration_options.
Returns: Returns:
@ -126,6 +128,7 @@ class Passkey:
response_json: dict | str, response_json: dict | str,
expected_challenge: bytes, expected_challenge: bytes,
user_id: UUID, user_id: UUID,
origin: str | None = None,
) -> StoredCredential: ) -> StoredCredential:
""" """
Verify registration response. Verify registration response.
@ -138,10 +141,11 @@ class Passkey:
Registration verification result Registration verification result
""" """
credential = parse_registration_credential_json(response_json) credential = parse_registration_credential_json(response_json)
expected_origin = origin or self.origin
registration = verify_registration_response( registration = verify_registration_response(
credential=credential, credential=credential,
expected_challenge=expected_challenge, expected_challenge=expected_challenge,
expected_origin=self.origin, expected_origin=expected_origin,
expected_rp_id=self.rp_id, expected_rp_id=self.rp_id,
) )
return StoredCredential( return StoredCredential(
@ -193,6 +197,7 @@ class Passkey:
credential: AuthenticationCredential, credential: AuthenticationCredential,
expected_challenge: bytes, expected_challenge: bytes,
stored_cred: StoredCredential, stored_cred: StoredCredential,
origin: str | None = None,
) -> VerifiedAuthentication: ) -> VerifiedAuthentication:
""" """
Verify authentication response against locally stored credential data. Verify authentication response against locally stored credential data.
@ -202,11 +207,12 @@ class Passkey:
expected_challenge: The earlier generated challenge bytes expected_challenge: The earlier generated challenge bytes
stored_cred: The server stored credential record (modified by this function) stored_cred: The server stored credential record (modified by this function)
""" """
expected_origin = origin or self.origin
# Verify the authentication response # Verify the authentication response
verification = verify_authentication_response( verification = verify_authentication_response(
credential=credential, credential=credential,
expected_challenge=expected_challenge, expected_challenge=expected_challenge,
expected_origin=self.origin, expected_origin=expected_origin,
expected_rp_id=self.rp_id, expected_rp_id=self.rp_id,
credential_public_key=stored_cred.public_key, credential_public_key=stored_cred.public_key,
credential_current_sign_count=stored_cred.sign_count, credential_current_sign_count=stored_cred.sign_count,