Use origin from requests, rather than hardcode it. This is still constrained to rp_id and its subdomains, so it should be fine for security, also WebSockets make sure the origin doesn't change between stages of the chat.
This commit is contained in:
parent
9711453553
commit
7665044032
@ -45,7 +45,6 @@ STATIC_DIR = Path(__file__).parent.parent / "static"
|
||||
passkey = Passkey(
|
||||
rp_id="localhost",
|
||||
rp_name="Passkey Auth",
|
||||
origin="http://localhost:3000",
|
||||
)
|
||||
|
||||
|
||||
@ -62,11 +61,12 @@ app = FastAPI(title="Passkey Auth", lifespan=lifespan)
|
||||
async def websocket_register_new(ws: WebSocket, user_name: str):
|
||||
"""Register a new user and with a new passkey credential."""
|
||||
await ws.accept()
|
||||
origin = ws.headers.get("origin")
|
||||
try:
|
||||
user_id = uuid4()
|
||||
|
||||
# WebAuthn registration
|
||||
credential = await register_chat(ws, user_id, user_name)
|
||||
credential = await register_chat(ws, user_id, user_name, origin=origin)
|
||||
|
||||
# Store the user and credential in the database
|
||||
await db.create_user_and_credential(
|
||||
@ -97,6 +97,7 @@ async def websocket_register_new(ws: WebSocket, user_name: str):
|
||||
async def websocket_register_add(ws: WebSocket):
|
||||
"""Register a new credential for an existing user."""
|
||||
await ws.accept()
|
||||
origin = ws.headers.get("origin")
|
||||
try:
|
||||
# Authenticate user via cookie
|
||||
cookie_header = ws.headers.get("cookie", "")
|
||||
@ -112,7 +113,9 @@ async def websocket_register_add(ws: WebSocket):
|
||||
challenge_ids = await db.get_user_credentials(user_id)
|
||||
|
||||
# WebAuthn registration
|
||||
credential = await register_chat(ws, user_id, user_name, challenge_ids)
|
||||
credential = await register_chat(
|
||||
ws, user_id, user_name, challenge_ids, origin=origin
|
||||
)
|
||||
# Store the new credential in the database
|
||||
await db.create_credential_for_user(credential)
|
||||
|
||||
@ -137,6 +140,7 @@ async def websocket_register_add(ws: WebSocket):
|
||||
async def websocket_add_device_credential(ws: WebSocket, token: str):
|
||||
"""Add a new credential for an existing user via device addition token."""
|
||||
await ws.accept()
|
||||
origin = ws.headers.get("origin")
|
||||
try:
|
||||
reset_token = await db.get_reset_token(token)
|
||||
if not reset_token:
|
||||
@ -153,11 +157,13 @@ async def websocket_add_device_credential(ws: WebSocket, token: str):
|
||||
|
||||
# Get user information
|
||||
user = await db.get_user_by_id(reset_token.user_id)
|
||||
challenge_ids = await db.get_user_credentials(reset_token.user_id)
|
||||
|
||||
# WebAuthn registration
|
||||
# Fetch challenge IDs for the user
|
||||
challenge_ids = await db.get_user_credentials(reset_token.user_id)
|
||||
|
||||
credential = await register_chat(
|
||||
ws, reset_token.user_id, user.user_name, challenge_ids
|
||||
ws, reset_token.user_id, user.user_name, challenge_ids, origin=origin
|
||||
)
|
||||
|
||||
# Store the new credential in the database
|
||||
@ -188,12 +194,14 @@ async def register_chat(
|
||||
user_id: UUID,
|
||||
user_name: str,
|
||||
credential_ids: list[bytes] | None = None,
|
||||
origin: str | None = None,
|
||||
):
|
||||
"""Generate registration options and send them to the client."""
|
||||
options, challenge = passkey.reg_generate_options(
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
credential_ids=credential_ids,
|
||||
origin=origin,
|
||||
)
|
||||
await ws.send_json(options)
|
||||
response = await ws.receive_json()
|
||||
@ -203,15 +211,16 @@ async def register_chat(
|
||||
@app.websocket("/auth/ws/authenticate")
|
||||
async def websocket_authenticate(ws: WebSocket):
|
||||
await ws.accept()
|
||||
origin = ws.headers.get("origin")
|
||||
try:
|
||||
options, challenge = passkey.auth_generate_options()
|
||||
options, challenge = passkey.auth_generate_options(origin=origin)
|
||||
await ws.send_json(options)
|
||||
# Wait for the client to use his authenticator to authenticate
|
||||
credential = passkey.auth_parse(await ws.receive_json())
|
||||
# Fetch from the database by credential ID
|
||||
stored_cred = await db.get_credential_by_id(credential.raw_id)
|
||||
# Verify the credential matches the stored data
|
||||
passkey.auth_verify(credential, challenge, stored_cred)
|
||||
passkey.auth_verify(credential, challenge, stored_cred, origin=origin)
|
||||
# Update both credential and user's last_seen timestamp
|
||||
await db.login_user(stored_cred.user_id, stored_cred)
|
||||
|
||||
|
@ -88,6 +88,7 @@ class Passkey:
|
||||
user_id: UUID,
|
||||
user_name: str,
|
||||
credential_ids: list[bytes] | None = None,
|
||||
origin: str | None = None,
|
||||
**regopts,
|
||||
) -> tuple[dict, bytes]:
|
||||
"""
|
||||
@ -99,6 +100,7 @@ class Passkey:
|
||||
credential_ids: For an already authenticated user, a list of credential IDs
|
||||
associated with the account. This prevents accidentally adding another
|
||||
credential on an authenticator that already has one of the listed IDs.
|
||||
origin: The origin URL of the application (e.g. "https://app.example.com"). Must be a subdomain or same as rp_id, with port and scheme but no path included.
|
||||
regopts: Additional arguments to generate_registration_options.
|
||||
|
||||
Returns:
|
||||
@ -126,6 +128,7 @@ class Passkey:
|
||||
response_json: dict | str,
|
||||
expected_challenge: bytes,
|
||||
user_id: UUID,
|
||||
origin: str | None = None,
|
||||
) -> StoredCredential:
|
||||
"""
|
||||
Verify registration response.
|
||||
@ -138,10 +141,11 @@ class Passkey:
|
||||
Registration verification result
|
||||
"""
|
||||
credential = parse_registration_credential_json(response_json)
|
||||
expected_origin = origin or self.origin
|
||||
registration = verify_registration_response(
|
||||
credential=credential,
|
||||
expected_challenge=expected_challenge,
|
||||
expected_origin=self.origin,
|
||||
expected_origin=expected_origin,
|
||||
expected_rp_id=self.rp_id,
|
||||
)
|
||||
return StoredCredential(
|
||||
@ -193,6 +197,7 @@ class Passkey:
|
||||
credential: AuthenticationCredential,
|
||||
expected_challenge: bytes,
|
||||
stored_cred: StoredCredential,
|
||||
origin: str | None = None,
|
||||
) -> VerifiedAuthentication:
|
||||
"""
|
||||
Verify authentication response against locally stored credential data.
|
||||
@ -202,11 +207,12 @@ class Passkey:
|
||||
expected_challenge: The earlier generated challenge bytes
|
||||
stored_cred: The server stored credential record (modified by this function)
|
||||
"""
|
||||
expected_origin = origin or self.origin
|
||||
# Verify the authentication response
|
||||
verification = verify_authentication_response(
|
||||
credential=credential,
|
||||
expected_challenge=expected_challenge,
|
||||
expected_origin=self.origin,
|
||||
expected_origin=expected_origin,
|
||||
expected_rp_id=self.rp_id,
|
||||
credential_public_key=stored_cred.public_key,
|
||||
credential_current_sign_count=stored_cred.sign_count,
|
||||
|
Loading…
x
Reference in New Issue
Block a user