From 1b7fa16cc0563fb8208882dbb1d01d052efa6d8b Mon Sep 17 00:00:00 2001 From: Leo Vasanko Date: Thu, 3 Jul 2025 18:46:05 -0600 Subject: [PATCH] Refactoring done, bugs gone. --- passkeyauth/db.py | 282 ++++++++++++++++++++++------------------- passkeyauth/main.py | 52 +++++--- passkeyauth/passkey.py | 36 +++--- pyproject.toml | 2 + static/app.js | 17 ++- 5 files changed, 213 insertions(+), 176 deletions(-) diff --git a/passkeyauth/db.py b/passkeyauth/db.py index 0fabd5e..263dbef 100644 --- a/passkeyauth/db.py +++ b/passkeyauth/db.py @@ -1,149 +1,167 @@ -import sqlite3 +""" +Async database implementation for WebAuthn passkey authentication. + +This module provides an async database layer using dataclasses and aiosqlite +for managing users and credentials in a WebAuthn authentication system. +""" + +from dataclasses import dataclass +from datetime import datetime +from typing import Optional + +import aiosqlite DB_PATH = "webauthn.db" - -def init_database(): - """Initialize the SQLite database with required tables""" - conn = sqlite3.connect(DB_PATH) - cursor = conn.cursor() - - # Create users table - cursor.execute( - """ - CREATE TABLE IF NOT EXISTS users ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - username TEXT UNIQUE NOT NULL, - user_id BLOB NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - """ +# SQL Statements +SQL_CREATE_USERS = """ + CREATE TABLE IF NOT EXISTS users ( + user_id BINARY(16) PRIMARY KEY NOT NULL, + user_name TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) +""" - # Create credentials table - cursor.execute( - """ - CREATE TABLE IF NOT EXISTS credentials ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id INTEGER NOT NULL, - credential_id BLOB NOT NULL, - public_key BLOB NOT NULL, - sign_count INTEGER DEFAULT 0, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (user_id) REFERENCES users (id), - UNIQUE(credential_id) - ) - """ +SQL_CREATE_CREDENTIALS = """ + CREATE TABLE IF NOT EXISTS credentials ( + credential_id BINARY(64) PRIMARY KEY NOT NULL, + user_id BINARY(16) NOT NULL, + aaguid BINARY(16) NOT NULL, + public_key BLOB NOT NULL, + sign_count INTEGER DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_used TIMESTAMP NULL, + FOREIGN KEY (user_id) REFERENCES users (user_id) ON DELETE CASCADE ) +""" - conn.commit() - conn.close() +SQL_GET_USER_BY_USER_ID = """ + SELECT * FROM users WHERE user_id = ? +""" + +SQL_CREATE_USER = """ + INSERT INTO users (user_id, user_name) VALUES (?, ?) +""" + +SQL_STORE_CREDENTIAL = """ + INSERT INTO credentials (credential_id, user_id, aaguid, public_key, sign_count) + VALUES (?, ?, ?, ?, ?) +""" + +SQL_GET_CREDENTIAL_BY_ID = """ + SELECT credential_id, user_id, aaguid, public_key, sign_count, created_at, last_used + FROM credentials + WHERE credential_id = ? +""" + +SQL_GET_USER_CREDENTIALS = """ + SELECT c.credential_id + FROM credentials c + JOIN users u ON c.user_id = u.user_id + WHERE u.user_name = ? +""" + +SQL_UPDATE_CREDENTIAL_SIGN_COUNT = """ + UPDATE credentials + SET sign_count = ?, last_used = CURRENT_TIMESTAMP + WHERE credential_id = ? +""" -def get_user_by_username(username: str) -> dict | None: - """Get user record by username""" - conn = sqlite3.connect(DB_PATH) - cursor = conn.cursor() - cursor.execute( - "SELECT id, username, user_id FROM users WHERE username = ?", (username,) - ) - row = cursor.fetchone() - conn.close() +@dataclass +class User: + """User data model.""" - if row: - return {"id": row[0], "username": row[1], "user_id": row[2]} - return None + user_id: bytes = b"" + user_name: str = "" + created_at: Optional[datetime] = None -def get_user_by_user_id(user_id: bytes) -> dict | None: - """Get user record by WebAuthn user ID""" - conn = sqlite3.connect(DB_PATH) - cursor = conn.cursor() - cursor.execute( - "SELECT id, username, user_id FROM users WHERE user_id = ?", (user_id,) - ) - row = cursor.fetchone() - conn.close() +@dataclass +class Credential: + """Credential data model.""" - if row: - return {"id": row[0], "username": row[1], "user_id": row[2]} - return None + credential_id: bytes = b"" + user_id: bytes = b"" + aaguid: bytes = b"" + public_key: bytes = b"" + sign_count: int = 0 + created_at: Optional[datetime] = None + last_used: Optional[datetime] = None -def create_user(username: str, user_id: bytes) -> int: - """Create a new user and return the user ID""" - conn = sqlite3.connect(DB_PATH) - cursor = conn.cursor() - cursor.execute( - "INSERT INTO users (username, user_id) VALUES (?, ?)", (username, user_id) - ) - user_db_id = cursor.lastrowid - conn.commit() - conn.close() - if user_db_id is None: - raise RuntimeError("Failed to create user") - return user_db_id +class Database: + """Async database handler for WebAuthn operations.""" + + def __init__(self, db_path: str = DB_PATH): + self.db_path = db_path + + async def init_database(self): + """Initialize the SQLite database with required tables.""" + async with aiosqlite.connect(self.db_path) as conn: + await conn.execute(SQL_CREATE_USERS) + await conn.execute(SQL_CREATE_CREDENTIALS) + await conn.commit() + + async def get_user_by_user_id(self, user_id: bytes) -> User: + """Get user record by WebAuthn user ID.""" + async with aiosqlite.connect(self.db_path) as conn: + async with conn.execute(SQL_GET_USER_BY_USER_ID, (user_id,)) as cursor: + row = await cursor.fetchone() + if row: + return User(user_id=row[0], user_name=row[1], created_at=row[2]) + raise ValueError("User not found") + + async def create_user(self, user_id: bytes, user_name: str) -> User: + """Create a new user and return the User dataclass.""" + async with aiosqlite.connect(self.db_path) as conn: + await conn.execute(SQL_CREATE_USER, (user_id, user_name)) + await conn.commit() + return User(user_id=user_id, user_name=user_name) + + async def store_credential(self, credential: Credential) -> None: + """Store a credential for a user.""" + async with aiosqlite.connect(self.db_path) as conn: + await conn.execute( + SQL_STORE_CREDENTIAL, + ( + credential.credential_id, + credential.user_id, + credential.aaguid, + credential.public_key, + credential.sign_count, + ), + ) + await conn.commit() + + async def get_credential_by_id(self, credential_id: bytes) -> Credential: + """Get credential by credential ID.""" + async with aiosqlite.connect(self.db_path) as conn: + async with conn.execute( + SQL_GET_CREDENTIAL_BY_ID, (credential_id,) + ) as cursor: + row = await cursor.fetchone() + if row: + return Credential( + credential_id=row[0], + user_id=row[1], + aaguid=row[2], + public_key=row[3], + sign_count=row[4], + created_at=row[5], + last_used=row[6], + ) + raise ValueError("Credential not found") + + async def update_credential(self, credential: Credential) -> None: + """Update the sign count for a credential.""" + async with aiosqlite.connect(self.db_path) as conn: + await conn.execute( + SQL_UPDATE_CREDENTIAL_SIGN_COUNT, + (credential.sign_count, credential.credential_id), + ) + await conn.commit() -def store_credential(user_db_id: int, credential_id: bytes, public_key: bytes) -> None: - """Store a credential for a user""" - conn = sqlite3.connect(DB_PATH) - cursor = conn.cursor() - cursor.execute( - "INSERT INTO credentials (user_id, credential_id, public_key) VALUES (?, ?, ?)", - (user_db_id, credential_id, public_key), - ) - conn.commit() - conn.close() - - -def get_credential_by_id(credential_id: bytes) -> dict | None: - """Get credential by credential ID""" - conn = sqlite3.connect(DB_PATH) - cursor = conn.cursor() - cursor.execute( - """ - SELECT c.public_key, c.sign_count, u.username - FROM credentials c - JOIN users u ON c.user_id = u.id - WHERE c.credential_id = ? - """, - (credential_id,), - ) - row = cursor.fetchone() - conn.close() - - if row: - return {"public_key": row[0], "sign_count": row[1], "username": row[2]} - return None - - -def get_user_credentials(username: str) -> list[bytes]: - """Get all credential IDs for a user""" - conn = sqlite3.connect(DB_PATH) - cursor = conn.cursor() - cursor.execute( - """ - SELECT c.credential_id - FROM credentials c - JOIN users u ON c.user_id = u.id - WHERE u.username = ? - """, - (username,), - ) - rows = cursor.fetchall() - conn.close() - - return [row[0] for row in rows] - - -def update_credential_sign_count(credential_id: bytes, sign_count: int) -> None: - """Update the sign count for a credential""" - conn = sqlite3.connect(DB_PATH) - cursor = conn.cursor() - cursor.execute( - "UPDATE credentials SET sign_count = ? WHERE credential_id = ?", - (sign_count, credential_id), - ) - conn.commit() - conn.close() +# Global database instance +db = Database() diff --git a/passkeyauth/main.py b/passkeyauth/main.py index 295436a..a3e7b63 100644 --- a/passkeyauth/main.py +++ b/passkeyauth/main.py @@ -5,19 +5,19 @@ This module provides a simple WebAuthn implementation that: - Uses WebSocket for real-time communication - Supports Resident Keys (discoverable credentials) for passwordless authentication - Maintains challenges locally per connection -- Uses SQLite database for persistent storage of users and credentials -- Enables true passwordless authentication where users don't need to enter a username +- Uses async SQLite database for persistent storage of users and credentials +- Enables true passwordless authentication where users don't need to enter a user_name """ from pathlib import Path -import db import uuid7 from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles -from passkeyauth.passkey import Passkey +from .db import Credential, db +from .passkey import Passkey STATIC_DIR = Path(__file__).parent.parent / "static" @@ -30,6 +30,12 @@ passkey = Passkey( app = FastAPI(title="Passkey Auth") +@app.on_event("startup") +async def startup_event(): + """Initialize database on startup.""" + await db.init_database() + + @app.websocket("/ws/new_user_registration") async def websocket_register_new(ws: WebSocket): """Register a new user and with a new passkey credential.""" @@ -38,40 +44,53 @@ async def websocket_register_new(ws: WebSocket): form = await ws.receive_json() user_id = uuid7.create().bytes user_name = form["user_name"] - await register_chat(ws, user_id, username) + + # Generate registration options and handle registration + credential, verified = await register_chat(ws, user_id, user_name) + # Store the user in the database - await db.create_user(user_name, user_id) + await db.create_user(user_id, user_name) + await db.store_credential( + Credential( + credential_id=credential.raw_id, + user_id=user_id, + aaguid=b"", # verified.aaguid, + public_key=verified.credential_public_key, + sign_count=verified.sign_count, + ) + ) await ws.send_json({"status": "success", "user_id": user_id.hex()}) except WebSocketDisconnect: pass -async def register_chat(ws: WebSocket, user_id: bytes, username: str): +async def register_chat(ws: WebSocket, user_id: bytes, user_name: str): """Generate registration options and send them to the client.""" options, challenge = passkey.reg_generate_options( user_id=user_id, - username=username, + user_name=user_name, ) - await ws.send_text(options) + await ws.send_json(options) # Wait for the client to use his authenticator to register credential = passkey.reg_credential(await ws.receive_json()) - passkey.reg_verify(credential, challenge) + verified_registration = passkey.reg_verify(credential, challenge) + return credential, verified_registration @app.websocket("/ws/authenticate") async def websocket_authenticate(ws: WebSocket): await ws.accept() try: - options = passkey.auth_generate_options() + options, challenge = await passkey.auth_generate_options() await ws.send_json(options) # Wait for the client to use his authenticator to authenticate credential = passkey.auth_credential(await ws.receive_json()) # Fetch from the database by credential ID - stored_cred = await db.fetch_credential(credential.raw_id) - # Verify the credential matches the stored data, that is also updated - passkey.auth_verify(credential, stored_cred) - # Update the credential in the database + stored_cred = await db.get_credential_by_id(credential.raw_id) + # Verify the credential matches the stored data + _ = await passkey.auth_verify(credential, challenge, stored_cred) await db.update_credential(stored_cred) + await ws.send_json({"status": "success"}) except WebSocketDisconnect: pass @@ -99,8 +118,5 @@ def main(): ) -# Initialize database on startup -db.init_database() - if __name__ == "__main__": main() diff --git a/passkeyauth/passkey.py b/passkeyauth/passkey.py index 7d58ccd..97b6598 100644 --- a/passkeyauth/passkey.py +++ b/passkeyauth/passkey.py @@ -7,6 +7,8 @@ This module provides a unified interface for WebAuthn operations including: - Credential validation """ +import json + from webauthn import ( generate_authentication_options, generate_registration_options, @@ -53,32 +55,32 @@ class Passkey: self.origin = origin self.supported_pub_key_algs = supported_pub_key_algs or [ COSEAlgorithmIdentifier.EDDSA, - # COSEAlgorithmIdentifier.ECDSA_SHA_256, - # COSEAlgorithmIdentifier.RSASSA_PKCS1_v1_5_SHA_256, + COSEAlgorithmIdentifier.ECDSA_SHA_256, + COSEAlgorithmIdentifier.RSASSA_PKCS1_v1_5_SHA_256, ] ### Registration Methods ### def reg_generate_options( - self, user_id: bytes, username: str, display_name="", **regopts - ) -> tuple[str, bytes]: + self, user_id: bytes, user_name: str, display_name="", **regopts + ) -> tuple[dict, bytes]: """ Generate registration options for WebAuthn registration. Args: user_id: The user ID as bytes - username: The username - display_name: The display name (defaults to username if empty) + user_name: The username + display_name: The display name (defaults to user_name if empty) Returns: - JSON string containing registration options + JSON dict containing options to be sent to client, challenge bytes to store """ options = generate_registration_options( rp_id=self.rp_id, rp_name=self.rp_name, user_id=user_id, - user_name=username, - user_display_name=display_name or username, + user_name=user_name, + user_display_name=display_name or user_name, authenticator_selection=AuthenticatorSelectionCriteria( resident_key=ResidentKeyRequirement.REQUIRED, user_verification=UserVerificationRequirement.PREFERRED, @@ -86,7 +88,7 @@ class Passkey: supported_pub_key_algs=self.supported_pub_key_algs, **regopts, ) - return options_to_json(options), options.challenge + return json.loads(options_to_json(options)), options.challenge @staticmethod def reg_credential(credential: dict | str) -> RegistrationCredential: @@ -119,14 +121,14 @@ class Passkey: async def auth_generate_options( self, user_verification_required=False, **kwopts - ) -> str: + ) -> tuple[dict, bytes]: """ Generate authentication options for WebAuthn authentication. Args: user_verification_required: The user will have to re-enter PIN or use biometrics for this operation. Useful when accessing security settings etc. Returns: - JSON string containing authentication options + Tuple of (JSON to be sent to client, challenge bytes to store) """ options = generate_authentication_options( rp_id=self.rp_id, @@ -137,7 +139,7 @@ class Passkey: ), **kwopts, ) - return options_to_json(options) + return json.loads(options_to_json(options)), options.challenge @staticmethod def auth_credential(credential: dict | str) -> AuthenticationCredential: @@ -148,7 +150,7 @@ class Passkey: self, credential: AuthenticationCredential, expected_challenge: bytes, - stored_cred: dict, + stored_cred, ): """ Verify authentication response against locally stored credential data. @@ -159,8 +161,8 @@ class Passkey: expected_challenge=expected_challenge, expected_origin=self.origin, expected_rp_id=self.rp_id, - credential_public_key=stored_cred["public_key"], - credential_current_sign_count=stored_cred["sign_count"], + credential_public_key=stored_cred.public_key, + credential_current_sign_count=stored_cred.sign_count, ) - stored_cred["sign_count"] = verification.new_sign_count + stored_cred.sign_count = verification.new_sign_count return verification diff --git a/pyproject.toml b/pyproject.toml index 5f636d4..74b1866 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,8 @@ dependencies = [ "websockets>=12.0", "webauthn>=1.11.1", "base64url>=1.0.0", + "aiosqlite>=0.19.0", + "uuid7-standard>=1.0.0", ] requires-python = ">=3.10" diff --git a/static/app.js b/static/app.js index 1490df3..7066267 100644 --- a/static/app.js +++ b/static/app.js @@ -1,9 +1,9 @@ const { startRegistration, startAuthentication } = SimpleWebAuthnBrowser -async function register(username) { +async function register(user_name) { + const ws = await aWebSocket('/ws/new_user_registration') + ws.send(JSON.stringify({user_name})) // Registration chat - const ws = await aWebSocket('/ws/register') - ws.send(username) const optionsJSON = JSON.parse(await ws.recv()) if (optionsJSON.error) throw new Error(optionsJSON.error) ws.send(JSON.stringify(await startRegistration({optionsJSON}))) @@ -14,10 +14,9 @@ async function register(username) { async function authenticate() { // Authentication chat const ws = await aWebSocket('/ws/authenticate') - ws.send('') // Send empty string to trigger authentication const optionsJSON = JSON.parse(await ws.recv()) if (optionsJSON.error) throw new Error(optionsJSON.error) - ws.send(JSON.stringify(await startAuthentication({optionsJSON}))) + await ws.send(JSON.stringify(await startAuthentication({optionsJSON}))) const result = JSON.parse(await ws.recv()) if (result.error) throw new Error(`Server: ${result.error}`) return result @@ -29,9 +28,9 @@ async function authenticate() { regForm.addEventListener('submit', ev => { ev.preventDefault() regSubmitBtn.disabled = true - const username = (new FormData(regForm)).get('username') - register(username).then(() => { - alert(`Registration successful for ${username}!`) + const user_name = (new FormData(regForm)).get('username') + register(user_name).then(() => { + alert(`Registration successful for ${user_name}!`) }).catch(err => { alert(`Registration failed: ${err.message}`) }).finally(() => { @@ -45,7 +44,7 @@ async function authenticate() { ev.preventDefault() authSubmitBtn.disabled = true authenticate().then(result => { - alert(`Authentication successful! Welcome ${result.username}`) + alert(`Authentication successful!`) }).catch(err => { alert(`Authentication failed: ${err.message}`) }).finally(() => {