Database reworked simpler, JWTs replaced by sessions table and random tokens. Accessing Add device link is currently broken.

This commit is contained in:
Leo Vasanko 2025-07-14 17:29:48 -06:00
parent 225d7b7542
commit dc0b0f4613
7 changed files with 364 additions and 132 deletions

6
API.md
View File

@ -271,12 +271,6 @@ Register a new user with a new passkey credential.
"message": "User registered successfully",
"token": "string (JWT)"
}
// Error response
{
"status": "error",
"message": "error description"
}
```
#### `WS /auth/ws/add_credential`

View File

@ -21,7 +21,7 @@ from sqlalchemy import (
select,
update,
)
from sqlalchemy.dialects.sqlite import BLOB
from sqlalchemy.dialects.sqlite import BLOB, JSON
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
@ -52,8 +52,8 @@ class UserModel(Base):
class CredentialModel(Base):
__tablename__ = "credentials"
credential_id: Mapped[bytes] = mapped_column(LargeBinary(64), primary_key=True)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
credential_id: Mapped[bytes] = mapped_column(LargeBinary(64), unique=True)
user_id: Mapped[bytes] = mapped_column(
LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
)
@ -68,14 +68,18 @@ class CredentialModel(Base):
user: Mapped["UserModel"] = relationship("UserModel", back_populates="credentials")
class ResetTokenModel(Base):
__tablename__ = "reset_tokens"
class SessionModel(Base):
__tablename__ = "sessions"
token: Mapped[str] = mapped_column(String(64), primary_key=True)
token: Mapped[str] = mapped_column(String(32), primary_key=True)
user_id: Mapped[bytes] = mapped_column(
LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
)
credential_id: Mapped[int | None] = mapped_column(
Integer, ForeignKey("credentials.id", ondelete="SET NULL")
)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
info: Mapped[dict | None] = mapped_column(JSON, nullable=True)
# Relationship to user
user: Mapped["UserModel"] = relationship("UserModel")
@ -90,13 +94,6 @@ class User:
visits: int = 0
@dataclass
class ResetToken:
token: str
user_id: UUID
created_at: datetime
# Global engine and session factory
engine = create_async_engine(DB_PATH, echo=False)
async_session_factory = async_sessionmaker(engine, expire_on_commit=False)
@ -243,45 +240,6 @@ class DB:
await self.session.execute(stmt)
await self.session.commit()
async def create_reset_token(self, user_id: UUID, token: str | None = None) -> str:
"""Create a new reset token for a user."""
if token is None:
token = secrets.token_urlsafe(32)
reset_token_model = ResetTokenModel(
token=token,
user_id=user_id.bytes,
created_at=datetime.now(),
)
self.session.add(reset_token_model)
await self.session.flush()
return token
async def get_reset_token(self, token: str) -> ResetToken | None:
"""Get reset token by token string."""
stmt = select(ResetTokenModel).where(ResetTokenModel.token == token)
result = await self.session.execute(stmt)
token_model = result.scalar_one_or_none()
if token_model:
return ResetToken(
token=token_model.token,
user_id=UUID(bytes=token_model.user_id),
created_at=token_model.created_at,
)
return None
async def delete_reset_token(self, token: str) -> None:
"""Delete a reset token (used after successful credential addition)."""
stmt = delete(ResetTokenModel).where(ResetTokenModel.token == token)
await self.session.execute(stmt)
async def cleanup_expired_tokens(self) -> None:
"""Remove expired reset tokens (older than 24 hours)."""
expiry_time = datetime.now() - timedelta(hours=24)
stmt = delete(ResetTokenModel).where(ResetTokenModel.created_at < expiry_time)
await self.session.execute(stmt)
async def get_user_by_username(self, user_name: str) -> User | None:
"""Get user by username."""
stmt = select(UserModel).where(UserModel.user_name == user_name)
@ -298,6 +256,98 @@ class DB:
)
return None
async def create_session(
self,
user_id: UUID,
credential_id: int | None = None,
token: str | None = None,
info: dict | None = None,
) -> str:
"""Create a new authentication session for a user. If credential_id is None, creates a session without a specific credential."""
if token is None:
token = secrets.token_urlsafe(12)
session_model = SessionModel(
token=token,
user_id=user_id.bytes,
credential_id=credential_id,
created_at=datetime.now(),
info=info,
)
self.session.add(session_model)
await self.session.flush()
return token
async def create_session_by_credential_id(
self,
user_id: UUID,
credential_id: bytes | None,
token: str | None = None,
info: dict | None = None,
) -> str:
"""Create a new authentication session for a user using WebAuthn credential ID. If credential_id is None, creates a session without a specific credential."""
if credential_id is None:
return await self.create_session(user_id, None, token, info)
# Get the database ID from the credential
stmt = select(CredentialModel.id).where(
CredentialModel.credential_id == credential_id
)
result = await self.session.execute(stmt)
db_credential_id = result.scalar_one()
return await self.create_session(user_id, db_credential_id, token, info)
async def get_session(self, token: str) -> dict | None:
"""Get session by token string."""
stmt = select(SessionModel).where(SessionModel.token == token)
result = await self.session.execute(stmt)
session_model = result.scalar_one_or_none()
if session_model:
# Check if session is expired (24 hours)
expiry_time = session_model.created_at + timedelta(hours=24)
if datetime.now() > expiry_time:
# Clean up expired session
await self.delete_session(token)
return None
return {
"token": session_model.token,
"user_id": UUID(bytes=session_model.user_id),
"credential_id": session_model.credential_id,
"created_at": session_model.created_at,
"info": session_model.info or {},
}
return None
async def delete_session(self, token: str) -> None:
"""Delete a session by token."""
stmt = delete(SessionModel).where(SessionModel.token == token)
await self.session.execute(stmt)
async def cleanup_expired_sessions(self) -> None:
"""Remove expired sessions (older than 24 hours)."""
expiry_time = datetime.now() - timedelta(hours=24)
stmt = delete(SessionModel).where(SessionModel.created_at < expiry_time)
await self.session.execute(stmt)
async def refresh_session(self, token: str) -> str | None:
"""Refresh a session by updating its created_at timestamp."""
session_data = await self.get_session(token)
if not session_data:
return None
# Delete old session
await self.delete_session(token)
# Create new session with same user and credential
return await self.create_session(
session_data["user_id"],
session_data["credential_id"],
info=session_data["info"],
)
# Standalone functions that handle database connections internally
async def init_database() -> None:
@ -358,31 +408,55 @@ async def create_new_session(user_id: UUID, credential: StoredCredential) -> Non
await db.create_new_session(user_id, credential)
async def create_reset_token(user_id: UUID, token: str | None = None) -> str:
"""Create a reset token for a user."""
async with connect() as db:
return await db.create_reset_token(user_id, token)
async def get_reset_token(token: str) -> ResetToken | None:
"""Get reset token by token string."""
async with connect() as db:
return await db.get_reset_token(token)
async def delete_reset_token(token: str) -> None:
"""Delete a reset token (used after successful credential addition)."""
async with connect() as db:
await db.delete_reset_token(token)
async def cleanup_expired_tokens() -> None:
"""Remove expired reset tokens (older than 24 hours)."""
async with connect() as db:
await db.cleanup_expired_tokens()
async def get_user_by_username(user_name: str) -> User | None:
"""Get user by username."""
async with connect() as db:
return await db.get_user_by_username(user_name)
async def create_session(
user_id: UUID,
credential_id: int | None = None,
token: str | None = None,
info: dict | None = None,
) -> str:
"""Create a new authentication session for a user. If credential_id is None, creates a session without a specific credential."""
async with connect() as db:
return await db.create_session(user_id, credential_id, token, info)
async def create_session_by_credential_id(
user_id: UUID,
credential_id: bytes | None,
token: str | None = None,
info: dict | None = None,
) -> str:
"""Create a new authentication session for a user using WebAuthn credential ID. If credential_id is None, creates a session without a specific credential."""
async with connect() as db:
return await db.create_session_by_credential_id(
user_id, credential_id, token, info
)
async def get_session(token: str) -> dict | None:
"""Get session by token string."""
async with connect() as db:
return await db.get_session(token)
async def delete_session(token: str) -> None:
"""Delete a session by token."""
async with connect() as db:
await db.delete_session(token)
async def cleanup_expired_sessions() -> None:
"""Remove expired sessions (older than 24 hours)."""
async with connect() as db:
await db.cleanup_expired_sessions()
async def refresh_session(token: str) -> str | None:
"""Refresh a session by updating its created_at timestamp."""
async with connect() as db:
return await db.refresh_session(token)

View File

@ -12,7 +12,7 @@ from fastapi import FastAPI, Request, Response
from .. import aaguid
from ..db import sql
from ..util.jwt import refresh_session_token, validate_session_token
from ..util.session import refresh_session_token, validate_session_token
from .session import (
clear_session_cookie,
get_current_user,
@ -37,7 +37,7 @@ def register_api_routes(app: FastAPI):
current_credential_id = None
session_token = get_session_token_from_cookie(request)
if session_token:
token_data = validate_session_token(session_token)
token_data = await validate_session_token(session_token)
if token_data:
current_credential_id = token_data.get("credential_id")
@ -97,9 +97,24 @@ def register_api_routes(app: FastAPI):
return {"error": f"Failed to get user info: {str(e)}"}
@app.post("/auth/logout")
async def api_logout(response: Response):
"""Log out the current user by clearing the session cookie."""
async def api_logout(request: Request, response: Response):
"""Log out the current user by clearing the session cookie and deleting from database."""
# Get the session token before clearing the cookie
session_token = get_session_token_from_cookie(request)
# Clear the cookie
clear_session_cookie(response)
# Delete the session from the database if it exists
if session_token:
from ..util.session import logout_session
try:
await logout_session(session_token)
except Exception:
# Continue even if session deletion fails
pass
return {"status": "success", "message": "Logged out successfully"}
@app.post("/auth/set-session")
@ -112,7 +127,7 @@ def register_api_routes(app: FastAPI):
return {"error": "No session token provided"}
# Validate the session token
token_data = validate_session_token(session_token)
token_data = await validate_session_token(session_token)
if not token_data:
return {"error": "Invalid or expired session token"}
@ -162,7 +177,7 @@ def register_api_routes(app: FastAPI):
# Check if this is the current session credential
session_token = get_session_token_from_cookie(request)
if session_token:
token_data = validate_session_token(session_token)
token_data = await validate_session_token(session_token)
if (
token_data
and token_data.get("credential_id") == credential_id_bytes
@ -182,6 +197,67 @@ def register_api_routes(app: FastAPI):
except Exception as e:
return {"error": f"Failed to delete credential: {str(e)}"}
@app.get("/auth/sessions")
async def api_get_sessions(request: Request):
"""Get all active sessions for the current user."""
try:
user = await get_current_user(request)
if not user:
return {"error": "Authentication required"}
# Get all sessions for this user
from sqlalchemy import select
from ..db.sql import SessionModel, connect
async with connect() as db:
stmt = select(SessionModel).where(
SessionModel.user_id == user.user_id.bytes
)
result = await db.session.execute(stmt)
session_models = result.scalars().all()
sessions = []
current_token = get_session_token_from_cookie(request)
for session in session_models:
# Check if session is expired
from datetime import datetime, timedelta
expiry_time = session.created_at + timedelta(hours=24)
is_expired = datetime.now() > expiry_time
sessions.append(
{
"token": session.token[:8]
+ "...", # Only show first 8 chars for security
"created_at": session.created_at.isoformat(),
"client_ip": session.info.get("client_ip")
if session.info
else None,
"user_agent": session.info.get("user_agent")
if session.info
else None,
"connection_type": session.info.get(
"connection_type", "http"
)
if session.info
else "http",
"is_current": session.token == current_token,
"is_reset_token": session.credential_id is None,
"is_expired": is_expired,
}
)
return {
"status": "success",
"sessions": sessions,
"total_sessions": len(sessions),
}
except Exception as e:
return {"error": f"Failed to get sessions: {str(e)}"}
async def validate_token(request: Request, response: Response) -> dict:
"""Validate a session token and return user info. Also refreshes the token if valid."""
@ -191,13 +267,13 @@ async def validate_token(request: Request, response: Response) -> dict:
return {"error": "No session token found"}
# Validate the session token
token_data = validate_session_token(session_token)
token_data = await validate_session_token(session_token)
if not token_data:
clear_session_cookie(response)
return {"error": "Invalid or expired session token"}
# Refresh the token if valid
new_token = refresh_session_token(session_token)
new_token = await refresh_session_token(session_token)
if new_token:
set_session_cookie(response, new_token)
@ -206,9 +282,10 @@ async def validate_token(request: Request, response: Response) -> dict:
"valid": True,
"refreshed": bool(new_token),
"user_id": str(token_data["user_id"]),
"credential_id": token_data["credential_id"].hex(),
"issued_at": token_data["issued_at"],
"expires_at": token_data["expires_at"],
"credential_id": token_data["credential_id"].hex()
if token_data["credential_id"]
else None,
"created_at": token_data["created_at"].isoformat(),
}
except Exception as e:

View File

@ -7,14 +7,13 @@ This module provides endpoints for authenticated users to:
- Add new passkeys to existing accounts via tokens
"""
from datetime import datetime, timedelta
from fastapi import FastAPI, Path, Request
from fastapi.responses import RedirectResponse
from ..db import sql
from ..util.passphrase import generate
from .session import get_current_user
from ..util.session import get_client_info
from .session import get_current_user, is_device_addition_session, set_session_cookie
def register_reset_routes(app: FastAPI):
@ -32,8 +31,9 @@ def register_reset_routes(app: FastAPI):
# Generate a human-readable token
token = generate(n=4, sep=".") # e.g., "able-ocean-forest-dawn"
# Create reset token in database
await sql.create_reset_token(user.user_id, token)
# Create session token in database with credential_id=None for device addition
client_info = get_client_info(request)
await sql.create_session(user.user_id, None, token, client_info)
# Generate the device addition link with pretty URL
addition_link = f"{request.headers.get('origin', '')}/auth/{token}"
@ -51,38 +51,34 @@ def register_reset_routes(app: FastAPI):
@app.get("/auth/device-session-check")
async def check_device_session(request: Request):
"""Check if the current session is for device addition."""
from .session import is_device_addition_session
is_device_session = await is_device_addition_session(request)
return {"device_addition_session": is_device_session}
@app.get("/auth/{passphrase}")
async def reset_authentication(
request: Request,
passphrase: str = Path(pattern=r"^\w+(\.\w+){2,}$"),
):
try:
# Get reset token to validate it exists and get user_id
reset_token = await sql.get_reset_token(passphrase)
if not reset_token:
# Get session token to validate it exists and get user_id
session_data = await sql.get_session(passphrase)
if not session_data:
# Token doesn't exist, redirect to home
return RedirectResponse(url="/", status_code=303)
# Check if token is expired (24 hours)
expiry_time = reset_token.created_at + timedelta(hours=24)
if datetime.now() > expiry_time:
# Token expired, clean it up and redirect to home
await sql.delete_reset_token(passphrase)
# Check if this is a device addition session (credential_id is None)
if session_data["credential_id"] is not None:
# Not a device addition session, redirect to home
return RedirectResponse(url="/", status_code=303)
# Create a device addition session token for the user
from ..util.jwt import create_device_addition_token
session_token = create_device_addition_token(reset_token.user_id)
client_info = get_client_info(request)
session_token = await sql.create_session(
session_data["user_id"], None, None, client_info
)
# Create response and set session cookie
response = RedirectResponse(url="/auth/", status_code=303)
from .session import set_session_cookie
set_session_cookie(response, session_token)
return response
@ -95,18 +91,17 @@ def register_reset_routes(app: FastAPI):
async def use_device_addition_token(token: str) -> dict:
"""Delete a device addition token after successful use."""
try:
# Get reset token first to validate it exists and is not expired
reset_token = await sql.get_reset_token(token)
if not reset_token:
# Get session token first to validate it exists and is not expired
session_data = await sql.get_session(token)
if not session_data:
return {"error": "Invalid or expired device addition token"}
# Check if token is expired (24 hours)
expiry_time = reset_token.created_at + timedelta(hours=24)
if datetime.now() > expiry_time:
return {"error": "Device addition token has expired"}
# Check if this is a device addition session (credential_id is None)
if session_data["credential_id"] is not None:
return {"error": "Invalid device addition token"}
# Delete the token (it's now used)
await sql.delete_reset_token(token)
await sql.delete_session(token)
return {
"status": "success",

View File

@ -12,7 +12,7 @@ from uuid import UUID
from fastapi import Request, Response
from ..db.sql import User, get_user_by_id
from ..util.jwt import validate_session_token
from ..util.session import validate_session_token
COOKIE_NAME = "auth"
COOKIE_MAX_AGE = 86400 # 24 hours
@ -24,7 +24,7 @@ async def get_current_user(request: Request) -> User | None:
if not session_token:
return None
token_data = validate_session_token(session_token)
token_data = await validate_session_token(session_token)
if not token_data:
return None
@ -63,7 +63,7 @@ async def validate_session_from_request(request: Request) -> dict | None:
if not session_token:
return None
return validate_session_token(session_token)
return await validate_session_token(session_token)
async def get_session_token_from_bearer(request: Request) -> str | None:
@ -91,7 +91,7 @@ async def get_user_from_cookie_string(cookie_header: str) -> UUID | None:
if not session_token:
return None
token_data = validate_session_token(session_token)
token_data = await validate_session_token(session_token)
if not token_data:
return None
@ -104,7 +104,7 @@ async def is_device_addition_session(request: Request) -> bool:
if not session_token:
return False
token_data = validate_session_token(session_token)
token_data = await validate_session_token(session_token)
if not token_data:
return False
@ -117,7 +117,7 @@ async def get_device_addition_user_id(request: Request) -> UUID | None:
if not session_token:
return None
token_data = validate_session_token(session_token)
token_data = await validate_session_token(session_token)
if not token_data or not token_data.get("device_addition"):
return None
@ -141,7 +141,7 @@ async def get_device_addition_user_id_from_cookie(cookie_header: str) -> UUID |
if not session_token:
return None
token_data = validate_session_token(session_token)
token_data = await validate_session_token(session_token)
if not token_data or not token_data.get("device_addition"):
return None

View File

@ -19,7 +19,7 @@ from webauthn.helpers.exceptions import InvalidAuthenticationResponse
from ..db import sql
from ..db.sql import User
from ..sansio import Passkey
from ..util.jwt import create_session_token
from ..util.session import create_session_token, get_client_info_from_websocket
from .session import get_user_from_cookie_string
# Create a FastAPI subapp for WebSocket endpoints
@ -69,7 +69,10 @@ async def websocket_register_new(ws: WebSocket, user_name: str):
)
# Create a session token for the new user
session_token = create_session_token(user_id, credential.credential_id)
client_info = get_client_info_from_websocket(ws)
session_token = await create_session_token(
user_id, credential.credential_id, client_info
)
await ws.send_json(
{
@ -248,8 +251,9 @@ async def websocket_authenticate(ws: WebSocket):
await sql.login_user(stored_cred.user_id, stored_cred)
# Create a session token for the authenticated user
session_token = create_session_token(
stored_cred.user_id, stored_cred.credential_id
client_info = get_client_info_from_websocket(ws)
session_token = await create_session_token(
stored_cred.user_id, stored_cred.credential_id, client_info
)
await ws.send_json(

88
passkey/util/session.py Normal file
View File

@ -0,0 +1,88 @@
"""
Database session management for WebAuthn authentication.
This module provides session management using database tokens instead of JWT tokens.
Session tokens are stored in the database and validated on each request.
"""
from datetime import datetime
from typing import Optional
from uuid import UUID
from fastapi import Request
from ..db import sql
def get_client_info(request: Request) -> dict:
"""Extract client information from FastAPI request and return as dict."""
# Get client IP (handle X-Forwarded-For for proxies)
# Get user agent
return {
"client_ip": request.client.host if request.client else "",
"user_agent": request.headers.get("user-agent", "")[:500],
}
def get_client_info_from_websocket(ws) -> dict:
"""Extract client information from WebSocket connection and return as dict."""
# Get client IP from WebSocket
client_ip = None
if hasattr(ws, "client") and ws.client:
client_ip = ws.client.host
# Check for forwarded headers
if hasattr(ws, "headers"):
forwarded_for = ws.headers.get("x-forwarded-for")
if forwarded_for:
client_ip = forwarded_for.split(",")[0].strip()
# Get user agent from WebSocket headers
user_agent = None
if hasattr(ws, "headers"):
user_agent = ws.headers.get("user-agent")
# Truncate user agent if too long
if user_agent and len(user_agent) > 500: # Keep some margin
user_agent = user_agent[:500]
return {
"client_ip": client_ip,
"user_agent": user_agent,
"timestamp": datetime.now().isoformat(),
"connection_type": "websocket",
}
async def create_session_token(
user_id: UUID, credential_id: bytes, info: dict | None = None
) -> str:
"""Create a session token for a user."""
return await sql.create_session_by_credential_id(user_id, credential_id, None, info)
async def validate_session_token(token: str) -> Optional[dict]:
"""Validate a session token."""
session_data = await sql.get_session(token)
if not session_data:
return None
return {
"user_id": session_data["user_id"],
"credential_id": session_data["credential_id"],
"created_at": session_data["created_at"],
}
async def refresh_session_token(token: str) -> Optional[str]:
"""Refresh a session token."""
return await sql.refresh_session(token)
async def delete_session_token(token: str) -> None:
"""Delete a session token."""
await sql.delete_session(token)
async def logout_session(token: str) -> None:
"""Log out a user by deleting their session token."""
await sql.delete_session(token)