Finish refactoring, working.
This commit is contained in:
parent
4f8b5f837c
commit
48c5d8a831
@ -8,7 +8,6 @@ for managing users and credentials in a WebAuthn authentication system.
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
@ -86,8 +85,8 @@ class User:
|
|||||||
|
|
||||||
user_id: UUID
|
user_id: UUID
|
||||||
user_name: str
|
user_name: str
|
||||||
created_at: Optional[datetime] = None
|
created_at: datetime | None = None
|
||||||
last_seen: Optional[datetime] = None
|
last_seen: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@ -124,15 +123,14 @@ class DB:
|
|||||||
)
|
)
|
||||||
raise ValueError("User not found")
|
raise ValueError("User not found")
|
||||||
|
|
||||||
async def create_user(self, user: User) -> User:
|
async def create_user(self, user: User) -> None:
|
||||||
"""Create a new user and return the User dataclass."""
|
"""Create a new user and return the User dataclass."""
|
||||||
await self.conn.execute(
|
await self.conn.execute(
|
||||||
SQL_CREATE_USER,
|
SQL_CREATE_USER,
|
||||||
(user.user_id.bytes, user.user_name, user.created_at, user.last_seen),
|
(user.user_id.bytes, user.user_name, user.created_at, user.last_seen),
|
||||||
)
|
)
|
||||||
return user
|
|
||||||
|
|
||||||
async def store_credential(self, credential: StoredCredential) -> None:
|
async def create_credential(self, credential: StoredCredential) -> None:
|
||||||
"""Store a credential for a user."""
|
"""Store a credential for a user."""
|
||||||
await self.conn.execute(
|
await self.conn.execute(
|
||||||
SQL_STORE_CREDENTIAL,
|
SQL_STORE_CREDENTIAL,
|
||||||
@ -165,7 +163,7 @@ class DB:
|
|||||||
last_used=row[6],
|
last_used=row[6],
|
||||||
last_verified=row[7],
|
last_verified=row[7],
|
||||||
)
|
)
|
||||||
raise ValueError("Credential not found")
|
raise ValueError("Credential not registered")
|
||||||
|
|
||||||
async def get_credentials_by_user_id(self, user_id: bytes) -> list[bytes]:
|
async def get_credentials_by_user_id(self, user_id: bytes) -> list[bytes]:
|
||||||
"""Get all credential IDs for a user."""
|
"""Get all credential IDs for a user."""
|
||||||
@ -188,9 +186,8 @@ class DB:
|
|||||||
|
|
||||||
async def login(self, user_id: bytes, credential: StoredCredential) -> None:
|
async def login(self, user_id: bytes, credential: StoredCredential) -> None:
|
||||||
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
|
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
|
||||||
# Update credential
|
await self.conn.execute("BEGIN")
|
||||||
await self.update_credential(credential)
|
await self.update_credential(credential)
|
||||||
# Update user's last_seen timestamp
|
|
||||||
await self.conn.execute(
|
await self.conn.execute(
|
||||||
"UPDATE users SET last_seen = ? WHERE user_id = ?",
|
"UPDATE users SET last_seen = ? WHERE user_id = ?",
|
||||||
(credential.last_used, user_id),
|
(credential.last_used, user_id),
|
||||||
|
@ -11,14 +11,13 @@ This module provides a simple WebAuthn implementation that:
|
|||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from uuid import UUID
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
import uuid7
|
|
||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
from .db import User, db
|
from .db import User, connect
|
||||||
from .passkey import Passkey
|
from .passkey import Passkey
|
||||||
|
|
||||||
STATIC_DIR = Path(__file__).parent.parent / "static"
|
STATIC_DIR = Path(__file__).parent.parent / "static"
|
||||||
@ -32,7 +31,8 @@ passkey = Passkey(
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
await db.init_database()
|
async with connect() as db:
|
||||||
|
await db.init_db()
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@ -46,17 +46,19 @@ async def websocket_register_new(ws: WebSocket):
|
|||||||
try:
|
try:
|
||||||
# Data for the new user account
|
# Data for the new user account
|
||||||
form = await ws.receive_json()
|
form = await ws.receive_json()
|
||||||
user_id = uuid7.create()
|
user_id = uuid4()
|
||||||
user_name = form["user_name"]
|
user_name = form["user_name"]
|
||||||
|
|
||||||
# WebAuthn registration
|
# WebAuthn registration
|
||||||
credential = await register_chat(ws, user_id, user_name)
|
credential = await register_chat(ws, user_id, user_name)
|
||||||
|
|
||||||
credential.created_at = now
|
|
||||||
# Store the user in the database
|
# Store the user in the database
|
||||||
await db.create_user(User(user_id, user_name, now))
|
async with connect() as db:
|
||||||
await db.store_credential(credential)
|
await db.conn.execute("BEGIN")
|
||||||
await ws.send_json({"status": "success", "user_id": user_id.hex()})
|
await db.create_user(User(user_id, user_name))
|
||||||
|
await db.create_credential(credential)
|
||||||
|
|
||||||
|
await ws.send_json({"status": "success", "user_id": str(user_id)})
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -80,12 +82,15 @@ async def websocket_authenticate(ws: WebSocket):
|
|||||||
await ws.send_json(options)
|
await ws.send_json(options)
|
||||||
# Wait for the client to use his authenticator to authenticate
|
# Wait for the client to use his authenticator to authenticate
|
||||||
credential = passkey.auth_parse(await ws.receive_json())
|
credential = passkey.auth_parse(await ws.receive_json())
|
||||||
# Fetch from the database by credential ID
|
async with connect() as db:
|
||||||
stored_cred = await db.get_credential_by_id(credential.raw_id)
|
# Fetch from the database by credential ID
|
||||||
# Verify the credential matches the stored data
|
stored_cred = await db.get_credential_by_id(credential.raw_id)
|
||||||
await passkey.auth_verify(credential, challenge, stored_cred)
|
# Verify the credential matches the stored data
|
||||||
await db.update_credential(stored_cred)
|
await passkey.auth_verify(credential, challenge, stored_cred)
|
||||||
|
await db.update_credential(stored_cred)
|
||||||
await ws.send_json({"status": "success"})
|
await ws.send_json({"status": "success"})
|
||||||
|
except ValueError as e:
|
||||||
|
await ws.send_json({"error": str(e)})
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user