diff --git a/passkeyauth/db.py b/passkeyauth/db.py index 1a0b0aa..bb0939d 100644 --- a/passkeyauth/db.py +++ b/passkeyauth/db.py @@ -8,7 +8,6 @@ for managing users and credentials in a WebAuthn authentication system. from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import datetime -from typing import Optional from uuid import UUID import aiosqlite @@ -86,8 +85,8 @@ class User: user_id: UUID user_name: str - created_at: Optional[datetime] = None - last_seen: Optional[datetime] = None + created_at: datetime | None = None + last_seen: datetime | None = None @asynccontextmanager @@ -124,15 +123,14 @@ class DB: ) raise ValueError("User not found") - async def create_user(self, user: User) -> User: + async def create_user(self, user: User) -> None: """Create a new user and return the User dataclass.""" await self.conn.execute( SQL_CREATE_USER, (user.user_id.bytes, user.user_name, user.created_at, user.last_seen), ) - return user - async def store_credential(self, credential: StoredCredential) -> None: + async def create_credential(self, credential: StoredCredential) -> None: """Store a credential for a user.""" await self.conn.execute( SQL_STORE_CREDENTIAL, @@ -165,7 +163,7 @@ class DB: last_used=row[6], last_verified=row[7], ) - raise ValueError("Credential not found") + raise ValueError("Credential not registered") async def get_credentials_by_user_id(self, user_id: bytes) -> list[bytes]: """Get all credential IDs for a user.""" @@ -188,9 +186,8 @@ class DB: async def login(self, user_id: bytes, credential: StoredCredential) -> None: """Update the last_seen timestamp for a user and the credential record used for logging in.""" - # Update credential + await self.conn.execute("BEGIN") await self.update_credential(credential) - # Update user's last_seen timestamp await self.conn.execute( "UPDATE users SET last_seen = ? WHERE user_id = ?", (credential.last_used, user_id), diff --git a/passkeyauth/main.py b/passkeyauth/main.py index 1f6c721..49e2c95 100644 --- a/passkeyauth/main.py +++ b/passkeyauth/main.py @@ -11,14 +11,13 @@ This module provides a simple WebAuthn implementation that: from contextlib import asynccontextmanager from pathlib import Path -from uuid import UUID +from uuid import UUID, uuid4 -import uuid7 from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles -from .db import User, db +from .db import User, connect from .passkey import Passkey STATIC_DIR = Path(__file__).parent.parent / "static" @@ -32,7 +31,8 @@ passkey = Passkey( @asynccontextmanager async def lifespan(app: FastAPI): - await db.init_database() + async with connect() as db: + await db.init_db() yield @@ -46,17 +46,19 @@ async def websocket_register_new(ws: WebSocket): try: # Data for the new user account form = await ws.receive_json() - user_id = uuid7.create() + user_id = uuid4() user_name = form["user_name"] # WebAuthn registration credential = await register_chat(ws, user_id, user_name) - credential.created_at = now # Store the user in the database - await db.create_user(User(user_id, user_name, now)) - await db.store_credential(credential) - await ws.send_json({"status": "success", "user_id": user_id.hex()}) + async with connect() as db: + await db.conn.execute("BEGIN") + await db.create_user(User(user_id, user_name)) + await db.create_credential(credential) + + await ws.send_json({"status": "success", "user_id": str(user_id)}) except WebSocketDisconnect: pass @@ -80,12 +82,15 @@ async def websocket_authenticate(ws: WebSocket): await ws.send_json(options) # Wait for the client to use his authenticator to authenticate credential = passkey.auth_parse(await ws.receive_json()) - # Fetch from the database by credential ID - stored_cred = await db.get_credential_by_id(credential.raw_id) - # Verify the credential matches the stored data - await passkey.auth_verify(credential, challenge, stored_cred) - await db.update_credential(stored_cred) + async with connect() as db: + # Fetch from the database by credential ID + stored_cred = await db.get_credential_by_id(credential.raw_id) + # Verify the credential matches the stored data + await passkey.auth_verify(credential, challenge, stored_cred) + await db.update_credential(stored_cred) await ws.send_json({"status": "success"}) + except ValueError as e: + await ws.send_json({"error": str(e)}) except WebSocketDisconnect: pass