More refactoring. Prevent registering another key on the same authenticator for the same user.
This commit is contained in:
parent
eb56c000e8
commit
1c9044054a
@ -10,8 +10,8 @@ This module contains all the HTTP API endpoints for:
|
|||||||
|
|
||||||
from fastapi import Request, Response
|
from fastapi import Request, Response
|
||||||
|
|
||||||
|
from . import db
|
||||||
from .aaguid_manager import get_aaguid_manager
|
from .aaguid_manager import get_aaguid_manager
|
||||||
from .db import connect
|
|
||||||
from .jwt_manager import refresh_session_token, validate_session_token
|
from .jwt_manager import refresh_session_token, validate_session_token
|
||||||
from .session_manager import (
|
from .session_manager import (
|
||||||
clear_session_cookie,
|
clear_session_cookie,
|
||||||
@ -57,9 +57,8 @@ async def get_user_credentials(request: Request) -> dict:
|
|||||||
if token_data:
|
if token_data:
|
||||||
current_credential_id = token_data.get("credential_id")
|
current_credential_id = token_data.get("credential_id")
|
||||||
|
|
||||||
async with connect() as db:
|
|
||||||
# Get all credentials for the user
|
# Get all credentials for the user
|
||||||
credential_ids = await db.get_credentials_by_user_id(user.user_id.bytes)
|
credential_ids = await db.get_user_credentials(user.user_id)
|
||||||
|
|
||||||
credentials = []
|
credentials = []
|
||||||
user_aaguids = set()
|
user_aaguids = set()
|
||||||
@ -73,9 +72,7 @@ async def get_user_credentials(request: Request) -> dict:
|
|||||||
user_aaguids.add(aaguid_str)
|
user_aaguids.add(aaguid_str)
|
||||||
|
|
||||||
# Check if this is the current session credential
|
# Check if this is the current session credential
|
||||||
is_current_session = (
|
is_current_session = current_credential_id == stored_cred.credential_id
|
||||||
current_credential_id == stored_cred.credential_id
|
|
||||||
)
|
|
||||||
|
|
||||||
credentials.append(
|
credentials.append(
|
||||||
{
|
{
|
||||||
@ -212,7 +209,6 @@ async def delete_credential(request: Request) -> dict:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
return {"error": "Invalid credential_id format"}
|
return {"error": "Invalid credential_id format"}
|
||||||
|
|
||||||
async with connect() as db:
|
|
||||||
# First, verify the credential belongs to the current user
|
# First, verify the credential belongs to the current user
|
||||||
try:
|
try:
|
||||||
stored_cred = await db.get_credential_by_id(credential_id_bytes)
|
stored_cred = await db.get_credential_by_id(credential_id_bytes)
|
||||||
@ -225,21 +221,16 @@ async def delete_credential(request: Request) -> dict:
|
|||||||
session_token = get_session_token_from_request(request)
|
session_token = get_session_token_from_request(request)
|
||||||
if session_token:
|
if session_token:
|
||||||
token_data = validate_session_token(session_token)
|
token_data = validate_session_token(session_token)
|
||||||
if (
|
if token_data and token_data.get("credential_id") == credential_id_bytes:
|
||||||
token_data
|
|
||||||
and token_data.get("credential_id") == credential_id_bytes
|
|
||||||
):
|
|
||||||
return {"error": "Cannot delete current session credential"}
|
return {"error": "Cannot delete current session credential"}
|
||||||
|
|
||||||
# Get user's remaining credentials count
|
# Get user's remaining credentials count
|
||||||
remaining_credentials = await db.get_credentials_by_user_id(
|
remaining_credentials = await db.get_user_credentials(user.user_id)
|
||||||
user.user_id.bytes
|
|
||||||
)
|
|
||||||
if len(remaining_credentials) <= 1:
|
if len(remaining_credentials) <= 1:
|
||||||
return {"error": "Cannot delete last remaining credential"}
|
return {"error": "Cannot delete last remaining credential"}
|
||||||
|
|
||||||
# Delete the credential
|
# Delete the credential
|
||||||
await db.delete_credential(credential_id_bytes)
|
await db.delete_user_credential(credential_id_bytes)
|
||||||
|
|
||||||
return {"status": "success", "message": "Credential deleted successfully"}
|
return {"status": "success", "message": "Credential deleted successfully"}
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Async database implementation for WebAuthn passkey authentication.
|
Async database implementation for WebAuthn passkey authentication.
|
||||||
|
|
||||||
This module provides an async database layer using dataclasses and aiosqlite
|
This module provides an async database layer using SQLAlchemy async mode
|
||||||
for managing users and credentials in a WebAuthn authentication system.
|
for managing users and credentials in a WebAuthn authentication system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -10,66 +10,60 @@ from dataclasses import dataclass
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import aiosqlite
|
from sqlalchemy import (
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Integer,
|
||||||
|
LargeBinary,
|
||||||
|
String,
|
||||||
|
delete,
|
||||||
|
select,
|
||||||
|
update,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.sqlite import BLOB
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from .passkey import StoredCredential
|
from .passkey import StoredCredential
|
||||||
|
|
||||||
DB_PATH = "webauthn.db"
|
DB_PATH = "sqlite+aiosqlite:///webauthn.db"
|
||||||
|
|
||||||
# SQL Statements
|
|
||||||
SQL_CREATE_USERS = """
|
# SQLAlchemy Models
|
||||||
CREATE TABLE IF NOT EXISTS users (
|
class Base(DeclarativeBase):
|
||||||
user_id BINARY(16) PRIMARY KEY NOT NULL,
|
pass
|
||||||
user_name TEXT NOT NULL,
|
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
last_seen TIMESTAMP NULL
|
class UserModel(Base):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
user_id: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
|
||||||
|
user_name: Mapped[str] = mapped_column(String, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||||
|
last_seen: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
|
|
||||||
|
# Relationship to credentials
|
||||||
|
credentials: Mapped[list["CredentialModel"]] = relationship(
|
||||||
|
"CredentialModel", back_populates="user", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
"""
|
|
||||||
|
|
||||||
SQL_CREATE_CREDENTIALS = """
|
|
||||||
CREATE TABLE IF NOT EXISTS credentials (
|
class CredentialModel(Base):
|
||||||
credential_id BINARY(64) PRIMARY KEY NOT NULL,
|
__tablename__ = "credentials"
|
||||||
user_id BINARY(16) NOT NULL,
|
|
||||||
aaguid BINARY(16) NOT NULL,
|
credential_id: Mapped[bytes] = mapped_column(LargeBinary(64), primary_key=True)
|
||||||
public_key BLOB NOT NULL,
|
user_id: Mapped[bytes] = mapped_column(
|
||||||
sign_count INTEGER NOT NULL,
|
LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
last_used TIMESTAMP NULL,
|
|
||||||
last_verified TIMESTAMP NULL,
|
|
||||||
FOREIGN KEY (user_id) REFERENCES users (user_id) ON DELETE CASCADE
|
|
||||||
)
|
)
|
||||||
"""
|
aaguid: Mapped[bytes] = mapped_column(LargeBinary(16), nullable=False)
|
||||||
|
public_key: Mapped[bytes] = mapped_column(BLOB, nullable=False)
|
||||||
|
sign_count: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||||
|
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
|
last_verified: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
|
|
||||||
SQL_GET_USER_BY_USER_ID = """
|
# Relationship to user
|
||||||
SELECT * FROM users WHERE user_id = ?
|
user: Mapped["UserModel"] = relationship("UserModel", back_populates="credentials")
|
||||||
"""
|
|
||||||
|
|
||||||
SQL_CREATE_USER = """
|
|
||||||
INSERT INTO users (user_id, user_name, created_at, last_seen) VALUES (?, ?, ?, ?)
|
|
||||||
"""
|
|
||||||
|
|
||||||
SQL_STORE_CREDENTIAL = """
|
|
||||||
INSERT INTO credentials (credential_id, user_id, aaguid, public_key, sign_count, created_at, last_used, last_verified)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
"""
|
|
||||||
|
|
||||||
SQL_GET_CREDENTIAL_BY_ID = """
|
|
||||||
SELECT * FROM credentials WHERE credential_id = ?
|
|
||||||
"""
|
|
||||||
|
|
||||||
SQL_GET_USER_CREDENTIALS = """
|
|
||||||
SELECT credential_id FROM credentials WHERE user_id = ?
|
|
||||||
"""
|
|
||||||
|
|
||||||
SQL_UPDATE_CREDENTIAL = """
|
|
||||||
UPDATE credentials
|
|
||||||
SET sign_count = ?, created_at = ?, last_used = ?, last_verified = ?
|
|
||||||
WHERE credential_id = ?
|
|
||||||
"""
|
|
||||||
|
|
||||||
SQL_DELETE_CREDENTIAL = """
|
|
||||||
DELETE FROM credentials WHERE credential_id = ?
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -80,121 +74,181 @@ class User:
|
|||||||
last_seen: datetime | None = None
|
last_seen: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# Global engine and session factory
|
||||||
|
engine = create_async_engine(DB_PATH, echo=False)
|
||||||
|
async_session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def connect():
|
async def connect():
|
||||||
conn = await aiosqlite.connect(DB_PATH)
|
"""Context manager for database connections."""
|
||||||
try:
|
async with async_session_factory() as session:
|
||||||
yield DB(conn)
|
yield DB(session)
|
||||||
await conn.commit()
|
await session.commit()
|
||||||
finally:
|
|
||||||
await conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
class DB:
|
class DB:
|
||||||
def __init__(self, conn: aiosqlite.Connection):
|
def __init__(self, session: AsyncSession):
|
||||||
self.conn = conn
|
self.session = session
|
||||||
|
|
||||||
async def init_db(self) -> None:
|
async def init_db(self) -> None:
|
||||||
"""Initialize database tables."""
|
"""Initialize database tables."""
|
||||||
await self.conn.execute(SQL_CREATE_USERS)
|
async with engine.begin() as conn:
|
||||||
await self.conn.execute(SQL_CREATE_CREDENTIALS)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
await self.conn.commit()
|
|
||||||
|
|
||||||
# Database operation functions that work with a connection
|
async def get_user_by_user_id(self, user_id: UUID) -> User:
|
||||||
async def get_user_by_user_id(self, user_id: bytes) -> User:
|
|
||||||
"""Get user record by WebAuthn user ID."""
|
"""Get user record by WebAuthn user ID."""
|
||||||
async with self.conn.execute(SQL_GET_USER_BY_USER_ID, (user_id,)) as cursor:
|
stmt = select(UserModel).where(UserModel.user_id == user_id.bytes)
|
||||||
row = await cursor.fetchone()
|
result = await self.session.execute(stmt)
|
||||||
if row:
|
user_model = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if user_model:
|
||||||
return User(
|
return User(
|
||||||
user_id=UUID(bytes=row[0]),
|
user_id=UUID(bytes=user_model.user_id),
|
||||||
user_name=row[1],
|
user_name=user_model.user_name,
|
||||||
created_at=_convert_datetime(row[2]),
|
created_at=user_model.created_at,
|
||||||
last_seen=_convert_datetime(row[3]),
|
last_seen=user_model.last_seen,
|
||||||
)
|
)
|
||||||
raise ValueError("User not found")
|
raise ValueError("User not found")
|
||||||
|
|
||||||
async def create_user(self, user: User) -> None:
|
async def create_user(self, user: User) -> None:
|
||||||
"""Create a new user and return the User dataclass."""
|
"""Create a new user."""
|
||||||
await self.conn.execute(
|
user_model = UserModel(
|
||||||
SQL_CREATE_USER,
|
user_id=user.user_id.bytes,
|
||||||
(
|
user_name=user.user_name,
|
||||||
user.user_id.bytes,
|
created_at=user.created_at or datetime.now(),
|
||||||
user.user_name,
|
last_seen=user.last_seen,
|
||||||
user.created_at or datetime.now(),
|
|
||||||
user.last_seen,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
self.session.add(user_model)
|
||||||
|
await self.session.flush()
|
||||||
|
|
||||||
async def create_credential(self, credential: StoredCredential) -> None:
|
async def create_credential(self, credential: StoredCredential) -> None:
|
||||||
"""Store a credential for a user."""
|
"""Store a credential for a user."""
|
||||||
await self.conn.execute(
|
credential_model = CredentialModel(
|
||||||
SQL_STORE_CREDENTIAL,
|
credential_id=credential.credential_id,
|
||||||
(
|
user_id=credential.user_id.bytes,
|
||||||
credential.credential_id,
|
aaguid=credential.aaguid.bytes,
|
||||||
credential.user_id.bytes,
|
public_key=credential.public_key,
|
||||||
credential.aaguid.bytes,
|
sign_count=credential.sign_count,
|
||||||
credential.public_key,
|
created_at=credential.created_at,
|
||||||
credential.sign_count,
|
last_used=credential.last_used,
|
||||||
credential.created_at,
|
last_verified=credential.last_verified,
|
||||||
credential.last_used,
|
|
||||||
credential.last_verified,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
self.session.add(credential_model)
|
||||||
|
await self.session.flush()
|
||||||
|
|
||||||
async def get_credential_by_id(self, credential_id: bytes) -> StoredCredential:
|
async def get_credential_by_id(self, credential_id: bytes) -> StoredCredential:
|
||||||
"""Get credential by credential ID."""
|
"""Get credential by credential ID."""
|
||||||
async with self.conn.execute(
|
stmt = select(CredentialModel).where(
|
||||||
SQL_GET_CREDENTIAL_BY_ID, (credential_id,)
|
CredentialModel.credential_id == credential_id
|
||||||
) as cursor:
|
)
|
||||||
row = await cursor.fetchone()
|
result = await self.session.execute(stmt)
|
||||||
if row:
|
credential_model = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if credential_model:
|
||||||
return StoredCredential(
|
return StoredCredential(
|
||||||
credential_id=row[0],
|
credential_id=credential_model.credential_id,
|
||||||
user_id=UUID(bytes=row[1]),
|
user_id=UUID(bytes=credential_model.user_id),
|
||||||
aaguid=UUID(bytes=row[2]),
|
aaguid=UUID(bytes=credential_model.aaguid),
|
||||||
public_key=row[3],
|
public_key=credential_model.public_key,
|
||||||
sign_count=row[4],
|
sign_count=credential_model.sign_count,
|
||||||
created_at=datetime.fromisoformat(row[5]),
|
created_at=credential_model.created_at,
|
||||||
last_used=_convert_datetime(row[6]),
|
last_used=credential_model.last_used,
|
||||||
last_verified=_convert_datetime(row[7]),
|
last_verified=credential_model.last_verified,
|
||||||
)
|
)
|
||||||
raise ValueError("Credential not registered")
|
raise ValueError("Credential not registered")
|
||||||
|
|
||||||
async def get_credentials_by_user_id(self, user_id: bytes) -> list[bytes]:
|
async def get_credentials_by_user_id(self, user_id: UUID) -> list[bytes]:
|
||||||
"""Get all credential IDs for a user."""
|
"""Get all credential IDs for a user."""
|
||||||
async with self.conn.execute(SQL_GET_USER_CREDENTIALS, (user_id,)) as cursor:
|
stmt = select(CredentialModel.credential_id).where(
|
||||||
rows = await cursor.fetchall()
|
CredentialModel.user_id == user_id.bytes
|
||||||
return [row[0] for row in rows]
|
)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return [row[0] for row in result.fetchall()]
|
||||||
|
|
||||||
async def update_credential(self, credential: StoredCredential) -> None:
|
async def update_credential(self, credential: StoredCredential) -> None:
|
||||||
"""Update the sign count, created_at, last_used, and last_verified for a credential."""
|
"""Update the sign count, created_at, last_used, and last_verified for a credential."""
|
||||||
await self.conn.execute(
|
stmt = (
|
||||||
SQL_UPDATE_CREDENTIAL,
|
update(CredentialModel)
|
||||||
(
|
.where(CredentialModel.credential_id == credential.credential_id)
|
||||||
credential.sign_count,
|
.values(
|
||||||
credential.created_at,
|
sign_count=credential.sign_count,
|
||||||
credential.last_used,
|
created_at=credential.created_at,
|
||||||
credential.last_verified,
|
last_used=credential.last_used,
|
||||||
credential.credential_id,
|
last_verified=credential.last_verified,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
await self.session.execute(stmt)
|
||||||
|
|
||||||
async def login(self, user_id: bytes, credential: StoredCredential) -> None:
|
async def login(self, user_id: UUID, credential: StoredCredential) -> None:
|
||||||
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
|
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
|
||||||
await self.conn.execute("BEGIN")
|
async with self.session.begin():
|
||||||
|
# Update credential
|
||||||
await self.update_credential(credential)
|
await self.update_credential(credential)
|
||||||
await self.conn.execute(
|
|
||||||
"UPDATE users SET last_seen = ? WHERE user_id = ?",
|
# Update user's last_seen
|
||||||
(credential.last_used, user_id),
|
stmt = (
|
||||||
|
update(UserModel)
|
||||||
|
.where(UserModel.user_id == user_id.bytes)
|
||||||
|
.values(last_seen=credential.last_used)
|
||||||
)
|
)
|
||||||
|
await self.session.execute(stmt)
|
||||||
|
|
||||||
async def delete_credential(self, credential_id: bytes) -> None:
|
async def delete_credential(self, credential_id: bytes) -> None:
|
||||||
"""Delete a credential by its ID."""
|
"""Delete a credential by its ID."""
|
||||||
await self.conn.execute(SQL_DELETE_CREDENTIAL, (credential_id,))
|
stmt = delete(CredentialModel).where(
|
||||||
await self.conn.commit()
|
CredentialModel.credential_id == credential_id
|
||||||
|
)
|
||||||
|
await self.session.execute(stmt)
|
||||||
|
await self.session.commit()
|
||||||
|
|
||||||
|
|
||||||
def _convert_datetime(val):
|
# Standalone functions that handle database connections internally
|
||||||
"""Convert string from SQLite to datetime object (pass through None)."""
|
async def init_database() -> None:
|
||||||
return val and datetime.fromisoformat(val)
|
"""Initialize database tables."""
|
||||||
|
async with connect() as db:
|
||||||
|
await db.init_db()
|
||||||
|
|
||||||
|
|
||||||
|
async def create_user_and_credential(user: User, credential: StoredCredential) -> None:
|
||||||
|
"""Create a new user and their first credential in a single transaction."""
|
||||||
|
async with connect() as db:
|
||||||
|
await db.session.begin()
|
||||||
|
await db.create_user(user)
|
||||||
|
await db.create_credential(credential)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_by_id(user_id: UUID) -> User:
|
||||||
|
"""Get user record by WebAuthn user ID."""
|
||||||
|
async with connect() as db:
|
||||||
|
return await db.get_user_by_user_id(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_credential_for_user(credential: StoredCredential) -> None:
|
||||||
|
"""Store a credential for an existing user."""
|
||||||
|
async with connect() as db:
|
||||||
|
await db.create_credential(credential)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_credential_by_id(credential_id: bytes) -> StoredCredential:
|
||||||
|
"""Get credential by credential ID."""
|
||||||
|
async with connect() as db:
|
||||||
|
return await db.get_credential_by_id(credential_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_credentials(user_id: UUID) -> list[bytes]:
|
||||||
|
"""Get all credential IDs for a user."""
|
||||||
|
async with connect() as db:
|
||||||
|
return await db.get_credentials_by_user_id(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def login_user(user_id: UUID, credential: StoredCredential) -> None:
|
||||||
|
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
|
||||||
|
async with connect() as db:
|
||||||
|
await db.login(user_id, credential)
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_user_credential(credential_id: bytes) -> None:
|
||||||
|
"""Delete a credential by its ID."""
|
||||||
|
async with connect() as db:
|
||||||
|
await db.delete_credential(credential_id)
|
||||||
|
@ -18,6 +18,7 @@ from fastapi import FastAPI, Request, Response, WebSocket, WebSocketDisconnect
|
|||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
from . import db
|
||||||
from .api_handlers import (
|
from .api_handlers import (
|
||||||
delete_credential,
|
delete_credential,
|
||||||
get_user_credentials,
|
get_user_credentials,
|
||||||
@ -27,7 +28,7 @@ from .api_handlers import (
|
|||||||
set_session,
|
set_session,
|
||||||
validate_token,
|
validate_token,
|
||||||
)
|
)
|
||||||
from .db import User, connect
|
from .db import User
|
||||||
from .jwt_manager import create_session_token
|
from .jwt_manager import create_session_token
|
||||||
from .passkey import Passkey
|
from .passkey import Passkey
|
||||||
from .session_manager import get_user_from_cookie_string
|
from .session_manager import get_user_from_cookie_string
|
||||||
@ -44,8 +45,7 @@ passkey = Passkey(
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
async with connect() as db:
|
await db.init_database()
|
||||||
await db.init_db()
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@ -65,11 +65,11 @@ async def websocket_register_new(ws: WebSocket):
|
|||||||
# WebAuthn registration
|
# WebAuthn registration
|
||||||
credential = await register_chat(ws, user_id, user_name)
|
credential = await register_chat(ws, user_id, user_name)
|
||||||
|
|
||||||
# Store the user in the database
|
# Store the user and credential in the database
|
||||||
async with connect() as db:
|
await db.create_user_and_credential(
|
||||||
await db.conn.execute("BEGIN")
|
User(user_id, user_name, created_at=datetime.now()),
|
||||||
await db.create_user(User(user_id, user_name, created_at=datetime.now()))
|
credential,
|
||||||
await db.create_credential(credential)
|
)
|
||||||
|
|
||||||
# Create a session token for the new user
|
# Create a session token for the new user
|
||||||
session_token = create_session_token(user_id, credential.credential_id)
|
session_token = create_session_token(user_id, credential.credential_id)
|
||||||
@ -101,16 +101,15 @@ async def websocket_register_add(ws: WebSocket):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Get user information to get the user_name
|
# Get user information to get the user_name
|
||||||
async with connect() as db:
|
user = await db.get_user_by_id(user_id)
|
||||||
user = await db.get_user_by_user_id(user_id.bytes)
|
|
||||||
user_name = user.user_name
|
user_name = user.user_name
|
||||||
|
challenge_ids = await db.get_user_credentials(user_id)
|
||||||
|
|
||||||
# WebAuthn registration
|
# WebAuthn registration
|
||||||
credential = await register_chat(ws, user_id, user_name)
|
credential = await register_chat(ws, user_id, user_name, challenge_ids)
|
||||||
print(f"New credential for user {user_id}: {credential}")
|
print(f"New credential for user {user_id}: {credential}")
|
||||||
# Store the new credential in the database
|
# Store the new credential in the database
|
||||||
async with connect() as db:
|
await db.create_credential_for_user(credential)
|
||||||
await db.create_credential(credential)
|
|
||||||
|
|
||||||
await ws.send_json(
|
await ws.send_json(
|
||||||
{
|
{
|
||||||
@ -128,11 +127,17 @@ async def websocket_register_add(ws: WebSocket):
|
|||||||
await ws.send_json({"error": f"Server error: {str(e)}"})
|
await ws.send_json({"error": f"Server error: {str(e)}"})
|
||||||
|
|
||||||
|
|
||||||
async def register_chat(ws: WebSocket, user_id: UUID, user_name: str):
|
async def register_chat(
|
||||||
|
ws: WebSocket,
|
||||||
|
user_id: UUID,
|
||||||
|
user_name: str,
|
||||||
|
credential_ids: list[bytes] | None = None,
|
||||||
|
):
|
||||||
"""Generate registration options and send them to the client."""
|
"""Generate registration options and send them to the client."""
|
||||||
options, challenge = passkey.reg_generate_options(
|
options, challenge = passkey.reg_generate_options(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
|
credential_ids=credential_ids,
|
||||||
)
|
)
|
||||||
await ws.send_json(options)
|
await ws.send_json(options)
|
||||||
response = await ws.receive_json()
|
response = await ws.receive_json()
|
||||||
@ -144,17 +149,16 @@ async def register_chat(ws: WebSocket, user_id: UUID, user_name: str):
|
|||||||
async def websocket_authenticate(ws: WebSocket):
|
async def websocket_authenticate(ws: WebSocket):
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
try:
|
try:
|
||||||
options, challenge = await passkey.auth_generate_options()
|
options, challenge = passkey.auth_generate_options()
|
||||||
await ws.send_json(options)
|
await ws.send_json(options)
|
||||||
# Wait for the client to use his authenticator to authenticate
|
# Wait for the client to use his authenticator to authenticate
|
||||||
credential = passkey.auth_parse(await ws.receive_json())
|
credential = passkey.auth_parse(await ws.receive_json())
|
||||||
async with connect() as db:
|
|
||||||
# Fetch from the database by credential ID
|
# Fetch from the database by credential ID
|
||||||
stored_cred = await db.get_credential_by_id(credential.raw_id)
|
stored_cred = await db.get_credential_by_id(credential.raw_id)
|
||||||
# Verify the credential matches the stored data
|
# Verify the credential matches the stored data
|
||||||
await passkey.auth_verify(credential, challenge, stored_cred)
|
passkey.auth_verify(credential, challenge, stored_cred)
|
||||||
# Update both credential and user's last_seen timestamp
|
# Update both credential and user's last_seen timestamp
|
||||||
await db.login(stored_cred.user_id.bytes, stored_cred)
|
await db.login_user(stored_cred.user_id, stored_cred)
|
||||||
|
|
||||||
# Create a session token for the authenticated user
|
# Create a session token for the authenticated user
|
||||||
session_token = create_session_token(
|
session_token = create_session_token(
|
||||||
|
@ -155,7 +155,7 @@ class Passkey:
|
|||||||
|
|
||||||
### Authentication Methods ###
|
### Authentication Methods ###
|
||||||
|
|
||||||
async def auth_generate_options(
|
def auth_generate_options(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
user_verification_required=False,
|
user_verification_required=False,
|
||||||
@ -178,7 +178,7 @@ class Passkey:
|
|||||||
user_verification=(
|
user_verification=(
|
||||||
UserVerificationRequirement.REQUIRED
|
UserVerificationRequirement.REQUIRED
|
||||||
if user_verification_required
|
if user_verification_required
|
||||||
else UserVerificationRequirement.PREFERRED
|
else UserVerificationRequirement.DISCOURAGED
|
||||||
),
|
),
|
||||||
allow_credentials=_convert_credential_ids(credential_ids),
|
allow_credentials=_convert_credential_ids(credential_ids),
|
||||||
**authopts,
|
**authopts,
|
||||||
@ -188,7 +188,7 @@ class Passkey:
|
|||||||
def auth_parse(self, response: dict | str) -> AuthenticationCredential:
|
def auth_parse(self, response: dict | str) -> AuthenticationCredential:
|
||||||
return parse_authentication_credential_json(response)
|
return parse_authentication_credential_json(response)
|
||||||
|
|
||||||
async def auth_verify(
|
def auth_verify(
|
||||||
self,
|
self,
|
||||||
credential: AuthenticationCredential,
|
credential: AuthenticationCredential,
|
||||||
expected_challenge: bytes,
|
expected_challenge: bytes,
|
||||||
|
@ -12,7 +12,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from fastapi import Request, Response
|
from fastapi import Request, Response
|
||||||
|
|
||||||
from .db import User, connect
|
from .db import User, get_user_by_id
|
||||||
from .jwt_manager import validate_session_token
|
from .jwt_manager import validate_session_token
|
||||||
|
|
||||||
COOKIE_NAME = "session_token"
|
COOKIE_NAME = "session_token"
|
||||||
@ -29,9 +29,8 @@ async def get_current_user(request: Request) -> Optional[User]:
|
|||||||
if not token_data:
|
if not token_data:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async with connect() as db:
|
|
||||||
try:
|
try:
|
||||||
user = await db.get_user_by_user_id(token_data["user_id"].bytes)
|
user = await get_user_by_id(token_data["user_id"])
|
||||||
return user
|
return user
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
@ -15,6 +15,7 @@ dependencies = [
|
|||||||
"websockets>=12.0",
|
"websockets>=12.0",
|
||||||
"webauthn>=1.11.1",
|
"webauthn>=1.11.1",
|
||||||
"base64url>=1.0.0",
|
"base64url>=1.0.0",
|
||||||
|
"sqlalchemy[asyncio]>=2.0.0",
|
||||||
"aiosqlite>=0.19.0",
|
"aiosqlite>=0.19.0",
|
||||||
"uuid7-standard>=1.0.0",
|
"uuid7-standard>=1.0.0",
|
||||||
"pyjwt>=2.8.0",
|
"pyjwt>=2.8.0",
|
||||||
|
@ -334,7 +334,7 @@ async function addNewCredential() {
|
|||||||
clearStatus('dashboardStatus')
|
clearStatus('dashboardStatus')
|
||||||
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
showStatus('dashboardStatus', `Failed to add new passkey: ${error.message}`, 'error')
|
showStatus('dashboardStatus', 'Registration cancelled', 'error')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user