Finish refactoring, working.

This commit is contained in:
Leo Vasanko 2025-07-06 12:06:22 -06:00
parent 4f8b5f837c
commit 48c5d8a831
2 changed files with 25 additions and 23 deletions

View File

@ -8,7 +8,6 @@ for managing users and credentials in a WebAuthn authentication system.
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Optional
from uuid import UUID from uuid import UUID
import aiosqlite import aiosqlite
@ -86,8 +85,8 @@ class User:
user_id: UUID user_id: UUID
user_name: str user_name: str
created_at: Optional[datetime] = None created_at: datetime | None = None
last_seen: Optional[datetime] = None last_seen: datetime | None = None
@asynccontextmanager @asynccontextmanager
@ -124,15 +123,14 @@ class DB:
) )
raise ValueError("User not found") raise ValueError("User not found")
async def create_user(self, user: User) -> User: async def create_user(self, user: User) -> None:
"""Create a new user and return the User dataclass.""" """Create a new user and return the User dataclass."""
await self.conn.execute( await self.conn.execute(
SQL_CREATE_USER, SQL_CREATE_USER,
(user.user_id.bytes, user.user_name, user.created_at, user.last_seen), (user.user_id.bytes, user.user_name, user.created_at, user.last_seen),
) )
return user
async def store_credential(self, credential: StoredCredential) -> None: async def create_credential(self, credential: StoredCredential) -> None:
"""Store a credential for a user.""" """Store a credential for a user."""
await self.conn.execute( await self.conn.execute(
SQL_STORE_CREDENTIAL, SQL_STORE_CREDENTIAL,
@ -165,7 +163,7 @@ class DB:
last_used=row[6], last_used=row[6],
last_verified=row[7], last_verified=row[7],
) )
raise ValueError("Credential not found") raise ValueError("Credential not registered")
async def get_credentials_by_user_id(self, user_id: bytes) -> list[bytes]: async def get_credentials_by_user_id(self, user_id: bytes) -> list[bytes]:
"""Get all credential IDs for a user.""" """Get all credential IDs for a user."""
@ -188,9 +186,8 @@ class DB:
async def login(self, user_id: bytes, credential: StoredCredential) -> None: async def login(self, user_id: bytes, credential: StoredCredential) -> None:
"""Update the last_seen timestamp for a user and the credential record used for logging in.""" """Update the last_seen timestamp for a user and the credential record used for logging in."""
# Update credential await self.conn.execute("BEGIN")
await self.update_credential(credential) await self.update_credential(credential)
# Update user's last_seen timestamp
await self.conn.execute( await self.conn.execute(
"UPDATE users SET last_seen = ? WHERE user_id = ?", "UPDATE users SET last_seen = ? WHERE user_id = ?",
(credential.last_used, user_id), (credential.last_used, user_id),

View File

@ -11,14 +11,13 @@ This module provides a simple WebAuthn implementation that:
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from uuid import UUID from uuid import UUID, uuid4
import uuid7
from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from .db import User, db from .db import User, connect
from .passkey import Passkey from .passkey import Passkey
STATIC_DIR = Path(__file__).parent.parent / "static" STATIC_DIR = Path(__file__).parent.parent / "static"
@ -32,7 +31,8 @@ passkey = Passkey(
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
await db.init_database() async with connect() as db:
await db.init_db()
yield yield
@ -46,17 +46,19 @@ async def websocket_register_new(ws: WebSocket):
try: try:
# Data for the new user account # Data for the new user account
form = await ws.receive_json() form = await ws.receive_json()
user_id = uuid7.create() user_id = uuid4()
user_name = form["user_name"] user_name = form["user_name"]
# WebAuthn registration # WebAuthn registration
credential = await register_chat(ws, user_id, user_name) credential = await register_chat(ws, user_id, user_name)
credential.created_at = now
# Store the user in the database # Store the user in the database
await db.create_user(User(user_id, user_name, now)) async with connect() as db:
await db.store_credential(credential) await db.conn.execute("BEGIN")
await ws.send_json({"status": "success", "user_id": user_id.hex()}) await db.create_user(User(user_id, user_name))
await db.create_credential(credential)
await ws.send_json({"status": "success", "user_id": str(user_id)})
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass
@ -80,12 +82,15 @@ async def websocket_authenticate(ws: WebSocket):
await ws.send_json(options) await ws.send_json(options)
# Wait for the client to use his authenticator to authenticate # Wait for the client to use his authenticator to authenticate
credential = passkey.auth_parse(await ws.receive_json()) credential = passkey.auth_parse(await ws.receive_json())
# Fetch from the database by credential ID async with connect() as db:
stored_cred = await db.get_credential_by_id(credential.raw_id) # Fetch from the database by credential ID
# Verify the credential matches the stored data stored_cred = await db.get_credential_by_id(credential.raw_id)
await passkey.auth_verify(credential, challenge, stored_cred) # Verify the credential matches the stored data
await db.update_credential(stored_cred) await passkey.auth_verify(credential, challenge, stored_cred)
await db.update_credential(stored_cred)
await ws.send_json({"status": "success"}) await ws.send_json({"status": "success"})
except ValueError as e:
await ws.send_json({"error": str(e)})
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass