Refactoring for simplicity, added last login/verification tracking.

This commit is contained in:
Leo Vasanko 2025-07-06 10:34:32 -06:00
parent 0325129190
commit 66384da8ce
3 changed files with 87 additions and 113 deletions

View File

@ -12,6 +12,8 @@ from uuid import UUID
import aiosqlite import aiosqlite
from .passkey import StoredCredential
DB_PATH = "webauthn.db" DB_PATH = "webauthn.db"
# SQL Statements # SQL Statements
@ -74,25 +76,12 @@ SQL_UPDATE_CREDENTIAL_SIGN_COUNT = """
class User: class User:
"""User data model.""" """User data model."""
user_id: bytes = b"" user_id: UUID
user_name: str = "" user_name: str
created_at: Optional[datetime] = None created_at: Optional[datetime] = None
last_seen: Optional[datetime] = None last_seen: Optional[datetime] = None
@dataclass
class Credential:
"""Credential data model."""
credential_id: bytes
user_id: bytes
aaguid: UUID
public_key: bytes
sign_count: int
created_at: datetime
last_used: datetime | None = None
class Database: class Database:
"""Async database handler for WebAuthn operations.""" """Async database handler for WebAuthn operations."""
@ -130,7 +119,7 @@ class Database:
await conn.commit() await conn.commit()
return user return user
async def store_credential(self, credential: Credential) -> None: async def store_credential(self, credential: StoredCredential) -> None:
"""Store a credential for a user.""" """Store a credential for a user."""
async with aiosqlite.connect(self.db_path) as conn: async with aiosqlite.connect(self.db_path) as conn:
await conn.execute( await conn.execute(
@ -145,7 +134,7 @@ class Database:
) )
await conn.commit() await conn.commit()
async def get_credential_by_id(self, credential_id: bytes) -> Credential: async def get_credential_by_id(self, credential_id: bytes) -> StoredCredential:
"""Get credential by credential ID.""" """Get credential by credential ID."""
async with aiosqlite.connect(self.db_path) as conn: async with aiosqlite.connect(self.db_path) as conn:
async with conn.execute( async with conn.execute(
@ -153,10 +142,10 @@ class Database:
) as cursor: ) as cursor:
row = await cursor.fetchone() row = await cursor.fetchone()
if row: if row:
return Credential( return StoredCredential(
credential_id=row[0], credential_id=row[0],
user_id=row[1], user_id=UUID(bytes=row[1]),
aaguid=UUID(bytes=row[2]), # Convert bytes to UUID aaguid=UUID(bytes=row[2]),
public_key=row[3], public_key=row[3],
sign_count=row[4], sign_count=row[4],
created_at=row[5], created_at=row[5],
@ -171,7 +160,7 @@ class Database:
rows = await cursor.fetchall() rows = await cursor.fetchall()
return [row[0] for row in rows] return [row[0] for row in rows]
async def update_credential(self, credential: Credential) -> None: async def update_credential(self, credential: StoredCredential) -> None:
"""Update the sign count for a credential.""" """Update the sign count for a credential."""
async with aiosqlite.connect(self.db_path) as conn: async with aiosqlite.connect(self.db_path) as conn:
await conn.execute( await conn.execute(
@ -180,13 +169,11 @@ class Database:
) )
await conn.commit() await conn.commit()
async def update_user_last_seen( async def login(self, user_id: bytes, credential: StoredCredential) -> None:
self, user_id: bytes, last_seen: datetime | None = None
) -> None:
"""Update the last_seen timestamp for a user.""" """Update the last_seen timestamp for a user."""
if last_seen is None:
last_seen = datetime.now()
async with aiosqlite.connect(self.db_path) as conn: async with aiosqlite.connect(self.db_path) as conn:
# Do these in a single transaction
self.store_credential(credential)
await conn.execute( await conn.execute(
"UPDATE users SET last_seen = ? WHERE user_id = ?", "UPDATE users SET last_seen = ? WHERE user_id = ?",
(last_seen, user_id), (last_seen, user_id),

View File

@ -10,7 +10,6 @@ This module provides a simple WebAuthn implementation that:
""" """
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime
from pathlib import Path from pathlib import Path
from uuid import UUID from uuid import UUID
@ -19,7 +18,7 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from .db import Credential, User, db from .db import User, db
from .passkey import Passkey from .passkey import Passkey
STATIC_DIR = Path(__file__).parent.parent / "static" STATIC_DIR = Path(__file__).parent.parent / "static"
@ -45,42 +44,32 @@ async def websocket_register_new(ws: WebSocket):
"""Register a new user and with a new passkey credential.""" """Register a new user and with a new passkey credential."""
await ws.accept() await ws.accept()
try: try:
# Data for the new user account
form = await ws.receive_json() form = await ws.receive_json()
now = datetime.now() user_id = uuid7.create()
user_id = uuid7.create(now).bytes
user_name = form["user_name"] user_name = form["user_name"]
# Generate registration options and handle registration # WebAuthn registration
credential, verified = await register_chat(ws, user_id, user_name) credential = await register_chat(ws, user_id, user_name)
credential.created_at = now
# Store the user in the database # Store the user in the database
await db.create_user(User(user_id, user_name, now)) await db.create_user(User(user_id, user_name, now))
await db.store_credential( await db.store_credential(credential)
Credential(
credential_id=credential.raw_id,
user_id=user_id,
aaguid=UUID(verified.aaguid),
public_key=verified.credential_public_key,
sign_count=verified.sign_count,
created_at=now,
)
)
await ws.send_json({"status": "success", "user_id": user_id.hex()}) await ws.send_json({"status": "success", "user_id": user_id.hex()})
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass
async def register_chat(ws: WebSocket, user_id: bytes, user_name: str): async def register_chat(ws: WebSocket, user_id: UUID, user_name: str):
"""Generate registration options and send them to the client.""" """Generate registration options and send them to the client."""
options, challenge = passkey.reg_generate_options( options, challenge = passkey.reg_generate_options(
user_id=user_id, user_id=user_id,
user_name=user_name, user_name=user_name,
) )
await ws.send_json(options) await ws.send_json(options)
# Wait for the client to use his authenticator to register response = await ws.receive_json()
credential = passkey.reg_credential(await ws.receive_json()) return passkey.reg_verify(response, challenge, user_id)
verified_registration = passkey.reg_verify(credential, challenge)
return credential, verified_registration
@app.websocket("/ws/authenticate") @app.websocket("/ws/authenticate")
@ -90,11 +79,11 @@ async def websocket_authenticate(ws: WebSocket):
options, challenge = await passkey.auth_generate_options() options, challenge = await passkey.auth_generate_options()
await ws.send_json(options) await ws.send_json(options)
# Wait for the client to use his authenticator to authenticate # Wait for the client to use his authenticator to authenticate
credential = passkey.auth_credential(await ws.receive_json()) credential = passkey.auth_parse(await ws.receive_json())
# Fetch from the database by credential ID # Fetch from the database by credential ID
stored_cred = await db.get_credential_by_id(credential.raw_id) stored_cred = await db.get_credential_by_id(credential.raw_id)
# Verify the credential matches the stored data # Verify the credential matches the stored data
_ = await passkey.auth_verify(credential, challenge, stored_cred) await passkey.auth_verify(credential, challenge, stored_cred)
await db.update_credential(stored_cred) await db.update_credential(stored_cred)
await ws.send_json({"status": "success"}) await ws.send_json({"status": "success"})
except WebSocketDisconnect: except WebSocketDisconnect:

View File

@ -8,7 +8,8 @@ This module provides a unified interface for WebAuthn operations including:
""" """
import json import json
from typing import Protocol from dataclasses import dataclass
from datetime import datetime
from uuid import UUID from uuid import UUID
from webauthn import ( from webauthn import (
@ -17,6 +18,9 @@ from webauthn import (
verify_authentication_response, verify_authentication_response,
verify_registration_response, verify_registration_response,
) )
from webauthn.authentication.verify_authentication_response import (
VerifiedAuthentication,
)
from webauthn.helpers import ( from webauthn.helpers import (
options_to_json, options_to_json,
parse_authentication_credential_json, parse_authentication_credential_json,
@ -27,28 +31,23 @@ from webauthn.helpers.structs import (
AuthenticationCredential, AuthenticationCredential,
AuthenticatorSelectionCriteria, AuthenticatorSelectionCriteria,
PublicKeyCredentialDescriptor, PublicKeyCredentialDescriptor,
RegistrationCredential,
ResidentKeyRequirement, ResidentKeyRequirement,
UserVerificationRequirement, UserVerificationRequirement,
) )
from webauthn.registration.verify_registration_response import VerifiedRegistration
class StoredCredentialRecord(Protocol): @dataclass
""" class StoredCredential:
Protocol for a stored credential record that must have settable attributes: """Credential data stored in the database."""
- id: The credential ID as bytes
- aaguid: The Authenticator Attestation GUID (AAGUID)
- public_key: The public key of the credential
- sign_count: The current sign count for the credential
Note: Can be a dataclass, ORM or any other object that implements these attributes, but not dict. credential_id: bytes
""" user_id: UUID
id: bytes
aaguid: UUID aaguid: UUID
public_key: bytes public_key: bytes
sign_count: int sign_count: int
created_at: datetime
last_used: datetime | None = None
last_verified: datetime | None = None
class Passkey: class Passkey:
@ -65,10 +64,10 @@ class Passkey:
Initialize the WebAuthn handler. Initialize the WebAuthn handler.
Args: Args:
rp_id: The relying party identifier rp_id: Your security domain (e.g. "example.com")
rp_name: The relying party name rp_name: The relying party name (e.g., "My Application" - visible to users)
origin: The origin URL of the application origin: The origin URL of the application (e.g. "https://app.example.com"). Must be a subdomain or same as rp_id, with port and scheme but no path included.
supported_pub_key_algs: List of supported COSE algorithms supported_pub_key_algs: List of supported COSE algorithms (default is EDDSA, ECDSA_SHA_256, RSASSA_PKCS1_v1_5_SHA_256).
""" """
self.rp_id = rp_id self.rp_id = rp_id
self.rp_name = rp_name self.rp_name = rp_name
@ -82,7 +81,11 @@ class Passkey:
### Registration Methods ### ### Registration Methods ###
def reg_generate_options( def reg_generate_options(
self, user_id: bytes, user_name: str, **regopts self,
user_id: UUID,
user_name: str,
credential_ids: list[bytes] | None = None,
**regopts,
) -> tuple[dict, bytes]: ) -> tuple[dict, bytes]:
""" """
Generate registration options for WebAuthn registration. Generate registration options for WebAuthn registration.
@ -90,35 +93,36 @@ class Passkey:
Args: Args:
user_id: The user ID as bytes user_id: The user ID as bytes
user_name: The username user_name: The username
display_name: The display name (defaults to user_name if empty) credential_ids: For an already authenticated user, a list of credential IDs
associated with the account. This prevents accidentally adding another
credential on an authenticator that already has one of the listed IDs.
regopts: Additional arguments to generate_registration_options. regopts: Additional arguments to generate_registration_options.
Returns: Returns:
JSON dict containing options to be sent to client, challenge bytes to store JSON dict containing options to be sent to client,
challenge bytes to keep during the registration process.
""" """
options = generate_registration_options( options = generate_registration_options(
rp_id=self.rp_id, rp_id=self.rp_id,
rp_name=self.rp_name, rp_name=self.rp_name,
user_id=user_id, user_id=user_id.bytes,
user_name=user_name, user_name=user_name,
authenticator_selection=AuthenticatorSelectionCriteria( authenticator_selection=AuthenticatorSelectionCriteria(
resident_key=ResidentKeyRequirement.REQUIRED, resident_key=ResidentKeyRequirement.REQUIRED,
user_verification=UserVerificationRequirement.PREFERRED, user_verification=UserVerificationRequirement.PREFERRED,
), ),
exclude_credentials=_convert_credential_ids(credential_ids),
supported_pub_key_algs=self.supported_pub_key_algs, supported_pub_key_algs=self.supported_pub_key_algs,
**regopts, **regopts,
) )
return json.loads(options_to_json(options)), options.challenge return json.loads(options_to_json(options)), options.challenge
@staticmethod
def reg_credential(credential: dict | str) -> RegistrationCredential:
return parse_registration_credential_json(credential)
def reg_verify( def reg_verify(
self, self,
credential: RegistrationCredential, response_json: dict | str,
expected_challenge: bytes, expected_challenge: bytes,
) -> VerifiedRegistration: user_id: UUID,
) -> StoredCredential:
""" """
Verify registration response. Verify registration response.
@ -129,34 +133,21 @@ class Passkey:
Returns: Returns:
Registration verification result Registration verification result
""" """
credential = parse_registration_credential_json(response_json)
registration = verify_registration_response( registration = verify_registration_response(
credential=credential, credential=credential,
expected_challenge=expected_challenge, expected_challenge=expected_challenge,
expected_origin=self.origin, expected_origin=self.origin,
expected_rp_id=self.rp_id, expected_rp_id=self.rp_id,
) )
return registration return StoredCredential(
credential_id=credential.raw_id,
def reg_store_credential( user_id=user_id,
self, aaguid=UUID(registration.aaguid),
stored_credential: StoredCredentialRecord, public_key=registration.credential_public_key,
credential: RegistrationCredential, sign_count=registration.sign_count,
verified: VerifiedRegistration, created_at=datetime.now(),
): )
"""
Write the verified credential data to the stored credential record.
Args:
stored_credential: A database record being created (dataclass, ORM, etc.)
credential: The registration credential response from the client
verified: The verified registration data
This function sets attributes on stored_credential (id, aaguid, public_key, sign_count).
"""
stored_credential.id = credential.raw_id
stored_credential.aaguid = UUID(verified.aaguid)
stored_credential.public_key = verified.credential_public_key
stored_credential.sign_count = verified.sign_count
### Authentication Methods ### ### Authentication Methods ###
@ -164,7 +155,7 @@ class Passkey:
self, self,
*, *,
user_verification_required=False, user_verification_required=False,
allow_credential_ids: list[bytes] | None = None, credential_ids: list[bytes] | None = None,
**authopts, **authopts,
) -> tuple[dict, bytes]: ) -> tuple[dict, bytes]:
""" """
@ -172,7 +163,7 @@ class Passkey:
Args: Args:
user_verification_required: The user will have to re-enter PIN or use biometrics for this operation. Useful when accessing security settings etc. user_verification_required: The user will have to re-enter PIN or use biometrics for this operation. Useful when accessing security settings etc.
allow_credentials: For an already known user, a list of credential IDs associated with the account (less prompts during authentication). credential_ids: For an already known user, a list of credential IDs associated with the account (less prompts during authentication).
authopts: Additional arguments to generate_authentication_options. authopts: Additional arguments to generate_authentication_options.
Returns: Returns:
@ -185,33 +176,27 @@ class Passkey:
if user_verification_required if user_verification_required
else UserVerificationRequirement.PREFERRED else UserVerificationRequirement.PREFERRED
), ),
allow_credentials=( allow_credentials=_convert_credential_ids(credential_ids),
None
if allow_credential_ids is None
else [PublicKeyCredentialDescriptor(id) for id in allow_credential_ids]
),
**authopts, **authopts,
) )
return json.loads(options_to_json(options)), options.challenge return json.loads(options_to_json(options)), options.challenge
@staticmethod def auth_parse(self, response: dict | str) -> AuthenticationCredential:
def auth_credential(credential: dict | str) -> AuthenticationCredential: return parse_authentication_credential_json(response)
"""Convert the authentication credential from JSON to a dataclass instance."""
return parse_authentication_credential_json(credential)
async def auth_verify( async def auth_verify(
self, self,
credential: AuthenticationCredential, credential: AuthenticationCredential,
expected_challenge: bytes, expected_challenge: bytes,
stored_cred: StoredCredentialRecord, stored_cred: StoredCredential,
): ) -> VerifiedAuthentication:
""" """
Verify authentication response against locally stored credential data. Verify authentication response against locally stored credential data.
Args: Args:
credential: The authentication credential response from the client credential: The authentication credential response from the client
expected_challenge: The earlier generated challenge bytes expected_challenge: The earlier generated challenge bytes
stored_cred: The server stored credential record. Must have accessors .public_key and .sign_count, the latter of which is updated by this function! stored_cred: The server stored credential record (modified by this function)
""" """
# Verify the authentication response # Verify the authentication response
verification = verify_authentication_response( verification = verify_authentication_response(
@ -223,4 +208,17 @@ class Passkey:
credential_current_sign_count=stored_cred.sign_count, credential_current_sign_count=stored_cred.sign_count,
) )
stored_cred.sign_count = verification.new_sign_count stored_cred.sign_count = verification.new_sign_count
now = datetime.now()
stored_cred.last_used = now
if verification.user_verified:
stored_cred.last_verified = now
return verification return verification
def _convert_credential_ids(
credential_ids: list[bytes] | None,
) -> list[PublicKeyCredentialDescriptor] | None:
"""A helper to convert a list of credential IDs to PublicKeyCredentialDescriptor objects, or pass through None."""
if credential_ids is None:
return None
return [PublicKeyCredentialDescriptor(id) for id in credential_ids]