diff --git a/passkeyauth/db.py b/passkeyauth/db.py index eb8dd42..99e4a63 100644 --- a/passkeyauth/db.py +++ b/passkeyauth/db.py @@ -12,6 +12,8 @@ from uuid import UUID import aiosqlite +from .passkey import StoredCredential + DB_PATH = "webauthn.db" # SQL Statements @@ -74,25 +76,12 @@ SQL_UPDATE_CREDENTIAL_SIGN_COUNT = """ class User: """User data model.""" - user_id: bytes = b"" - user_name: str = "" + user_id: UUID + user_name: str created_at: Optional[datetime] = None last_seen: Optional[datetime] = None -@dataclass -class Credential: - """Credential data model.""" - - credential_id: bytes - user_id: bytes - aaguid: UUID - public_key: bytes - sign_count: int - created_at: datetime - last_used: datetime | None = None - - class Database: """Async database handler for WebAuthn operations.""" @@ -130,7 +119,7 @@ class Database: await conn.commit() return user - async def store_credential(self, credential: Credential) -> None: + async def store_credential(self, credential: StoredCredential) -> None: """Store a credential for a user.""" async with aiosqlite.connect(self.db_path) as conn: await conn.execute( @@ -145,7 +134,7 @@ class Database: ) await conn.commit() - async def get_credential_by_id(self, credential_id: bytes) -> Credential: + async def get_credential_by_id(self, credential_id: bytes) -> StoredCredential: """Get credential by credential ID.""" async with aiosqlite.connect(self.db_path) as conn: async with conn.execute( @@ -153,10 +142,10 @@ class Database: ) as cursor: row = await cursor.fetchone() if row: - return Credential( + return StoredCredential( credential_id=row[0], - user_id=row[1], - aaguid=UUID(bytes=row[2]), # Convert bytes to UUID + user_id=UUID(bytes=row[1]), + aaguid=UUID(bytes=row[2]), public_key=row[3], sign_count=row[4], created_at=row[5], @@ -171,7 +160,7 @@ class Database: rows = await cursor.fetchall() return [row[0] for row in rows] - async def update_credential(self, credential: Credential) -> None: + async def update_credential(self, credential: StoredCredential) -> None: """Update the sign count for a credential.""" async with aiosqlite.connect(self.db_path) as conn: await conn.execute( @@ -180,13 +169,11 @@ class Database: ) await conn.commit() - async def update_user_last_seen( - self, user_id: bytes, last_seen: datetime | None = None - ) -> None: + async def login(self, user_id: bytes, credential: StoredCredential) -> None: """Update the last_seen timestamp for a user.""" - if last_seen is None: - last_seen = datetime.now() async with aiosqlite.connect(self.db_path) as conn: + # Do these in a single transaction + self.store_credential(credential) await conn.execute( "UPDATE users SET last_seen = ? WHERE user_id = ?", (last_seen, user_id), diff --git a/passkeyauth/main.py b/passkeyauth/main.py index d0fc77b..1f6c721 100644 --- a/passkeyauth/main.py +++ b/passkeyauth/main.py @@ -10,7 +10,6 @@ This module provides a simple WebAuthn implementation that: """ from contextlib import asynccontextmanager -from datetime import datetime from pathlib import Path from uuid import UUID @@ -19,7 +18,7 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles -from .db import Credential, User, db +from .db import User, db from .passkey import Passkey STATIC_DIR = Path(__file__).parent.parent / "static" @@ -45,42 +44,32 @@ async def websocket_register_new(ws: WebSocket): """Register a new user and with a new passkey credential.""" await ws.accept() try: + # Data for the new user account form = await ws.receive_json() - now = datetime.now() - user_id = uuid7.create(now).bytes + user_id = uuid7.create() user_name = form["user_name"] - # Generate registration options and handle registration - credential, verified = await register_chat(ws, user_id, user_name) + # WebAuthn registration + credential = await register_chat(ws, user_id, user_name) + credential.created_at = now # Store the user in the database await db.create_user(User(user_id, user_name, now)) - await db.store_credential( - Credential( - credential_id=credential.raw_id, - user_id=user_id, - aaguid=UUID(verified.aaguid), - public_key=verified.credential_public_key, - sign_count=verified.sign_count, - created_at=now, - ) - ) + await db.store_credential(credential) await ws.send_json({"status": "success", "user_id": user_id.hex()}) except WebSocketDisconnect: pass -async def register_chat(ws: WebSocket, user_id: bytes, user_name: str): +async def register_chat(ws: WebSocket, user_id: UUID, user_name: str): """Generate registration options and send them to the client.""" options, challenge = passkey.reg_generate_options( user_id=user_id, user_name=user_name, ) await ws.send_json(options) - # Wait for the client to use his authenticator to register - credential = passkey.reg_credential(await ws.receive_json()) - verified_registration = passkey.reg_verify(credential, challenge) - return credential, verified_registration + response = await ws.receive_json() + return passkey.reg_verify(response, challenge, user_id) @app.websocket("/ws/authenticate") @@ -90,11 +79,11 @@ async def websocket_authenticate(ws: WebSocket): options, challenge = await passkey.auth_generate_options() await ws.send_json(options) # Wait for the client to use his authenticator to authenticate - credential = passkey.auth_credential(await ws.receive_json()) + credential = passkey.auth_parse(await ws.receive_json()) # Fetch from the database by credential ID stored_cred = await db.get_credential_by_id(credential.raw_id) # Verify the credential matches the stored data - _ = await passkey.auth_verify(credential, challenge, stored_cred) + await passkey.auth_verify(credential, challenge, stored_cred) await db.update_credential(stored_cred) await ws.send_json({"status": "success"}) except WebSocketDisconnect: diff --git a/passkeyauth/passkey.py b/passkeyauth/passkey.py index 2aca32d..df55b1f 100644 --- a/passkeyauth/passkey.py +++ b/passkeyauth/passkey.py @@ -8,7 +8,8 @@ This module provides a unified interface for WebAuthn operations including: """ import json -from typing import Protocol +from dataclasses import dataclass +from datetime import datetime from uuid import UUID from webauthn import ( @@ -17,6 +18,9 @@ from webauthn import ( verify_authentication_response, verify_registration_response, ) +from webauthn.authentication.verify_authentication_response import ( + VerifiedAuthentication, +) from webauthn.helpers import ( options_to_json, parse_authentication_credential_json, @@ -27,28 +31,23 @@ from webauthn.helpers.structs import ( AuthenticationCredential, AuthenticatorSelectionCriteria, PublicKeyCredentialDescriptor, - RegistrationCredential, ResidentKeyRequirement, UserVerificationRequirement, ) -from webauthn.registration.verify_registration_response import VerifiedRegistration -class StoredCredentialRecord(Protocol): - """ - Protocol for a stored credential record that must have settable attributes: - - id: The credential ID as bytes - - aaguid: The Authenticator Attestation GUID (AAGUID) - - public_key: The public key of the credential - - sign_count: The current sign count for the credential +@dataclass +class StoredCredential: + """Credential data stored in the database.""" - Note: Can be a dataclass, ORM or any other object that implements these attributes, but not dict. - """ - - id: bytes + credential_id: bytes + user_id: UUID aaguid: UUID public_key: bytes sign_count: int + created_at: datetime + last_used: datetime | None = None + last_verified: datetime | None = None class Passkey: @@ -65,10 +64,10 @@ class Passkey: Initialize the WebAuthn handler. Args: - rp_id: The relying party identifier - rp_name: The relying party name - origin: The origin URL of the application - supported_pub_key_algs: List of supported COSE algorithms + rp_id: Your security domain (e.g. "example.com") + rp_name: The relying party name (e.g., "My Application" - visible to users) + origin: The origin URL of the application (e.g. "https://app.example.com"). Must be a subdomain or same as rp_id, with port and scheme but no path included. + supported_pub_key_algs: List of supported COSE algorithms (default is EDDSA, ECDSA_SHA_256, RSASSA_PKCS1_v1_5_SHA_256). """ self.rp_id = rp_id self.rp_name = rp_name @@ -82,7 +81,11 @@ class Passkey: ### Registration Methods ### def reg_generate_options( - self, user_id: bytes, user_name: str, **regopts + self, + user_id: UUID, + user_name: str, + credential_ids: list[bytes] | None = None, + **regopts, ) -> tuple[dict, bytes]: """ Generate registration options for WebAuthn registration. @@ -90,35 +93,36 @@ class Passkey: Args: user_id: The user ID as bytes user_name: The username - display_name: The display name (defaults to user_name if empty) + credential_ids: For an already authenticated user, a list of credential IDs + associated with the account. This prevents accidentally adding another + credential on an authenticator that already has one of the listed IDs. regopts: Additional arguments to generate_registration_options. Returns: - JSON dict containing options to be sent to client, challenge bytes to store + JSON dict containing options to be sent to client, + challenge bytes to keep during the registration process. """ options = generate_registration_options( rp_id=self.rp_id, rp_name=self.rp_name, - user_id=user_id, + user_id=user_id.bytes, user_name=user_name, authenticator_selection=AuthenticatorSelectionCriteria( resident_key=ResidentKeyRequirement.REQUIRED, user_verification=UserVerificationRequirement.PREFERRED, ), + exclude_credentials=_convert_credential_ids(credential_ids), supported_pub_key_algs=self.supported_pub_key_algs, **regopts, ) return json.loads(options_to_json(options)), options.challenge - @staticmethod - def reg_credential(credential: dict | str) -> RegistrationCredential: - return parse_registration_credential_json(credential) - def reg_verify( self, - credential: RegistrationCredential, + response_json: dict | str, expected_challenge: bytes, - ) -> VerifiedRegistration: + user_id: UUID, + ) -> StoredCredential: """ Verify registration response. @@ -129,34 +133,21 @@ class Passkey: Returns: Registration verification result """ + credential = parse_registration_credential_json(response_json) registration = verify_registration_response( credential=credential, expected_challenge=expected_challenge, expected_origin=self.origin, expected_rp_id=self.rp_id, ) - return registration - - def reg_store_credential( - self, - stored_credential: StoredCredentialRecord, - credential: RegistrationCredential, - verified: VerifiedRegistration, - ): - """ - Write the verified credential data to the stored credential record. - - Args: - stored_credential: A database record being created (dataclass, ORM, etc.) - credential: The registration credential response from the client - verified: The verified registration data - - This function sets attributes on stored_credential (id, aaguid, public_key, sign_count). - """ - stored_credential.id = credential.raw_id - stored_credential.aaguid = UUID(verified.aaguid) - stored_credential.public_key = verified.credential_public_key - stored_credential.sign_count = verified.sign_count + return StoredCredential( + credential_id=credential.raw_id, + user_id=user_id, + aaguid=UUID(registration.aaguid), + public_key=registration.credential_public_key, + sign_count=registration.sign_count, + created_at=datetime.now(), + ) ### Authentication Methods ### @@ -164,7 +155,7 @@ class Passkey: self, *, user_verification_required=False, - allow_credential_ids: list[bytes] | None = None, + credential_ids: list[bytes] | None = None, **authopts, ) -> tuple[dict, bytes]: """ @@ -172,7 +163,7 @@ class Passkey: Args: user_verification_required: The user will have to re-enter PIN or use biometrics for this operation. Useful when accessing security settings etc. - allow_credentials: For an already known user, a list of credential IDs associated with the account (less prompts during authentication). + credential_ids: For an already known user, a list of credential IDs associated with the account (less prompts during authentication). authopts: Additional arguments to generate_authentication_options. Returns: @@ -185,33 +176,27 @@ class Passkey: if user_verification_required else UserVerificationRequirement.PREFERRED ), - allow_credentials=( - None - if allow_credential_ids is None - else [PublicKeyCredentialDescriptor(id) for id in allow_credential_ids] - ), + allow_credentials=_convert_credential_ids(credential_ids), **authopts, ) return json.loads(options_to_json(options)), options.challenge - @staticmethod - def auth_credential(credential: dict | str) -> AuthenticationCredential: - """Convert the authentication credential from JSON to a dataclass instance.""" - return parse_authentication_credential_json(credential) + def auth_parse(self, response: dict | str) -> AuthenticationCredential: + return parse_authentication_credential_json(response) async def auth_verify( self, credential: AuthenticationCredential, expected_challenge: bytes, - stored_cred: StoredCredentialRecord, - ): + stored_cred: StoredCredential, + ) -> VerifiedAuthentication: """ Verify authentication response against locally stored credential data. Args: credential: The authentication credential response from the client expected_challenge: The earlier generated challenge bytes - stored_cred: The server stored credential record. Must have accessors .public_key and .sign_count, the latter of which is updated by this function! + stored_cred: The server stored credential record (modified by this function) """ # Verify the authentication response verification = verify_authentication_response( @@ -223,4 +208,17 @@ class Passkey: credential_current_sign_count=stored_cred.sign_count, ) stored_cred.sign_count = verification.new_sign_count + now = datetime.now() + stored_cred.last_used = now + if verification.user_verified: + stored_cred.last_verified = now return verification + + +def _convert_credential_ids( + credential_ids: list[bytes] | None, +) -> list[PublicKeyCredentialDescriptor] | None: + """A helper to convert a list of credential IDs to PublicKeyCredentialDescriptor objects, or pass through None.""" + if credential_ids is None: + return None + return [PublicKeyCredentialDescriptor(id) for id in credential_ids]