150 lines
4.0 KiB
Python
150 lines
4.0 KiB
Python
import sqlite3
|
|
|
|
DB_PATH = "webauthn.db"
|
|
|
|
|
|
def init_database():
|
|
"""Initialize the SQLite database with required tables"""
|
|
conn = sqlite3.connect(DB_PATH)
|
|
cursor = conn.cursor()
|
|
|
|
# Create users table
|
|
cursor.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
username TEXT UNIQUE NOT NULL,
|
|
user_id BLOB NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
"""
|
|
)
|
|
|
|
# Create credentials table
|
|
cursor.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS credentials (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id INTEGER NOT NULL,
|
|
credential_id BLOB NOT NULL,
|
|
public_key BLOB NOT NULL,
|
|
sign_count INTEGER DEFAULT 0,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
FOREIGN KEY (user_id) REFERENCES users (id),
|
|
UNIQUE(credential_id)
|
|
)
|
|
"""
|
|
)
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
|
|
def get_user_by_username(username: str) -> dict | None:
|
|
"""Get user record by username"""
|
|
conn = sqlite3.connect(DB_PATH)
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT id, username, user_id FROM users WHERE username = ?", (username,)
|
|
)
|
|
row = cursor.fetchone()
|
|
conn.close()
|
|
|
|
if row:
|
|
return {"id": row[0], "username": row[1], "user_id": row[2]}
|
|
return None
|
|
|
|
|
|
def get_user_by_user_id(user_id: bytes) -> dict | None:
|
|
"""Get user record by WebAuthn user ID"""
|
|
conn = sqlite3.connect(DB_PATH)
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT id, username, user_id FROM users WHERE user_id = ?", (user_id,)
|
|
)
|
|
row = cursor.fetchone()
|
|
conn.close()
|
|
|
|
if row:
|
|
return {"id": row[0], "username": row[1], "user_id": row[2]}
|
|
return None
|
|
|
|
|
|
def create_user(username: str, user_id: bytes) -> int:
|
|
"""Create a new user and return the user ID"""
|
|
conn = sqlite3.connect(DB_PATH)
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"INSERT INTO users (username, user_id) VALUES (?, ?)", (username, user_id)
|
|
)
|
|
user_db_id = cursor.lastrowid
|
|
conn.commit()
|
|
conn.close()
|
|
if user_db_id is None:
|
|
raise RuntimeError("Failed to create user")
|
|
return user_db_id
|
|
|
|
|
|
def store_credential(user_db_id: int, credential_id: bytes, public_key: bytes) -> None:
|
|
"""Store a credential for a user"""
|
|
conn = sqlite3.connect(DB_PATH)
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"INSERT INTO credentials (user_id, credential_id, public_key) VALUES (?, ?, ?)",
|
|
(user_db_id, credential_id, public_key),
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
|
|
def get_credential_by_id(credential_id: bytes) -> dict | None:
|
|
"""Get credential by credential ID"""
|
|
conn = sqlite3.connect(DB_PATH)
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""
|
|
SELECT c.public_key, c.sign_count, u.username
|
|
FROM credentials c
|
|
JOIN users u ON c.user_id = u.id
|
|
WHERE c.credential_id = ?
|
|
""",
|
|
(credential_id,),
|
|
)
|
|
row = cursor.fetchone()
|
|
conn.close()
|
|
|
|
if row:
|
|
return {"public_key": row[0], "sign_count": row[1], "username": row[2]}
|
|
return None
|
|
|
|
|
|
def get_user_credentials(username: str) -> list[bytes]:
|
|
"""Get all credential IDs for a user"""
|
|
conn = sqlite3.connect(DB_PATH)
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""
|
|
SELECT c.credential_id
|
|
FROM credentials c
|
|
JOIN users u ON c.user_id = u.id
|
|
WHERE u.username = ?
|
|
""",
|
|
(username,),
|
|
)
|
|
rows = cursor.fetchall()
|
|
conn.close()
|
|
|
|
return [row[0] for row in rows]
|
|
|
|
|
|
def update_credential_sign_count(credential_id: bytes, sign_count: int) -> None:
|
|
"""Update the sign count for a credential"""
|
|
conn = sqlite3.connect(DB_PATH)
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"UPDATE credentials SET sign_count = ? WHERE credential_id = ?",
|
|
(sign_count, credential_id),
|
|
)
|
|
conn.commit()
|
|
conn.close()
|