A lot of cleanup, restructuring project directory.
This commit is contained in:
3
passkey/__init__.py
Normal file
3
passkey/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .sansio import Passkey
|
||||
|
||||
__all__ = ["Passkey"]
|
||||
32
passkey/aaguid/__init__.py
Normal file
32
passkey/aaguid/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
AAGUID (Authenticator Attestation GUID) management for WebAuthn credentials.
|
||||
|
||||
This module provides functionality to:
|
||||
- Load AAGUID data from JSON file
|
||||
- Look up authenticator information by AAGUID
|
||||
- Return only relevant AAGUID data for user credentials
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from importlib.resources import files
|
||||
|
||||
__ALL__ = ["AAGUID", "filter"]
|
||||
|
||||
# Path to the AAGUID JSON file
|
||||
AAGUID_FILE = files("passkey") / "aaguid" / "combined_aaguid.json"
|
||||
AAGUID: dict[str, dict] = json.loads(AAGUID_FILE.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def filter(aaguids: Iterable[str]) -> dict[str, dict]:
|
||||
"""
|
||||
Get AAGUID information only for the provided set of AAGUIDs.
|
||||
|
||||
Args:
|
||||
aaguids: Set of AAGUID strings that the user has credentials for
|
||||
|
||||
Returns:
|
||||
Dictionary mapping AAGUID to authenticator information for only
|
||||
the AAGUIDs that the user has and that we have data for
|
||||
"""
|
||||
return {aaguid: AAGUID[aaguid] for aaguid in aaguids if aaguid in AAGUID}
|
||||
1
passkey/aaguid/combined_aaguid.json
Normal file
1
passkey/aaguid/combined_aaguid.json
Normal file
File diff suppressed because one or more lines are too long
388
passkey/db/sql.py
Normal file
388
passkey/db/sql.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""
|
||||
Async database implementation for WebAuthn passkey authentication.
|
||||
|
||||
This module provides an async database layer using SQLAlchemy async mode
|
||||
for managing users and credentials in a WebAuthn authentication system.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import (
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
LargeBinary,
|
||||
String,
|
||||
delete,
|
||||
select,
|
||||
update,
|
||||
)
|
||||
from sqlalchemy.dialects.sqlite import BLOB
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
from ..sansio import StoredCredential
|
||||
|
||||
DB_PATH = "sqlite+aiosqlite:///webauthn.db"
|
||||
|
||||
|
||||
# SQLAlchemy Models
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class UserModel(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
user_id: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
|
||||
user_name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||
last_seen: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
visits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
# Relationship to credentials
|
||||
credentials: Mapped[list["CredentialModel"]] = relationship(
|
||||
"CredentialModel", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class CredentialModel(Base):
|
||||
__tablename__ = "credentials"
|
||||
|
||||
credential_id: Mapped[bytes] = mapped_column(LargeBinary(64), primary_key=True)
|
||||
user_id: Mapped[bytes] = mapped_column(
|
||||
LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
|
||||
)
|
||||
aaguid: Mapped[bytes] = mapped_column(LargeBinary(16), nullable=False)
|
||||
public_key: Mapped[bytes] = mapped_column(BLOB, nullable=False)
|
||||
sign_count: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
last_verified: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# Relationship to user
|
||||
user: Mapped["UserModel"] = relationship("UserModel", back_populates="credentials")
|
||||
|
||||
|
||||
class ResetTokenModel(Base):
|
||||
__tablename__ = "reset_tokens"
|
||||
|
||||
token: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
user_id: Mapped[bytes] = mapped_column(
|
||||
LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||
|
||||
# Relationship to user
|
||||
user: Mapped["UserModel"] = relationship("UserModel")
|
||||
|
||||
|
||||
@dataclass
|
||||
class User:
|
||||
user_id: UUID
|
||||
user_name: str
|
||||
created_at: datetime | None = None
|
||||
last_seen: datetime | None = None
|
||||
visits: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetToken:
|
||||
token: str
|
||||
user_id: UUID
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# Global engine and session factory
|
||||
engine = create_async_engine(DB_PATH, echo=False)
|
||||
async_session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect():
|
||||
"""Context manager for database connections."""
|
||||
async with async_session_factory() as session:
|
||||
yield DB(session)
|
||||
await session.commit()
|
||||
|
||||
|
||||
class DB:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def init_db(self) -> None:
|
||||
"""Initialize database tables."""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async def get_user_by_user_id(self, user_id: UUID) -> User:
|
||||
"""Get user record by WebAuthn user ID."""
|
||||
stmt = select(UserModel).where(UserModel.user_id == user_id.bytes)
|
||||
result = await self.session.execute(stmt)
|
||||
user_model = result.scalar_one_or_none()
|
||||
|
||||
if user_model:
|
||||
return User(
|
||||
user_id=UUID(bytes=user_model.user_id),
|
||||
user_name=user_model.user_name,
|
||||
created_at=user_model.created_at,
|
||||
last_seen=user_model.last_seen,
|
||||
visits=user_model.visits,
|
||||
)
|
||||
raise ValueError("User not found")
|
||||
|
||||
async def create_user(self, user: User) -> None:
|
||||
"""Create a new user."""
|
||||
user_model = UserModel(
|
||||
user_id=user.user_id.bytes,
|
||||
user_name=user.user_name,
|
||||
created_at=user.created_at or datetime.now(),
|
||||
last_seen=user.last_seen,
|
||||
visits=user.visits,
|
||||
)
|
||||
self.session.add(user_model)
|
||||
await self.session.flush()
|
||||
|
||||
async def create_credential(self, credential: StoredCredential) -> None:
|
||||
"""Store a credential for a user."""
|
||||
credential_model = CredentialModel(
|
||||
credential_id=credential.credential_id,
|
||||
user_id=credential.user_id.bytes,
|
||||
aaguid=credential.aaguid.bytes,
|
||||
public_key=credential.public_key,
|
||||
sign_count=credential.sign_count,
|
||||
created_at=credential.created_at,
|
||||
last_used=credential.last_used,
|
||||
last_verified=credential.last_verified,
|
||||
)
|
||||
self.session.add(credential_model)
|
||||
await self.session.flush()
|
||||
|
||||
async def get_credential_by_id(self, credential_id: bytes) -> StoredCredential:
|
||||
"""Get credential by credential ID."""
|
||||
stmt = select(CredentialModel).where(
|
||||
CredentialModel.credential_id == credential_id
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
credential_model = result.scalar_one_or_none()
|
||||
|
||||
if credential_model:
|
||||
return StoredCredential(
|
||||
credential_id=credential_model.credential_id,
|
||||
user_id=UUID(bytes=credential_model.user_id),
|
||||
aaguid=UUID(bytes=credential_model.aaguid),
|
||||
public_key=credential_model.public_key,
|
||||
sign_count=credential_model.sign_count,
|
||||
created_at=credential_model.created_at,
|
||||
last_used=credential_model.last_used,
|
||||
last_verified=credential_model.last_verified,
|
||||
)
|
||||
raise ValueError("Credential not registered")
|
||||
|
||||
async def get_credentials_by_user_id(self, user_id: UUID) -> list[bytes]:
|
||||
"""Get all credential IDs for a user."""
|
||||
stmt = select(CredentialModel.credential_id).where(
|
||||
CredentialModel.user_id == user_id.bytes
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return [row[0] for row in result.fetchall()]
|
||||
|
||||
async def update_credential(self, credential: StoredCredential) -> None:
|
||||
"""Update the sign count, created_at, last_used, and last_verified for a credential."""
|
||||
stmt = (
|
||||
update(CredentialModel)
|
||||
.where(CredentialModel.credential_id == credential.credential_id)
|
||||
.values(
|
||||
sign_count=credential.sign_count,
|
||||
created_at=credential.created_at,
|
||||
last_used=credential.last_used,
|
||||
last_verified=credential.last_verified,
|
||||
)
|
||||
)
|
||||
await self.session.execute(stmt)
|
||||
|
||||
async def login(self, user_id: UUID, credential: StoredCredential) -> None:
|
||||
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
|
||||
async with self.session.begin():
|
||||
# Update credential
|
||||
await self.update_credential(credential)
|
||||
|
||||
# Update user's last_seen and increment visits
|
||||
stmt = (
|
||||
update(UserModel)
|
||||
.where(UserModel.user_id == user_id.bytes)
|
||||
.values(last_seen=credential.last_used, visits=UserModel.visits + 1)
|
||||
)
|
||||
await self.session.execute(stmt)
|
||||
|
||||
async def create_new_session(
|
||||
self, user_id: UUID, credential: StoredCredential
|
||||
) -> None:
|
||||
"""Create a new session for a user by incrementing visits and updating last_seen."""
|
||||
async with self.session.begin():
|
||||
# Update credential
|
||||
await self.update_credential(credential)
|
||||
|
||||
# Update user's last_seen and increment visits
|
||||
stmt = (
|
||||
update(UserModel)
|
||||
.where(UserModel.user_id == user_id.bytes)
|
||||
.values(last_seen=credential.last_used, visits=UserModel.visits + 1)
|
||||
)
|
||||
await self.session.execute(stmt)
|
||||
|
||||
async def delete_credential(self, credential_id: bytes) -> None:
|
||||
"""Delete a credential by its ID."""
|
||||
stmt = delete(CredentialModel).where(
|
||||
CredentialModel.credential_id == credential_id
|
||||
)
|
||||
await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
|
||||
async def create_reset_token(self, user_id: UUID, token: str | None = None) -> str:
|
||||
"""Create a new reset token for a user."""
|
||||
if token is None:
|
||||
token = secrets.token_urlsafe(32)
|
||||
|
||||
reset_token_model = ResetTokenModel(
|
||||
token=token,
|
||||
user_id=user_id.bytes,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
self.session.add(reset_token_model)
|
||||
await self.session.flush()
|
||||
return token
|
||||
|
||||
async def get_reset_token(self, token: str) -> ResetToken | None:
|
||||
"""Get reset token by token string."""
|
||||
stmt = select(ResetTokenModel).where(ResetTokenModel.token == token)
|
||||
result = await self.session.execute(stmt)
|
||||
token_model = result.scalar_one_or_none()
|
||||
|
||||
if token_model:
|
||||
return ResetToken(
|
||||
token=token_model.token,
|
||||
user_id=UUID(bytes=token_model.user_id),
|
||||
created_at=token_model.created_at,
|
||||
)
|
||||
return None
|
||||
|
||||
async def delete_reset_token(self, token: str) -> None:
|
||||
"""Delete a reset token (used after successful credential addition)."""
|
||||
stmt = delete(ResetTokenModel).where(ResetTokenModel.token == token)
|
||||
await self.session.execute(stmt)
|
||||
|
||||
async def cleanup_expired_tokens(self) -> None:
|
||||
"""Remove expired reset tokens (older than 24 hours)."""
|
||||
expiry_time = datetime.now() - timedelta(hours=24)
|
||||
stmt = delete(ResetTokenModel).where(ResetTokenModel.created_at < expiry_time)
|
||||
await self.session.execute(stmt)
|
||||
|
||||
async def get_user_by_username(self, user_name: str) -> User | None:
|
||||
"""Get user by username."""
|
||||
stmt = select(UserModel).where(UserModel.user_name == user_name)
|
||||
result = await self.session.execute(stmt)
|
||||
user_model = result.scalar_one_or_none()
|
||||
|
||||
if user_model:
|
||||
return User(
|
||||
user_id=UUID(bytes=user_model.user_id),
|
||||
user_name=user_model.user_name,
|
||||
created_at=user_model.created_at,
|
||||
last_seen=user_model.last_seen,
|
||||
visits=user_model.visits,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# Standalone functions that handle database connections internally
|
||||
async def init_database() -> None:
|
||||
"""Initialize database tables."""
|
||||
async with connect() as db:
|
||||
await db.init_db()
|
||||
|
||||
|
||||
async def create_user_and_credential(user: User, credential: StoredCredential) -> None:
|
||||
"""Create a new user and their first credential in a single transaction."""
|
||||
async with connect() as db:
|
||||
await db.session.begin()
|
||||
# Set visits to 1 for the new user since they're creating their first session
|
||||
user.visits = 1
|
||||
await db.create_user(user)
|
||||
await db.create_credential(credential)
|
||||
|
||||
|
||||
async def get_user_by_id(user_id: UUID) -> User:
|
||||
"""Get user record by WebAuthn user ID."""
|
||||
async with connect() as db:
|
||||
return await db.get_user_by_user_id(user_id)
|
||||
|
||||
|
||||
async def create_credential_for_user(credential: StoredCredential) -> None:
|
||||
"""Store a credential for an existing user."""
|
||||
async with connect() as db:
|
||||
await db.create_credential(credential)
|
||||
|
||||
|
||||
async def get_credential_by_id(credential_id: bytes) -> StoredCredential:
|
||||
"""Get credential by credential ID."""
|
||||
async with connect() as db:
|
||||
return await db.get_credential_by_id(credential_id)
|
||||
|
||||
|
||||
async def get_user_credentials(user_id: UUID) -> list[bytes]:
|
||||
"""Get all credential IDs for a user."""
|
||||
async with connect() as db:
|
||||
return await db.get_credentials_by_user_id(user_id)
|
||||
|
||||
|
||||
async def login_user(user_id: UUID, credential: StoredCredential) -> None:
|
||||
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
|
||||
async with connect() as db:
|
||||
await db.login(user_id, credential)
|
||||
|
||||
|
||||
async def delete_user_credential(credential_id: bytes) -> None:
|
||||
"""Delete a credential by its ID."""
|
||||
async with connect() as db:
|
||||
await db.delete_credential(credential_id)
|
||||
|
||||
|
||||
async def create_new_session(user_id: UUID, credential: StoredCredential) -> None:
|
||||
"""Create a new session for a user by incrementing visits and updating last_seen."""
|
||||
async with connect() as db:
|
||||
await db.create_new_session(user_id, credential)
|
||||
|
||||
|
||||
async def create_reset_token(user_id: UUID, token: str | None = None) -> str:
|
||||
"""Create a reset token for a user."""
|
||||
async with connect() as db:
|
||||
return await db.create_reset_token(user_id, token)
|
||||
|
||||
|
||||
async def get_reset_token(token: str) -> ResetToken | None:
|
||||
"""Get reset token by token string."""
|
||||
async with connect() as db:
|
||||
return await db.get_reset_token(token)
|
||||
|
||||
|
||||
async def delete_reset_token(token: str) -> None:
|
||||
"""Delete a reset token (used after successful credential addition)."""
|
||||
async with connect() as db:
|
||||
await db.delete_reset_token(token)
|
||||
|
||||
|
||||
async def cleanup_expired_tokens() -> None:
|
||||
"""Remove expired reset tokens (older than 24 hours)."""
|
||||
async with connect() as db:
|
||||
await db.cleanup_expired_tokens()
|
||||
|
||||
|
||||
async def get_user_by_username(user_name: str) -> User | None:
|
||||
"""Get user by username."""
|
||||
async with connect() as db:
|
||||
return await db.get_user_by_username(user_name)
|
||||
234
passkey/fastapi/api_handlers.py
Normal file
234
passkey/fastapi/api_handlers.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""
|
||||
API endpoints for user management and session handling.
|
||||
|
||||
This module contains all the HTTP API endpoints for:
|
||||
- User information retrieval
|
||||
- User credentials management
|
||||
- Session token validation and refresh
|
||||
- Login/logout functionality
|
||||
"""
|
||||
|
||||
from fastapi import Request, Response
|
||||
|
||||
from .. import aaguid
|
||||
from ..db import sql
|
||||
from ..util.jwt import refresh_session_token, validate_session_token
|
||||
from .session_manager import (
|
||||
clear_session_cookie,
|
||||
get_current_user,
|
||||
get_session_token_from_bearer,
|
||||
get_session_token_from_cookie,
|
||||
set_session_cookie,
|
||||
)
|
||||
|
||||
|
||||
async def get_user_info(request: Request) -> dict:
|
||||
"""Get user information from session cookie."""
|
||||
try:
|
||||
user = await get_current_user(request)
|
||||
if not user:
|
||||
return {"error": "Not authenticated"}
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"user": {
|
||||
"user_id": str(user.user_id),
|
||||
"user_name": user.user_name,
|
||||
"created_at": user.created_at.isoformat() if user.created_at else None,
|
||||
"last_seen": user.last_seen.isoformat() if user.last_seen else None,
|
||||
"visits": user.visits,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to get user info: {str(e)}"}
|
||||
|
||||
|
||||
async def get_user_credentials(request: Request) -> dict:
|
||||
"""Get all credentials for a user using session cookie."""
|
||||
try:
|
||||
user = await get_current_user(request)
|
||||
if not user:
|
||||
return {"error": "Not authenticated"}
|
||||
|
||||
# Get current session credential ID
|
||||
current_credential_id = None
|
||||
session_token = get_session_token_from_cookie(request)
|
||||
if session_token:
|
||||
token_data = validate_session_token(session_token)
|
||||
if token_data:
|
||||
current_credential_id = token_data.get("credential_id")
|
||||
|
||||
# Get all credentials for the user
|
||||
credential_ids = await sql.get_user_credentials(user.user_id)
|
||||
|
||||
credentials = []
|
||||
user_aaguids = set()
|
||||
|
||||
for cred_id in credential_ids:
|
||||
stored_cred = await sql.get_credential_by_id(cred_id)
|
||||
|
||||
# Convert AAGUID to string format
|
||||
aaguid_str = str(stored_cred.aaguid)
|
||||
user_aaguids.add(aaguid_str)
|
||||
|
||||
# Check if this is the current session credential
|
||||
is_current_session = current_credential_id == stored_cred.credential_id
|
||||
|
||||
credentials.append(
|
||||
{
|
||||
"credential_id": stored_cred.credential_id.hex(),
|
||||
"aaguid": aaguid_str,
|
||||
"created_at": stored_cred.created_at.isoformat(),
|
||||
"last_used": stored_cred.last_used.isoformat()
|
||||
if stored_cred.last_used
|
||||
else None,
|
||||
"last_verified": stored_cred.last_verified.isoformat()
|
||||
if stored_cred.last_verified
|
||||
else None,
|
||||
"sign_count": stored_cred.sign_count,
|
||||
"is_current_session": is_current_session,
|
||||
}
|
||||
)
|
||||
|
||||
# Get AAGUID information for only the AAGUIDs that the user has
|
||||
aaguid_info = aaguid.filter(user_aaguids)
|
||||
|
||||
# Sort credentials by creation date (earliest first, most recently created last)
|
||||
credentials.sort(key=lambda cred: cred["created_at"])
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"credentials": credentials,
|
||||
"aaguid_info": aaguid_info,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to get credentials: {str(e)}"}
|
||||
|
||||
|
||||
async def refresh_token(request: Request, response: Response) -> dict:
|
||||
"""Refresh the session token."""
|
||||
try:
|
||||
session_token = get_session_token_from_cookie(request)
|
||||
if not session_token:
|
||||
return {"error": "No session token found"}
|
||||
|
||||
# Validate and refresh the token
|
||||
new_token = refresh_session_token(session_token)
|
||||
|
||||
if new_token:
|
||||
set_session_cookie(response, new_token)
|
||||
return {"status": "success", "refreshed": True}
|
||||
else:
|
||||
clear_session_cookie(response)
|
||||
return {"error": "Invalid or expired session token"}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to refresh token: {str(e)}"}
|
||||
|
||||
|
||||
async def validate_token(request: Request) -> dict:
|
||||
"""Validate a session token and return user info."""
|
||||
try:
|
||||
session_token = get_session_token_from_cookie(request)
|
||||
if not session_token:
|
||||
return {"error": "No session token found"}
|
||||
|
||||
# Validate the session token
|
||||
token_data = validate_session_token(session_token)
|
||||
if not token_data:
|
||||
return {"error": "Invalid or expired session token"}
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"valid": True,
|
||||
"user_id": str(token_data["user_id"]),
|
||||
"credential_id": token_data["credential_id"].hex(),
|
||||
"issued_at": token_data["issued_at"],
|
||||
"expires_at": token_data["expires_at"],
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to validate token: {str(e)}"}
|
||||
|
||||
|
||||
async def logout(response: Response) -> dict:
|
||||
"""Log out the current user by clearing the session cookie."""
|
||||
clear_session_cookie(response)
|
||||
return {"status": "success", "message": "Logged out successfully"}
|
||||
|
||||
|
||||
async def set_session(request: Request, response: Response) -> dict:
|
||||
"""Set session cookie using JWT token from request body or Authorization header."""
|
||||
try:
|
||||
session_token = await get_session_token_from_bearer(request)
|
||||
|
||||
if not session_token:
|
||||
return {"error": "No session token provided"}
|
||||
|
||||
# Validate the session token
|
||||
token_data = validate_session_token(session_token)
|
||||
if not token_data:
|
||||
return {"error": "Invalid or expired session token"}
|
||||
|
||||
# Set the HTTP-only cookie
|
||||
set_session_cookie(response, session_token)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Session cookie set successfully",
|
||||
"user_id": str(token_data["user_id"]),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to set session: {str(e)}"}
|
||||
|
||||
|
||||
async def delete_credential(request: Request) -> dict:
|
||||
"""Delete a specific credential for the current user."""
|
||||
try:
|
||||
user = await get_current_user(request)
|
||||
if not user:
|
||||
return {"error": "Not authenticated"}
|
||||
|
||||
# Get the credential ID from the request body
|
||||
try:
|
||||
body = await request.json()
|
||||
credential_id = body.get("credential_id")
|
||||
if not credential_id:
|
||||
return {"error": "credential_id is required"}
|
||||
except Exception:
|
||||
return {"error": "Invalid request body"}
|
||||
|
||||
# Convert credential_id from hex string to bytes
|
||||
try:
|
||||
credential_id_bytes = bytes.fromhex(credential_id)
|
||||
except ValueError:
|
||||
return {"error": "Invalid credential_id format"}
|
||||
|
||||
# First, verify the credential belongs to the current user
|
||||
try:
|
||||
stored_cred = await sql.get_credential_by_id(credential_id_bytes)
|
||||
if stored_cred.user_id != user.user_id:
|
||||
return {"error": "Credential not found or access denied"}
|
||||
except ValueError:
|
||||
return {"error": "Credential not found"}
|
||||
|
||||
# Check if this is the current session credential
|
||||
session_token = get_session_token_from_cookie(request)
|
||||
if session_token:
|
||||
token_data = validate_session_token(session_token)
|
||||
if token_data and token_data.get("credential_id") == credential_id_bytes:
|
||||
return {"error": "Cannot delete current session credential"}
|
||||
|
||||
# Get user's remaining credentials count
|
||||
remaining_credentials = await sql.get_user_credentials(user.user_id)
|
||||
if len(remaining_credentials) <= 1:
|
||||
return {"error": "Cannot delete last remaining credential"}
|
||||
|
||||
# Delete the credential
|
||||
await sql.delete_user_credential(credential_id_bytes)
|
||||
|
||||
return {"status": "success", "message": "Credential deleted successfully"}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to delete credential: {str(e)}"}
|
||||
186
passkey/fastapi/main.py
Normal file
186
passkey/fastapi/main.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
Minimal FastAPI WebAuthn server with WebSocket support for passkey registration and authentication.
|
||||
|
||||
This module provides a simple WebAuthn implementation that:
|
||||
- Uses WebSocket for real-time communication
|
||||
- Supports Resident Keys (discoverable credentials) for passwordless authentication
|
||||
- Maintains challenges locally per connection
|
||||
- Uses async SQLite database for persistent storage of users and credentials
|
||||
- Enables true passwordless authentication where users don't need to enter a user_name
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import (
|
||||
FastAPI,
|
||||
Request,
|
||||
Response,
|
||||
)
|
||||
from fastapi import (
|
||||
Path as FastAPIPath,
|
||||
)
|
||||
from fastapi.responses import (
|
||||
FileResponse,
|
||||
RedirectResponse,
|
||||
)
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from ..db import sql
|
||||
from .api_handlers import (
|
||||
delete_credential,
|
||||
get_user_credentials,
|
||||
get_user_info,
|
||||
logout,
|
||||
refresh_token,
|
||||
set_session,
|
||||
validate_token,
|
||||
)
|
||||
from .reset_handlers import create_device_addition_link, validate_device_addition_token
|
||||
from .ws_handlers import ws_app
|
||||
|
||||
STATIC_DIR = Path(__file__).parent.parent / "frontend-build"
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await sql.init_database()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="Passkey Auth", lifespan=lifespan)
|
||||
|
||||
# Mount the WebSocket subapp
|
||||
app.mount("/auth/ws", ws_app)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@app.get("/auth/user-info")
|
||||
async def api_get_user_info(request: Request):
|
||||
"""Get user information from session cookie."""
|
||||
return await get_user_info(request)
|
||||
|
||||
|
||||
@app.get("/auth/user-credentials")
|
||||
async def api_get_user_credentials(request: Request):
|
||||
"""Get all credentials for a user using session cookie."""
|
||||
return await get_user_credentials(request)
|
||||
|
||||
|
||||
@app.post("/auth/refresh-token")
|
||||
async def api_refresh_token(request: Request, response: Response):
|
||||
"""Refresh the session token."""
|
||||
return await refresh_token(request, response)
|
||||
|
||||
|
||||
@app.get("/auth/validate-token")
|
||||
async def api_validate_token(request: Request):
|
||||
"""Validate a session token and return user info."""
|
||||
return await validate_token(request)
|
||||
|
||||
|
||||
@app.get("/auth/forward-auth")
|
||||
async def forward_authentication(request: Request):
|
||||
"""A verification endpoint to use with Caddy forward_auth or Nginx auth_request."""
|
||||
result = await validate_token(request)
|
||||
if result.get("status") != "success":
|
||||
# Serve the index.html of the authentication app if not authenticated
|
||||
return FileResponse(
|
||||
STATIC_DIR / "index.html",
|
||||
status_code=401,
|
||||
headers={"www-authenticate": "PrivateToken"},
|
||||
)
|
||||
|
||||
# If authenticated, return a success response
|
||||
return Response(
|
||||
status_code=204,
|
||||
headers={"x-auth-user-id": result["user_id"]},
|
||||
)
|
||||
|
||||
|
||||
@app.post("/auth/logout")
|
||||
async def api_logout(response: Response):
|
||||
"""Log out the current user by clearing the session cookie."""
|
||||
return await logout(response)
|
||||
|
||||
|
||||
@app.post("/auth/set-session")
|
||||
async def api_set_session(request: Request, response: Response):
|
||||
"""Set session cookie using JWT token from request body or Authorization header."""
|
||||
return await set_session(request, response)
|
||||
|
||||
|
||||
@app.post("/auth/delete-credential")
|
||||
async def api_delete_credential(request: Request):
|
||||
"""Delete a specific credential for the current user."""
|
||||
return await delete_credential(request)
|
||||
|
||||
|
||||
@app.post("/auth/create-device-link")
|
||||
async def api_create_device_link(request: Request):
|
||||
"""Create a device addition link for the authenticated user."""
|
||||
return await create_device_addition_link(request)
|
||||
|
||||
|
||||
@app.post("/auth/validate-device-token")
|
||||
async def api_validate_device_token(request: Request):
|
||||
"""Validate a device addition token."""
|
||||
return await validate_device_addition_token(request)
|
||||
|
||||
|
||||
@app.get("/auth/{passphrase}")
|
||||
async def reset_authentication(
|
||||
passphrase: str = FastAPIPath(pattern=r"^\w+(\.\w+){2,}$"),
|
||||
):
|
||||
response = RedirectResponse(url="/", status_code=303)
|
||||
response.set_cookie(
|
||||
key="auth-token",
|
||||
value=passphrase,
|
||||
httponly=False,
|
||||
secure=True,
|
||||
samesite="strict",
|
||||
max_age=2,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
# Serve static files
|
||||
app.mount(
|
||||
"/auth/assets", StaticFiles(directory=STATIC_DIR / "assets"), name="static assets"
|
||||
)
|
||||
|
||||
|
||||
@app.get("/auth")
|
||||
async def redirect_to_index():
|
||||
"""Serve the main authentication app."""
|
||||
return FileResponse(STATIC_DIR / "index.html")
|
||||
|
||||
|
||||
# Catch-all route for SPA - serve index.html for all non-API routes
|
||||
@app.get("/{path:path}")
|
||||
async def spa_handler(request: Request, path: str):
|
||||
"""Serve the Vue SPA for all routes (except API and static)"""
|
||||
if "text/html" not in request.headers.get("accept", ""):
|
||||
return Response(content="Not Found", status_code=404)
|
||||
return FileResponse(STATIC_DIR / "index.html")
|
||||
|
||||
|
||||
def main():
|
||||
"""Entry point for the application"""
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"passkey.fastapi.main:app",
|
||||
host="localhost",
|
||||
port=4401,
|
||||
reload=True,
|
||||
log_level="info",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
main()
|
||||
103
passkey/fastapi/reset_handlers.py
Normal file
103
passkey/fastapi/reset_handlers.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Device addition API handlers for WebAuthn authentication.
|
||||
|
||||
This module provides endpoints for authenticated users to:
|
||||
- Generate device addition links with human-readable tokens
|
||||
- Validate device addition tokens
|
||||
- Add new passkeys to existing accounts via tokens
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from ..db import sql
|
||||
from ..util.passphrase import generate
|
||||
from .session_manager import get_current_user
|
||||
|
||||
|
||||
async def create_device_addition_link(request: Request) -> dict:
|
||||
"""Create a device addition link for the authenticated user."""
|
||||
try:
|
||||
# Require authentication
|
||||
user = await get_current_user(request)
|
||||
if not user:
|
||||
return {"error": "Authentication required"}
|
||||
|
||||
# Generate a human-readable token
|
||||
token = generate(n=4, sep=".") # e.g., "able-ocean-forest-dawn"
|
||||
|
||||
# Create reset token in database
|
||||
await sql.create_reset_token(user.user_id, token)
|
||||
|
||||
# Generate the device addition link with pretty URL
|
||||
addition_link = f"{request.headers.get('origin', '')}/auth/{token}"
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Device addition link generated successfully",
|
||||
"addition_link": addition_link,
|
||||
"expires_in_hours": 24,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to create device addition link: {str(e)}"}
|
||||
|
||||
|
||||
async def validate_device_addition_token(request: Request) -> dict:
|
||||
"""Validate a device addition token and return user info."""
|
||||
try:
|
||||
body = await request.json()
|
||||
token = body.get("token")
|
||||
|
||||
if not token:
|
||||
return {"error": "Device addition token is required"}
|
||||
|
||||
# Get reset token
|
||||
reset_token = await sql.get_reset_token(token)
|
||||
if not reset_token:
|
||||
return {"error": "Invalid or expired device addition token"}
|
||||
|
||||
# Check if token is expired (24 hours)
|
||||
expiry_time = reset_token.created_at + timedelta(hours=24)
|
||||
if datetime.now() > expiry_time:
|
||||
return {"error": "Device addition token has expired"}
|
||||
|
||||
# Get user info
|
||||
user = await sql.get_user_by_id(reset_token.user_id)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"valid": True,
|
||||
"user_id": str(user.user_id),
|
||||
"user_name": user.user_name,
|
||||
"token": token,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to validate device addition token: {str(e)}"}
|
||||
|
||||
|
||||
async def use_device_addition_token(token: str) -> dict:
|
||||
"""Delete a device addition token after successful use."""
|
||||
try:
|
||||
# Get reset token first to validate it exists and is not expired
|
||||
reset_token = await sql.get_reset_token(token)
|
||||
if not reset_token:
|
||||
return {"error": "Invalid or expired device addition token"}
|
||||
|
||||
# Check if token is expired (24 hours)
|
||||
expiry_time = reset_token.created_at + timedelta(hours=24)
|
||||
if datetime.now() > expiry_time:
|
||||
return {"error": "Device addition token has expired"}
|
||||
|
||||
# Delete the token (it's now used)
|
||||
await sql.delete_reset_token(token)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Device addition token used successfully",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to use device addition token: {str(e)}"}
|
||||
98
passkey/fastapi/session_manager.py
Normal file
98
passkey/fastapi/session_manager.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Session management for WebAuthn authentication.
|
||||
|
||||
This module provides session management functionality including:
|
||||
- Getting current user from session cookies
|
||||
- Setting and clearing HTTP-only cookies
|
||||
- Session validation and token handling
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request, Response
|
||||
|
||||
from ..db.sql import User, get_user_by_id
|
||||
from ..util.jwt import validate_session_token
|
||||
|
||||
COOKIE_NAME = "auth"
|
||||
COOKIE_MAX_AGE = 86400 # 24 hours
|
||||
|
||||
|
||||
async def get_current_user(request: Request) -> User | None:
|
||||
"""Get the current user from the session cookie."""
|
||||
session_token = request.cookies.get(COOKIE_NAME)
|
||||
if not session_token:
|
||||
return None
|
||||
|
||||
token_data = validate_session_token(session_token)
|
||||
if not token_data:
|
||||
return None
|
||||
|
||||
try:
|
||||
user = await get_user_by_id(token_data["user_id"])
|
||||
return user
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def set_session_cookie(response: Response, session_token: str) -> None:
|
||||
"""Set the session token as an HTTP-only cookie."""
|
||||
response.set_cookie(
|
||||
key=COOKIE_NAME,
|
||||
value=session_token,
|
||||
max_age=COOKIE_MAX_AGE,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="lax",
|
||||
)
|
||||
|
||||
|
||||
def clear_session_cookie(response: Response) -> None:
|
||||
"""Clear the session cookie."""
|
||||
response.delete_cookie(key=COOKIE_NAME)
|
||||
|
||||
|
||||
def get_session_token_from_cookie(request: Request) -> str | None:
|
||||
"""Extract session token from request cookies."""
|
||||
return request.cookies.get(COOKIE_NAME)
|
||||
|
||||
|
||||
async def validate_session_from_request(request: Request) -> dict | None:
|
||||
"""Validate session token from request and return token data."""
|
||||
session_token = get_session_token_from_cookie(request)
|
||||
if not session_token:
|
||||
return None
|
||||
|
||||
return validate_session_token(session_token)
|
||||
|
||||
|
||||
async def get_session_token_from_bearer(request: Request) -> str | None:
|
||||
"""Extract session token from Authorization header or request body."""
|
||||
# Try to get token from Authorization header first
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
return auth_header.removeprefix("Bearer ")
|
||||
|
||||
|
||||
async def get_user_from_cookie_string(cookie_header: str) -> UUID | None:
|
||||
"""Parse cookie header and return user ID if valid session exists."""
|
||||
if not cookie_header:
|
||||
return None
|
||||
|
||||
# Parse cookies from header (simple implementation)
|
||||
cookies = {}
|
||||
for cookie in cookie_header.split(";"):
|
||||
cookie = cookie.strip()
|
||||
if "=" in cookie:
|
||||
name, value = cookie.split("=", 1)
|
||||
cookies[name] = value
|
||||
|
||||
session_token = cookies.get(COOKIE_NAME)
|
||||
if not session_token:
|
||||
return None
|
||||
|
||||
token_data = validate_session_token(session_token)
|
||||
if not token_data:
|
||||
return None
|
||||
|
||||
return token_data["user_id"]
|
||||
219
passkey/fastapi/ws_handlers.py
Normal file
219
passkey/fastapi/ws_handlers.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
WebSocket handlers for passkey authentication operations.
|
||||
|
||||
This module contains all WebSocket endpoints for:
|
||||
- User registration
|
||||
- Adding credentials to existing users
|
||||
- Device credential addition via token
|
||||
- Authentication
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
import uuid7
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from webauthn.helpers.exceptions import InvalidAuthenticationResponse
|
||||
|
||||
from ..db import sql
|
||||
from ..db.sql import User
|
||||
from ..sansio import Passkey
|
||||
from ..util.jwt import create_session_token
|
||||
from .session_manager import get_user_from_cookie_string
|
||||
|
||||
# Create a FastAPI subapp for WebSocket endpoints
|
||||
ws_app = FastAPI()
|
||||
|
||||
# Initialize the passkey instance
|
||||
passkey = Passkey(
|
||||
rp_id="localhost",
|
||||
rp_name="Passkey Auth",
|
||||
)
|
||||
|
||||
|
||||
async def register_chat(
|
||||
ws: WebSocket,
|
||||
user_id: UUID,
|
||||
user_name: str,
|
||||
credential_ids: list[bytes] | None = None,
|
||||
origin: str | None = None,
|
||||
):
|
||||
"""Generate registration options and send them to the client."""
|
||||
options, challenge = passkey.reg_generate_options(
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
credential_ids=credential_ids,
|
||||
origin=origin,
|
||||
)
|
||||
await ws.send_json(options)
|
||||
response = await ws.receive_json()
|
||||
return passkey.reg_verify(response, challenge, user_id, origin=origin)
|
||||
|
||||
|
||||
@ws_app.websocket("/register_new")
|
||||
async def websocket_register_new(ws: WebSocket, user_name: str):
|
||||
"""Register a new user and with a new passkey credential."""
|
||||
await ws.accept()
|
||||
origin = ws.headers.get("origin")
|
||||
try:
|
||||
user_id = uuid7.create()
|
||||
|
||||
# WebAuthn registration
|
||||
credential = await register_chat(ws, user_id, user_name, origin=origin)
|
||||
|
||||
# Store the user and credential in the database
|
||||
await sql.create_user_and_credential(
|
||||
User(user_id, user_name, created_at=datetime.now()),
|
||||
credential,
|
||||
)
|
||||
|
||||
# Create a session token for the new user
|
||||
session_token = create_session_token(user_id, credential.credential_id)
|
||||
|
||||
await ws.send_json(
|
||||
{
|
||||
"status": "success",
|
||||
"user_id": str(user_id),
|
||||
"session_token": session_token,
|
||||
}
|
||||
)
|
||||
except ValueError as e:
|
||||
await ws.send_json({"error": str(e)})
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception:
|
||||
logging.exception("Internal Server Error")
|
||||
await ws.send_json({"error": "Internal Server Error"})
|
||||
|
||||
|
||||
@ws_app.websocket("/add_credential")
|
||||
async def websocket_register_add(ws: WebSocket):
|
||||
"""Register a new credential for an existing user."""
|
||||
await ws.accept()
|
||||
origin = ws.headers.get("origin")
|
||||
try:
|
||||
# Authenticate user via cookie
|
||||
cookie_header = ws.headers.get("cookie", "")
|
||||
user_id = await get_user_from_cookie_string(cookie_header)
|
||||
|
||||
if not user_id:
|
||||
await ws.send_json({"error": "Authentication required"})
|
||||
return
|
||||
|
||||
# Get user information to get the user_name
|
||||
user = await sql.get_user_by_id(user_id)
|
||||
user_name = user.user_name
|
||||
challenge_ids = await sql.get_user_credentials(user_id)
|
||||
|
||||
# WebAuthn registration
|
||||
credential = await register_chat(
|
||||
ws, user_id, user_name, challenge_ids, origin=origin
|
||||
)
|
||||
# Store the new credential in the database
|
||||
await sql.create_credential_for_user(credential)
|
||||
|
||||
await ws.send_json(
|
||||
{
|
||||
"status": "success",
|
||||
"user_id": str(user_id),
|
||||
"credential_id": credential.credential_id.hex(),
|
||||
"message": "New credential added successfully",
|
||||
}
|
||||
)
|
||||
except ValueError as e:
|
||||
await ws.send_json({"error": str(e)})
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception:
|
||||
logging.exception("Internal Server Error")
|
||||
await ws.send_json({"error": "Internal Server Error"})
|
||||
|
||||
|
||||
@ws_app.websocket("/add_device_credential")
|
||||
async def websocket_add_device_credential(ws: WebSocket, token: str):
|
||||
"""Add a new credential for an existing user via device addition token."""
|
||||
await ws.accept()
|
||||
origin = ws.headers.get("origin")
|
||||
try:
|
||||
reset_token = await sql.get_reset_token(token)
|
||||
if not reset_token:
|
||||
await ws.send_json({"error": "Invalid or expired device addition token"})
|
||||
return
|
||||
|
||||
# Check if token is expired (24 hours)
|
||||
expiry_time = reset_token.created_at + timedelta(hours=24)
|
||||
if datetime.now() > expiry_time:
|
||||
await ws.send_json({"error": "Device addition token has expired"})
|
||||
return
|
||||
|
||||
# Get user information
|
||||
user = await sql.get_user_by_id(reset_token.user_id)
|
||||
|
||||
# WebAuthn registration
|
||||
# Fetch challenge IDs for the user
|
||||
challenge_ids = await sql.get_user_credentials(reset_token.user_id)
|
||||
|
||||
credential = await register_chat(
|
||||
ws, reset_token.user_id, user.user_name, challenge_ids, origin=origin
|
||||
)
|
||||
|
||||
# Store the new credential in the database
|
||||
await sql.create_credential_for_user(credential)
|
||||
|
||||
# Delete the device addition token (it's now used)
|
||||
await sql.delete_reset_token(token)
|
||||
|
||||
await ws.send_json(
|
||||
{
|
||||
"status": "success",
|
||||
"user_id": str(reset_token.user_id),
|
||||
"credential_id": credential.credential_id.hex(),
|
||||
"message": "New credential added successfully via device addition token",
|
||||
}
|
||||
)
|
||||
except ValueError as e:
|
||||
await ws.send_json({"error": str(e)})
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception:
|
||||
logging.exception("Internal Server Error")
|
||||
await ws.send_json({"error": "Internal Server Error"})
|
||||
|
||||
|
||||
@ws_app.websocket("/authenticate")
|
||||
async def websocket_authenticate(ws: WebSocket):
|
||||
await ws.accept()
|
||||
origin = ws.headers.get("origin")
|
||||
try:
|
||||
options, challenge = passkey.auth_generate_options()
|
||||
await ws.send_json(options)
|
||||
# Wait for the client to use his authenticator to authenticate
|
||||
credential = passkey.auth_parse(await ws.receive_json())
|
||||
# Fetch from the database by credential ID
|
||||
stored_cred = await sql.get_credential_by_id(credential.raw_id)
|
||||
# Verify the credential matches the stored data
|
||||
passkey.auth_verify(credential, challenge, stored_cred, origin=origin)
|
||||
# Update both credential and user's last_seen timestamp
|
||||
await sql.login_user(stored_cred.user_id, stored_cred)
|
||||
|
||||
# Create a session token for the authenticated user
|
||||
session_token = create_session_token(
|
||||
stored_cred.user_id, stored_cred.credential_id
|
||||
)
|
||||
|
||||
await ws.send_json(
|
||||
{
|
||||
"status": "success",
|
||||
"user_id": str(stored_cred.user_id),
|
||||
"session_token": session_token,
|
||||
}
|
||||
)
|
||||
except (ValueError, InvalidAuthenticationResponse) as e:
|
||||
logging.exception("ValueError")
|
||||
await ws.send_json({"error": str(e)})
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception:
|
||||
logging.exception("Internal Server Error")
|
||||
await ws.send_json({"error": "Internal Server Error"})
|
||||
233
passkey/sansio.py
Normal file
233
passkey/sansio.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
WebAuthn handler class that combines registration and authentication functionality.
|
||||
|
||||
This module provides a unified interface for WebAuthn operations including:
|
||||
- Registration challenge generation and verification
|
||||
- Authentication challenge generation and verification
|
||||
- Credential validation
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from webauthn import (
|
||||
generate_authentication_options,
|
||||
generate_registration_options,
|
||||
verify_authentication_response,
|
||||
verify_registration_response,
|
||||
)
|
||||
from webauthn.authentication.verify_authentication_response import (
|
||||
VerifiedAuthentication,
|
||||
)
|
||||
from webauthn.helpers import (
|
||||
options_to_json,
|
||||
parse_authentication_credential_json,
|
||||
parse_registration_credential_json,
|
||||
)
|
||||
from webauthn.helpers.cose import COSEAlgorithmIdentifier
|
||||
from webauthn.helpers.structs import (
|
||||
AttestationConveyancePreference,
|
||||
AuthenticationCredential,
|
||||
AuthenticatorSelectionCriteria,
|
||||
PublicKeyCredentialDescriptor,
|
||||
ResidentKeyRequirement,
|
||||
UserVerificationRequirement,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StoredCredential:
|
||||
"""Credential data stored in the database."""
|
||||
|
||||
# Fields set only at registration time
|
||||
credential_id: bytes
|
||||
user_id: UUID
|
||||
aaguid: UUID
|
||||
public_key: bytes
|
||||
# Mutable fields that may be updated during authentication
|
||||
sign_count: int
|
||||
created_at: datetime
|
||||
last_used: datetime | None = None
|
||||
last_verified: datetime | None = None
|
||||
|
||||
|
||||
class Passkey:
|
||||
"""WebAuthn handler for registration and authentication operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rp_id: str,
|
||||
rp_name: str,
|
||||
origin: str | None = None,
|
||||
supported_pub_key_algs: list[COSEAlgorithmIdentifier] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the WebAuthn handler.
|
||||
|
||||
Args:
|
||||
rp_id: Your security domain (e.g. "example.com")
|
||||
rp_name: The relying party name (e.g., "My Application" - visible to users)
|
||||
origin: The origin URL of the application (e.g. "https://app.example.com"). Must be a subdomain or same as rp_id, with port and scheme but no path included.
|
||||
supported_pub_key_algs: List of supported COSE algorithms (default is EDDSA, ECDSA_SHA_256, RSASSA_PKCS1_v1_5_SHA_256).
|
||||
"""
|
||||
self.rp_id = rp_id
|
||||
self.rp_name = rp_name
|
||||
self.origin = origin or f"https://{rp_id}"
|
||||
self.supported_pub_key_algs = supported_pub_key_algs or [
|
||||
COSEAlgorithmIdentifier.EDDSA,
|
||||
COSEAlgorithmIdentifier.ECDSA_SHA_256,
|
||||
COSEAlgorithmIdentifier.RSASSA_PKCS1_v1_5_SHA_256,
|
||||
]
|
||||
|
||||
### Registration Methods ###
|
||||
|
||||
def reg_generate_options(
|
||||
self,
|
||||
user_id: UUID,
|
||||
user_name: str,
|
||||
credential_ids: list[bytes] | None = None,
|
||||
origin: str | None = None,
|
||||
**regopts,
|
||||
) -> tuple[dict, bytes]:
|
||||
"""
|
||||
Generate registration options for WebAuthn registration.
|
||||
|
||||
Args:
|
||||
user_id: The user ID as bytes
|
||||
user_name: The username
|
||||
credential_ids: For an already authenticated user, a list of credential IDs
|
||||
associated with the account. This prevents accidentally adding another
|
||||
credential on an authenticator that already has one of the listed IDs.
|
||||
origin: The origin URL of the application (e.g. "https://app.example.com"). Must be a subdomain or same as rp_id, with port and scheme but no path included.
|
||||
regopts: Additional arguments to generate_registration_options.
|
||||
|
||||
Returns:
|
||||
JSON dict containing options to be sent to client,
|
||||
challenge bytes to keep during the registration process.
|
||||
"""
|
||||
options = generate_registration_options(
|
||||
rp_id=self.rp_id,
|
||||
rp_name=self.rp_name,
|
||||
user_id=user_id.bytes,
|
||||
user_name=user_name,
|
||||
attestation=AttestationConveyancePreference.DIRECT,
|
||||
authenticator_selection=AuthenticatorSelectionCriteria(
|
||||
resident_key=ResidentKeyRequirement.REQUIRED,
|
||||
user_verification=UserVerificationRequirement.PREFERRED,
|
||||
),
|
||||
exclude_credentials=_convert_credential_ids(credential_ids),
|
||||
supported_pub_key_algs=self.supported_pub_key_algs,
|
||||
**regopts,
|
||||
)
|
||||
return json.loads(options_to_json(options)), options.challenge
|
||||
|
||||
def reg_verify(
|
||||
self,
|
||||
response_json: dict | str,
|
||||
expected_challenge: bytes,
|
||||
user_id: UUID,
|
||||
origin: str | None = None,
|
||||
) -> StoredCredential:
|
||||
"""
|
||||
Verify registration response.
|
||||
|
||||
Args:
|
||||
credential: The credential response from the client
|
||||
expected_challenge: The expected challenge bytes
|
||||
|
||||
Returns:
|
||||
Registration verification result
|
||||
"""
|
||||
credential = parse_registration_credential_json(response_json)
|
||||
registration = verify_registration_response(
|
||||
credential=credential,
|
||||
expected_challenge=expected_challenge,
|
||||
expected_origin=origin or self.origin,
|
||||
expected_rp_id=self.rp_id,
|
||||
)
|
||||
return StoredCredential(
|
||||
credential_id=credential.raw_id,
|
||||
user_id=user_id,
|
||||
aaguid=UUID(registration.aaguid),
|
||||
public_key=registration.credential_public_key,
|
||||
sign_count=registration.sign_count,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
### Authentication Methods ###
|
||||
|
||||
def auth_generate_options(
|
||||
self,
|
||||
*,
|
||||
user_verification_required=False,
|
||||
credential_ids: list[bytes] | None = None,
|
||||
**authopts,
|
||||
) -> tuple[dict, bytes]:
|
||||
"""
|
||||
Generate authentication options for WebAuthn authentication.
|
||||
|
||||
Args:
|
||||
user_verification_required: The user will have to re-enter PIN or use biometrics for this operation. Useful when accessing security settings etc.
|
||||
credential_ids: For an already known user, a list of credential IDs associated with the account (less prompts during authentication).
|
||||
authopts: Additional arguments to generate_authentication_options.
|
||||
|
||||
Returns:
|
||||
Tuple of (JSON to be sent to client, challenge bytes to store)
|
||||
"""
|
||||
options = generate_authentication_options(
|
||||
rp_id=self.rp_id,
|
||||
user_verification=(
|
||||
UserVerificationRequirement.REQUIRED
|
||||
if user_verification_required
|
||||
else UserVerificationRequirement.DISCOURAGED
|
||||
),
|
||||
allow_credentials=_convert_credential_ids(credential_ids),
|
||||
**authopts,
|
||||
)
|
||||
return json.loads(options_to_json(options)), options.challenge
|
||||
|
||||
def auth_parse(self, response: dict | str) -> AuthenticationCredential:
|
||||
return parse_authentication_credential_json(response)
|
||||
|
||||
def auth_verify(
|
||||
self,
|
||||
credential: AuthenticationCredential,
|
||||
expected_challenge: bytes,
|
||||
stored_cred: StoredCredential,
|
||||
origin: str | None = None,
|
||||
) -> VerifiedAuthentication:
|
||||
"""
|
||||
Verify authentication response against locally stored credential data.
|
||||
|
||||
Args:
|
||||
credential: The authentication credential response from the client
|
||||
expected_challenge: The earlier generated challenge bytes
|
||||
stored_cred: The server stored credential record (modified by this function)
|
||||
"""
|
||||
expected_origin = origin or self.origin
|
||||
# Verify the authentication response
|
||||
verification = verify_authentication_response(
|
||||
credential=credential,
|
||||
expected_challenge=expected_challenge,
|
||||
expected_origin=expected_origin,
|
||||
expected_rp_id=self.rp_id,
|
||||
credential_public_key=stored_cred.public_key,
|
||||
credential_current_sign_count=stored_cred.sign_count,
|
||||
)
|
||||
stored_cred.sign_count = verification.new_sign_count
|
||||
now = datetime.now()
|
||||
stored_cred.last_used = now
|
||||
if verification.user_verified:
|
||||
stored_cred.last_verified = now
|
||||
return verification
|
||||
|
||||
|
||||
def _convert_credential_ids(
|
||||
credential_ids: list[bytes] | None,
|
||||
) -> list[PublicKeyCredentialDescriptor] | None:
|
||||
"""A helper to convert a list of credential IDs to PublicKeyCredentialDescriptor objects, or pass through None."""
|
||||
if credential_ids is None:
|
||||
return None
|
||||
return [PublicKeyCredentialDescriptor(id) for id in credential_ids]
|
||||
132
passkey/util/jwt.py
Normal file
132
passkey/util/jwt.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
JWT session management for WebAuthn authentication.
|
||||
|
||||
This module provides JWT token generation and validation for managing user sessions
|
||||
after successful WebAuthn authentication. Tokens contain user ID and credential ID
|
||||
for session validation.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import jwt
|
||||
|
||||
SECRET_FILE = Path("server-secret.bin")
|
||||
|
||||
|
||||
def load_or_create_secret() -> bytes:
|
||||
"""Load JWT secret from file or create a new one."""
|
||||
if SECRET_FILE.exists():
|
||||
return SECRET_FILE.read_bytes()
|
||||
else:
|
||||
# Generate a new 16-byte secret
|
||||
secret = secrets.token_bytes(16)
|
||||
SECRET_FILE.write_bytes(secret)
|
||||
return secret
|
||||
|
||||
|
||||
class JWTManager:
|
||||
"""Manages JWT tokens for user sessions."""
|
||||
|
||||
def __init__(self, secret_key: bytes, algorithm: str = "HS256"):
|
||||
self.secret_key = secret_key
|
||||
self.algorithm = algorithm
|
||||
self.token_expiry = timedelta(hours=24) # Tokens expire after 24 hours
|
||||
|
||||
def create_token(self, user_id: UUID, credential_id: bytes) -> str:
|
||||
"""
|
||||
Create a JWT token for a user session.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID
|
||||
credential_id: The credential ID used for authentication
|
||||
|
||||
Returns:
|
||||
JWT token string
|
||||
"""
|
||||
now = datetime.now()
|
||||
payload = {
|
||||
"user_id": str(user_id),
|
||||
"credential_id": credential_id.hex(),
|
||||
"iat": now,
|
||||
"exp": now + self.token_expiry,
|
||||
"iss": "passkeyauth",
|
||||
}
|
||||
|
||||
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
||||
|
||||
def validate_token(self, token: str) -> Optional[dict]:
|
||||
"""
|
||||
Validate a JWT token and return the payload.
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
|
||||
Returns:
|
||||
Dictionary with user_id and credential_id, or None if invalid
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
self.secret_key,
|
||||
algorithms=[self.algorithm],
|
||||
issuer="passkeyauth",
|
||||
)
|
||||
|
||||
return {
|
||||
"user_id": UUID(payload["user_id"]),
|
||||
"credential_id": bytes.fromhex(payload["credential_id"]),
|
||||
"issued_at": payload["iat"],
|
||||
"expires_at": payload["exp"],
|
||||
}
|
||||
except jwt.ExpiredSignatureError:
|
||||
return None
|
||||
except jwt.InvalidTokenError:
|
||||
return None
|
||||
|
||||
def refresh_token(self, token: str) -> Optional[str]:
|
||||
"""
|
||||
Refresh a JWT token if it's still valid.
|
||||
|
||||
Args:
|
||||
token: Current JWT token
|
||||
|
||||
Returns:
|
||||
New JWT token string, or None if the current token is invalid
|
||||
"""
|
||||
payload = self.validate_token(token)
|
||||
if payload is None:
|
||||
return None
|
||||
|
||||
return self.create_token(payload["user_id"], payload["credential_id"])
|
||||
|
||||
|
||||
# Global JWT manager instance
|
||||
_jwt_manager: JWTManager | None = None
|
||||
|
||||
|
||||
def get_jwt_manager() -> JWTManager:
|
||||
"""Get the global JWT manager instance."""
|
||||
global _jwt_manager
|
||||
if _jwt_manager is None:
|
||||
secret = load_or_create_secret()
|
||||
_jwt_manager = JWTManager(secret)
|
||||
return _jwt_manager # type: ignore
|
||||
|
||||
|
||||
def create_session_token(user_id: UUID, credential_id: bytes) -> str:
|
||||
"""Create a session token for a user."""
|
||||
return get_jwt_manager().create_token(user_id, credential_id)
|
||||
|
||||
|
||||
def validate_session_token(token: str) -> Optional[dict]:
|
||||
"""Validate a session token."""
|
||||
return get_jwt_manager().validate_token(token)
|
||||
|
||||
|
||||
def refresh_session_token(token: str) -> Optional[str]:
|
||||
"""Refresh a session token."""
|
||||
return get_jwt_manager().refresh_token(token)
|
||||
9
passkey/util/passphrase.py
Normal file
9
passkey/util/passphrase.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import secrets
|
||||
|
||||
from .wordlist import words
|
||||
|
||||
|
||||
def generate(n=4, sep="."):
|
||||
"""Generate a password of random words without repeating any word."""
|
||||
wl = list(words)
|
||||
return sep.join(wl.pop(secrets.randbelow(len(wl))) for i in range(n))
|
||||
54
passkey/util/wordlist.py
Normal file
54
passkey/util/wordlist.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# A custom list of 1024 common 3-6 letter words, with unique 3-prefixes and no prefix words, entropy 2.1b/letter 10b/word
|
||||
words: list = """
|
||||
able about absent abuse access acid across act adapt add adjust admit adult advice affair afraid again age agree ahead
|
||||
aim air aisle alarm album alert alien all almost alone alpha also alter always amazed among amused anchor angle animal
|
||||
ankle annual answer any apart appear april arch are argue army around array art ascent ash ask aspect assume asthma atom
|
||||
attack audit august aunt author avoid away awful axis baby back bad bag ball bamboo bank bar base battle beach become
|
||||
beef before begin behind below bench best better beyond bid bike bind bio birth bitter black bleak blind blood blue
|
||||
board body boil bomb bone book border boss bottom bounce bowl box boy brain bread bring brown brush bubble buck budget
|
||||
build bulk bundle burden bus but buyer buzz cable cache cage cake call came can car case catch cause cave celery cement
|
||||
census cereal change check child choice chunk cigar circle city civil class clean client close club coast code coffee
|
||||
coil cold come cool copy core cost cotton couch cover coyote craft cream crime cross cruel cry cube cue cult cup curve
|
||||
custom cute cycle dad damage danger daring dash dawn day deal debate decide deer define degree deity delay demand denial
|
||||
depth derive design detail device dial dice die differ dim dinner direct dish divert dizzy doctor dog dollar domain
|
||||
donate door dose double dove draft dream drive drop drum dry duck dumb dune during dust dutch dwarf eager early east
|
||||
echo eco edge edit effort egg eight either elbow elder elite else embark emerge emily employ enable end enemy engine
|
||||
enjoy enlist enough enrich ensure entire envy equal era erode error erupt escape essay estate ethics evil evoke exact
|
||||
excess exist exotic expect extent eye fabric face fade faith fall family fan far father fault feel female fence fetch
|
||||
fever few fiber field figure file find first fish fit fix flat flesh flight float fluid fly foam focus fog foil follow
|
||||
food force fossil found fox frame fresh friend frog fruit fuel fun fury future gadget gain galaxy game gap garden gas
|
||||
gate gauge gaze genius ghost giant gift giggle ginger girl give glass glide globe glue goal god gold good gospel govern
|
||||
gown grant great grid group grunt guard guess guide gulf gun gym habit hair half hammer hand happy hard hat have hawk
|
||||
hay hazard head hedge height help hen hero hidden high hill hint hip hire hobby hockey hold home honey hood hope horse
|
||||
host hotel hour hover how hub huge human hungry hurt hybrid ice icon idea idle ignore ill image immune impact income
|
||||
index infant inhale inject inmate inner input inside into invest iron island issue italy item ivory jacket jaguar james
|
||||
jar jazz jeans jelly jewel job joe joke joy judge juice july jump june just kansas kate keep kernel key kick kid kind
|
||||
kiss kit kiwi knee knife know labor lady lag lake lamp laptop large later laugh lava law layer lazy leader left legal
|
||||
lemon length lesson letter level liar libya lid life light like limit line lion liquid list little live lizard load
|
||||
local logic long loop lost loud love low loyal lucky lumber lunch lust luxury lyrics mad magic main major make male
|
||||
mammal man map market mass matter maze mccoy meadow media meet melt member men mercy mesh method middle milk mimic mind
|
||||
mirror miss mix mobile model mom monkey moon more mother mouse move much muffin mule must mutual myself myth naive name
|
||||
napkin narrow nasty nation near neck need nephew nerve nest net never news next nice night noble noise noodle normal
|
||||
nose note novel now number nurse nut oak obey object oblige obtain occur ocean odor off often oil okay old olive omit
|
||||
once one onion online open opium oppose option orange orbit order organ orient orphan other outer oval oven own oxygen
|
||||
oyster ozone pact paddle page pair palace panel paper parade past path pause pave paw pay peace pen people pepper permit
|
||||
pet philip phone phrase piano pick piece pig pilot pink pipe pistol pitch pizza place please pluck poem point polar pond
|
||||
pool post pot pound powder praise prefer price profit public pull punch pupil purity push put puzzle qatar quasi queen
|
||||
quite quoted rabbit race radio rail rally ramp range rapid rare rather raven raw razor real rebel recall red reform
|
||||
region reject relief remain rent reopen report result return review reward rhythm rib rich ride rifle right ring riot
|
||||
ripple risk ritual river road robot rocket room rose rotate round row royal rubber rude rug rule run rural sad safe sage
|
||||
sail salad same santa sauce save say scale scene school scope screen scuba sea second seed self semi sense series settle
|
||||
seven shadow she ship shock shrimp shy sick side siege sign silver simple since siren sister six size skate sketch ski
|
||||
skull slab sleep slight slogan slush small smile smooth snake sniff snow soap soccer soda soft solid son soon sort south
|
||||
space speak sphere spirit split spoil spring spy square state step still story strong stuff style submit such sudden
|
||||
suffer sugar suit summer sun supply sure swamp sweet switch sword symbol syntax syria system table tackle tag tail talk
|
||||
tank tape target task tattoo taxi team tell ten term test text that theme this three thumb tibet ticket tide tight tilt
|
||||
time tiny tip tired tissue title toast today toe toilet token tomato tone tool top torch toss total toward toy trade
|
||||
tree trial trophy true try tube tumble tunnel turn twenty twice two type ugly unable uncle under unfair unique unlock
|
||||
until unveil update uphold upon upper upset urban urge usage use usual vacuum vague valid van vapor vast vault vein
|
||||
velvet vendor very vessel viable video view villa violin virus visit vital vivid vocal voice volume vote voyage wage
|
||||
wait wall want war wash water wave way wealth web weird were west wet what when whip wide wife will window wire wish
|
||||
wolf woman wonder wood work wrap wreck write wrong xander xbox xerox xray yang yard year yellow yes yin york you zane
|
||||
zara zebra zen zero zippo zone zoo zorro zulu
|
||||
""".split()
|
||||
assert len(words) == 1024 # Exactly 10 bits of entropy per word
|
||||
Reference in New Issue
Block a user