2025-07-03 17:02:49 -06:00

150 lines
4.0 KiB
Python

import sqlite3
DB_PATH = "webauthn.db"
def init_database():
"""Initialize the SQLite database with required tables"""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# Create users table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
user_id BLOB NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
# Create credentials table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS credentials (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
credential_id BLOB NOT NULL,
public_key BLOB NOT NULL,
sign_count INTEGER DEFAULT 0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id),
UNIQUE(credential_id)
)
"""
)
conn.commit()
conn.close()
def get_user_by_username(username: str) -> dict | None:
"""Get user record by username"""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute(
"SELECT id, username, user_id FROM users WHERE username = ?", (username,)
)
row = cursor.fetchone()
conn.close()
if row:
return {"id": row[0], "username": row[1], "user_id": row[2]}
return None
def get_user_by_user_id(user_id: bytes) -> dict | None:
"""Get user record by WebAuthn user ID"""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute(
"SELECT id, username, user_id FROM users WHERE user_id = ?", (user_id,)
)
row = cursor.fetchone()
conn.close()
if row:
return {"id": row[0], "username": row[1], "user_id": row[2]}
return None
def create_user(username: str, user_id: bytes) -> int:
"""Create a new user and return the user ID"""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute(
"INSERT INTO users (username, user_id) VALUES (?, ?)", (username, user_id)
)
user_db_id = cursor.lastrowid
conn.commit()
conn.close()
if user_db_id is None:
raise RuntimeError("Failed to create user")
return user_db_id
def store_credential(user_db_id: int, credential_id: bytes, public_key: bytes) -> None:
"""Store a credential for a user"""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute(
"INSERT INTO credentials (user_id, credential_id, public_key) VALUES (?, ?, ?)",
(user_db_id, credential_id, public_key),
)
conn.commit()
conn.close()
def get_credential_by_id(credential_id: bytes) -> dict | None:
"""Get credential by credential ID"""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute(
"""
SELECT c.public_key, c.sign_count, u.username
FROM credentials c
JOIN users u ON c.user_id = u.id
WHERE c.credential_id = ?
""",
(credential_id,),
)
row = cursor.fetchone()
conn.close()
if row:
return {"public_key": row[0], "sign_count": row[1], "username": row[2]}
return None
def get_user_credentials(username: str) -> list[bytes]:
"""Get all credential IDs for a user"""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute(
"""
SELECT c.credential_id
FROM credentials c
JOIN users u ON c.user_id = u.id
WHERE u.username = ?
""",
(username,),
)
rows = cursor.fetchall()
conn.close()
return [row[0] for row in rows]
def update_credential_sign_count(credential_id: bytes, sign_count: int) -> None:
"""Update the sign count for a credential"""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute(
"UPDATE credentials SET sign_count = ? WHERE credential_id = ?",
(sign_count, credential_id),
)
conn.commit()
conn.close()