diff --git a/API.md b/API.md index ed049ab..b9a4365 100644 --- a/API.md +++ b/API.md @@ -271,12 +271,6 @@ Register a new user with a new passkey credential. "message": "User registered successfully", "token": "string (JWT)" } - -// Error response -{ - "status": "error", - "message": "error description" -} ``` #### `WS /auth/ws/add_credential` diff --git a/passkey/db/sql.py b/passkey/db/sql.py index b21c91a..e493656 100644 --- a/passkey/db/sql.py +++ b/passkey/db/sql.py @@ -21,7 +21,7 @@ from sqlalchemy import ( select, update, ) -from sqlalchemy.dialects.sqlite import BLOB +from sqlalchemy.dialects.sqlite import BLOB, JSON from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -52,8 +52,8 @@ class UserModel(Base): class CredentialModel(Base): __tablename__ = "credentials" - - credential_id: Mapped[bytes] = mapped_column(LargeBinary(64), primary_key=True) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + credential_id: Mapped[bytes] = mapped_column(LargeBinary(64), unique=True) user_id: Mapped[bytes] = mapped_column( LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE") ) @@ -68,14 +68,18 @@ class CredentialModel(Base): user: Mapped["UserModel"] = relationship("UserModel", back_populates="credentials") -class ResetTokenModel(Base): - __tablename__ = "reset_tokens" +class SessionModel(Base): + __tablename__ = "sessions" - token: Mapped[str] = mapped_column(String(64), primary_key=True) + token: Mapped[str] = mapped_column(String(32), primary_key=True) user_id: Mapped[bytes] = mapped_column( LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE") ) + credential_id: Mapped[int | None] = mapped_column( + Integer, ForeignKey("credentials.id", ondelete="SET NULL") + ) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now) + info: Mapped[dict | None] = mapped_column(JSON, nullable=True) # Relationship to user user: Mapped["UserModel"] = relationship("UserModel") @@ -90,13 +94,6 @@ class User: visits: int = 0 -@dataclass -class ResetToken: - token: str - user_id: UUID - created_at: datetime - - # Global engine and session factory engine = create_async_engine(DB_PATH, echo=False) async_session_factory = async_sessionmaker(engine, expire_on_commit=False) @@ -243,45 +240,6 @@ class DB: await self.session.execute(stmt) await self.session.commit() - async def create_reset_token(self, user_id: UUID, token: str | None = None) -> str: - """Create a new reset token for a user.""" - if token is None: - token = secrets.token_urlsafe(32) - - reset_token_model = ResetTokenModel( - token=token, - user_id=user_id.bytes, - created_at=datetime.now(), - ) - self.session.add(reset_token_model) - await self.session.flush() - return token - - async def get_reset_token(self, token: str) -> ResetToken | None: - """Get reset token by token string.""" - stmt = select(ResetTokenModel).where(ResetTokenModel.token == token) - result = await self.session.execute(stmt) - token_model = result.scalar_one_or_none() - - if token_model: - return ResetToken( - token=token_model.token, - user_id=UUID(bytes=token_model.user_id), - created_at=token_model.created_at, - ) - return None - - async def delete_reset_token(self, token: str) -> None: - """Delete a reset token (used after successful credential addition).""" - stmt = delete(ResetTokenModel).where(ResetTokenModel.token == token) - await self.session.execute(stmt) - - async def cleanup_expired_tokens(self) -> None: - """Remove expired reset tokens (older than 24 hours).""" - expiry_time = datetime.now() - timedelta(hours=24) - stmt = delete(ResetTokenModel).where(ResetTokenModel.created_at < expiry_time) - await self.session.execute(stmt) - async def get_user_by_username(self, user_name: str) -> User | None: """Get user by username.""" stmt = select(UserModel).where(UserModel.user_name == user_name) @@ -298,6 +256,98 @@ class DB: ) return None + async def create_session( + self, + user_id: UUID, + credential_id: int | None = None, + token: str | None = None, + info: dict | None = None, + ) -> str: + """Create a new authentication session for a user. If credential_id is None, creates a session without a specific credential.""" + if token is None: + token = secrets.token_urlsafe(12) + + session_model = SessionModel( + token=token, + user_id=user_id.bytes, + credential_id=credential_id, + created_at=datetime.now(), + info=info, + ) + self.session.add(session_model) + await self.session.flush() + return token + + async def create_session_by_credential_id( + self, + user_id: UUID, + credential_id: bytes | None, + token: str | None = None, + info: dict | None = None, + ) -> str: + """Create a new authentication session for a user using WebAuthn credential ID. If credential_id is None, creates a session without a specific credential.""" + if credential_id is None: + return await self.create_session(user_id, None, token, info) + + # Get the database ID from the credential + stmt = select(CredentialModel.id).where( + CredentialModel.credential_id == credential_id + ) + result = await self.session.execute(stmt) + db_credential_id = result.scalar_one() + + return await self.create_session(user_id, db_credential_id, token, info) + + async def get_session(self, token: str) -> dict | None: + """Get session by token string.""" + stmt = select(SessionModel).where(SessionModel.token == token) + result = await self.session.execute(stmt) + session_model = result.scalar_one_or_none() + + if session_model: + # Check if session is expired (24 hours) + expiry_time = session_model.created_at + timedelta(hours=24) + if datetime.now() > expiry_time: + # Clean up expired session + await self.delete_session(token) + return None + + return { + "token": session_model.token, + "user_id": UUID(bytes=session_model.user_id), + "credential_id": session_model.credential_id, + "created_at": session_model.created_at, + "info": session_model.info or {}, + } + return None + + async def delete_session(self, token: str) -> None: + """Delete a session by token.""" + stmt = delete(SessionModel).where(SessionModel.token == token) + await self.session.execute(stmt) + + async def cleanup_expired_sessions(self) -> None: + """Remove expired sessions (older than 24 hours).""" + expiry_time = datetime.now() - timedelta(hours=24) + stmt = delete(SessionModel).where(SessionModel.created_at < expiry_time) + await self.session.execute(stmt) + + async def refresh_session(self, token: str) -> str | None: + """Refresh a session by updating its created_at timestamp.""" + session_data = await self.get_session(token) + if not session_data: + return None + + # Delete old session + await self.delete_session(token) + + # Create new session with same user and credential + return await self.create_session( + session_data["user_id"], + session_data["credential_id"], + info=session_data["info"], + ) + # Standalone functions that handle database connections internally async def init_database() -> None: @@ -358,31 +408,55 @@ async def create_new_session(user_id: UUID, credential: StoredCredential) -> Non await db.create_new_session(user_id, credential) -async def create_reset_token(user_id: UUID, token: str | None = None) -> str: - """Create a reset token for a user.""" - async with connect() as db: - return await db.create_reset_token(user_id, token) - - -async def get_reset_token(token: str) -> ResetToken | None: - """Get reset token by token string.""" - async with connect() as db: - return await db.get_reset_token(token) - - -async def delete_reset_token(token: str) -> None: - """Delete a reset token (used after successful credential addition).""" - async with connect() as db: - await db.delete_reset_token(token) - - -async def cleanup_expired_tokens() -> None: - """Remove expired reset tokens (older than 24 hours).""" - async with connect() as db: - await db.cleanup_expired_tokens() - - async def get_user_by_username(user_name: str) -> User | None: """Get user by username.""" async with connect() as db: return await db.get_user_by_username(user_name) + + +async def create_session( + user_id: UUID, + credential_id: int | None = None, + token: str | None = None, + info: dict | None = None, +) -> str: + """Create a new authentication session for a user. If credential_id is None, creates a session without a specific credential.""" + async with connect() as db: + return await db.create_session(user_id, credential_id, token, info) + + +async def create_session_by_credential_id( + user_id: UUID, + credential_id: bytes | None, + token: str | None = None, + info: dict | None = None, +) -> str: + """Create a new authentication session for a user using WebAuthn credential ID. If credential_id is None, creates a session without a specific credential.""" + async with connect() as db: + return await db.create_session_by_credential_id( + user_id, credential_id, token, info + ) + + +async def get_session(token: str) -> dict | None: + """Get session by token string.""" + async with connect() as db: + return await db.get_session(token) + + +async def delete_session(token: str) -> None: + """Delete a session by token.""" + async with connect() as db: + await db.delete_session(token) + + +async def cleanup_expired_sessions() -> None: + """Remove expired sessions (older than 24 hours).""" + async with connect() as db: + await db.cleanup_expired_sessions() + + +async def refresh_session(token: str) -> str | None: + """Refresh a session by updating its created_at timestamp.""" + async with connect() as db: + return await db.refresh_session(token) diff --git a/passkey/fastapi/api_handlers.py b/passkey/fastapi/api_handlers.py index ae66cbf..420534e 100644 --- a/passkey/fastapi/api_handlers.py +++ b/passkey/fastapi/api_handlers.py @@ -12,7 +12,7 @@ from fastapi import FastAPI, Request, Response from .. import aaguid from ..db import sql -from ..util.jwt import refresh_session_token, validate_session_token +from ..util.session import refresh_session_token, validate_session_token from .session import ( clear_session_cookie, get_current_user, @@ -37,7 +37,7 @@ def register_api_routes(app: FastAPI): current_credential_id = None session_token = get_session_token_from_cookie(request) if session_token: - token_data = validate_session_token(session_token) + token_data = await validate_session_token(session_token) if token_data: current_credential_id = token_data.get("credential_id") @@ -97,9 +97,24 @@ def register_api_routes(app: FastAPI): return {"error": f"Failed to get user info: {str(e)}"} @app.post("/auth/logout") - async def api_logout(response: Response): - """Log out the current user by clearing the session cookie.""" + async def api_logout(request: Request, response: Response): + """Log out the current user by clearing the session cookie and deleting from database.""" + # Get the session token before clearing the cookie + session_token = get_session_token_from_cookie(request) + + # Clear the cookie clear_session_cookie(response) + + # Delete the session from the database if it exists + if session_token: + from ..util.session import logout_session + + try: + await logout_session(session_token) + except Exception: + # Continue even if session deletion fails + pass + return {"status": "success", "message": "Logged out successfully"} @app.post("/auth/set-session") @@ -112,7 +127,7 @@ def register_api_routes(app: FastAPI): return {"error": "No session token provided"} # Validate the session token - token_data = validate_session_token(session_token) + token_data = await validate_session_token(session_token) if not token_data: return {"error": "Invalid or expired session token"} @@ -162,7 +177,7 @@ def register_api_routes(app: FastAPI): # Check if this is the current session credential session_token = get_session_token_from_cookie(request) if session_token: - token_data = validate_session_token(session_token) + token_data = await validate_session_token(session_token) if ( token_data and token_data.get("credential_id") == credential_id_bytes @@ -182,6 +197,67 @@ def register_api_routes(app: FastAPI): except Exception as e: return {"error": f"Failed to delete credential: {str(e)}"} + @app.get("/auth/sessions") + async def api_get_sessions(request: Request): + """Get all active sessions for the current user.""" + try: + user = await get_current_user(request) + if not user: + return {"error": "Authentication required"} + + # Get all sessions for this user + from sqlalchemy import select + + from ..db.sql import SessionModel, connect + + async with connect() as db: + stmt = select(SessionModel).where( + SessionModel.user_id == user.user_id.bytes + ) + result = await db.session.execute(stmt) + session_models = result.scalars().all() + + sessions = [] + current_token = get_session_token_from_cookie(request) + + for session in session_models: + # Check if session is expired + from datetime import datetime, timedelta + + expiry_time = session.created_at + timedelta(hours=24) + is_expired = datetime.now() > expiry_time + + sessions.append( + { + "token": session.token[:8] + + "...", # Only show first 8 chars for security + "created_at": session.created_at.isoformat(), + "client_ip": session.info.get("client_ip") + if session.info + else None, + "user_agent": session.info.get("user_agent") + if session.info + else None, + "connection_type": session.info.get( + "connection_type", "http" + ) + if session.info + else "http", + "is_current": session.token == current_token, + "is_reset_token": session.credential_id is None, + "is_expired": is_expired, + } + ) + + return { + "status": "success", + "sessions": sessions, + "total_sessions": len(sessions), + } + + except Exception as e: + return {"error": f"Failed to get sessions: {str(e)}"} + async def validate_token(request: Request, response: Response) -> dict: """Validate a session token and return user info. Also refreshes the token if valid.""" @@ -191,13 +267,13 @@ async def validate_token(request: Request, response: Response) -> dict: return {"error": "No session token found"} # Validate the session token - token_data = validate_session_token(session_token) + token_data = await validate_session_token(session_token) if not token_data: clear_session_cookie(response) return {"error": "Invalid or expired session token"} # Refresh the token if valid - new_token = refresh_session_token(session_token) + new_token = await refresh_session_token(session_token) if new_token: set_session_cookie(response, new_token) @@ -206,9 +282,10 @@ async def validate_token(request: Request, response: Response) -> dict: "valid": True, "refreshed": bool(new_token), "user_id": str(token_data["user_id"]), - "credential_id": token_data["credential_id"].hex(), - "issued_at": token_data["issued_at"], - "expires_at": token_data["expires_at"], + "credential_id": token_data["credential_id"].hex() + if token_data["credential_id"] + else None, + "created_at": token_data["created_at"].isoformat(), } except Exception as e: diff --git a/passkey/fastapi/reset_handlers.py b/passkey/fastapi/reset_handlers.py index eb5f88e..498580f 100644 --- a/passkey/fastapi/reset_handlers.py +++ b/passkey/fastapi/reset_handlers.py @@ -7,14 +7,13 @@ This module provides endpoints for authenticated users to: - Add new passkeys to existing accounts via tokens """ -from datetime import datetime, timedelta - from fastapi import FastAPI, Path, Request from fastapi.responses import RedirectResponse from ..db import sql from ..util.passphrase import generate -from .session import get_current_user +from ..util.session import get_client_info +from .session import get_current_user, is_device_addition_session, set_session_cookie def register_reset_routes(app: FastAPI): @@ -32,8 +31,9 @@ def register_reset_routes(app: FastAPI): # Generate a human-readable token token = generate(n=4, sep=".") # e.g., "able-ocean-forest-dawn" - # Create reset token in database - await sql.create_reset_token(user.user_id, token) + # Create session token in database with credential_id=None for device addition + client_info = get_client_info(request) + await sql.create_session(user.user_id, None, token, client_info) # Generate the device addition link with pretty URL addition_link = f"{request.headers.get('origin', '')}/auth/{token}" @@ -51,38 +51,34 @@ def register_reset_routes(app: FastAPI): @app.get("/auth/device-session-check") async def check_device_session(request: Request): """Check if the current session is for device addition.""" - from .session import is_device_addition_session - is_device_session = await is_device_addition_session(request) return {"device_addition_session": is_device_session} @app.get("/auth/{passphrase}") async def reset_authentication( + request: Request, passphrase: str = Path(pattern=r"^\w+(\.\w+){2,}$"), ): try: - # Get reset token to validate it exists and get user_id - reset_token = await sql.get_reset_token(passphrase) - if not reset_token: + # Get session token to validate it exists and get user_id + session_data = await sql.get_session(passphrase) + if not session_data: # Token doesn't exist, redirect to home return RedirectResponse(url="/", status_code=303) - # Check if token is expired (24 hours) - expiry_time = reset_token.created_at + timedelta(hours=24) - if datetime.now() > expiry_time: - # Token expired, clean it up and redirect to home - await sql.delete_reset_token(passphrase) + # Check if this is a device addition session (credential_id is None) + if session_data["credential_id"] is not None: + # Not a device addition session, redirect to home return RedirectResponse(url="/", status_code=303) # Create a device addition session token for the user - from ..util.jwt import create_device_addition_token - - session_token = create_device_addition_token(reset_token.user_id) + client_info = get_client_info(request) + session_token = await sql.create_session( + session_data["user_id"], None, None, client_info + ) # Create response and set session cookie response = RedirectResponse(url="/auth/", status_code=303) - from .session import set_session_cookie - set_session_cookie(response, session_token) return response @@ -95,18 +91,17 @@ def register_reset_routes(app: FastAPI): async def use_device_addition_token(token: str) -> dict: """Delete a device addition token after successful use.""" try: - # Get reset token first to validate it exists and is not expired - reset_token = await sql.get_reset_token(token) - if not reset_token: + # Get session token first to validate it exists and is not expired + session_data = await sql.get_session(token) + if not session_data: return {"error": "Invalid or expired device addition token"} - # Check if token is expired (24 hours) - expiry_time = reset_token.created_at + timedelta(hours=24) - if datetime.now() > expiry_time: - return {"error": "Device addition token has expired"} + # Check if this is a device addition session (credential_id is None) + if session_data["credential_id"] is not None: + return {"error": "Invalid device addition token"} # Delete the token (it's now used) - await sql.delete_reset_token(token) + await sql.delete_session(token) return { "status": "success", diff --git a/passkey/fastapi/session.py b/passkey/fastapi/session.py index 570ead9..9a14454 100644 --- a/passkey/fastapi/session.py +++ b/passkey/fastapi/session.py @@ -12,7 +12,7 @@ from uuid import UUID from fastapi import Request, Response from ..db.sql import User, get_user_by_id -from ..util.jwt import validate_session_token +from ..util.session import validate_session_token COOKIE_NAME = "auth" COOKIE_MAX_AGE = 86400 # 24 hours @@ -24,7 +24,7 @@ async def get_current_user(request: Request) -> User | None: if not session_token: return None - token_data = validate_session_token(session_token) + token_data = await validate_session_token(session_token) if not token_data: return None @@ -63,7 +63,7 @@ async def validate_session_from_request(request: Request) -> dict | None: if not session_token: return None - return validate_session_token(session_token) + return await validate_session_token(session_token) async def get_session_token_from_bearer(request: Request) -> str | None: @@ -91,7 +91,7 @@ async def get_user_from_cookie_string(cookie_header: str) -> UUID | None: if not session_token: return None - token_data = validate_session_token(session_token) + token_data = await validate_session_token(session_token) if not token_data: return None @@ -104,7 +104,7 @@ async def is_device_addition_session(request: Request) -> bool: if not session_token: return False - token_data = validate_session_token(session_token) + token_data = await validate_session_token(session_token) if not token_data: return False @@ -117,7 +117,7 @@ async def get_device_addition_user_id(request: Request) -> UUID | None: if not session_token: return None - token_data = validate_session_token(session_token) + token_data = await validate_session_token(session_token) if not token_data or not token_data.get("device_addition"): return None @@ -141,7 +141,7 @@ async def get_device_addition_user_id_from_cookie(cookie_header: str) -> UUID | if not session_token: return None - token_data = validate_session_token(session_token) + token_data = await validate_session_token(session_token) if not token_data or not token_data.get("device_addition"): return None diff --git a/passkey/fastapi/ws_handlers.py b/passkey/fastapi/ws_handlers.py index 1422b65..7d5bc99 100644 --- a/passkey/fastapi/ws_handlers.py +++ b/passkey/fastapi/ws_handlers.py @@ -19,7 +19,7 @@ from webauthn.helpers.exceptions import InvalidAuthenticationResponse from ..db import sql from ..db.sql import User from ..sansio import Passkey -from ..util.jwt import create_session_token +from ..util.session import create_session_token, get_client_info_from_websocket from .session import get_user_from_cookie_string # Create a FastAPI subapp for WebSocket endpoints @@ -69,7 +69,10 @@ async def websocket_register_new(ws: WebSocket, user_name: str): ) # Create a session token for the new user - session_token = create_session_token(user_id, credential.credential_id) + client_info = get_client_info_from_websocket(ws) + session_token = await create_session_token( + user_id, credential.credential_id, client_info + ) await ws.send_json( { @@ -248,8 +251,9 @@ async def websocket_authenticate(ws: WebSocket): await sql.login_user(stored_cred.user_id, stored_cred) # Create a session token for the authenticated user - session_token = create_session_token( - stored_cred.user_id, stored_cred.credential_id + client_info = get_client_info_from_websocket(ws) + session_token = await create_session_token( + stored_cred.user_id, stored_cred.credential_id, client_info ) await ws.send_json( diff --git a/passkey/util/session.py b/passkey/util/session.py new file mode 100644 index 0000000..1f84411 --- /dev/null +++ b/passkey/util/session.py @@ -0,0 +1,88 @@ +""" +Database session management for WebAuthn authentication. + +This module provides session management using database tokens instead of JWT tokens. +Session tokens are stored in the database and validated on each request. +""" + +from datetime import datetime +from typing import Optional +from uuid import UUID + +from fastapi import Request + +from ..db import sql + + +def get_client_info(request: Request) -> dict: + """Extract client information from FastAPI request and return as dict.""" + # Get client IP (handle X-Forwarded-For for proxies) + # Get user agent + return { + "client_ip": request.client.host if request.client else "", + "user_agent": request.headers.get("user-agent", "")[:500], + } + + +def get_client_info_from_websocket(ws) -> dict: + """Extract client information from WebSocket connection and return as dict.""" + # Get client IP from WebSocket + client_ip = None + if hasattr(ws, "client") and ws.client: + client_ip = ws.client.host + + # Check for forwarded headers + if hasattr(ws, "headers"): + forwarded_for = ws.headers.get("x-forwarded-for") + if forwarded_for: + client_ip = forwarded_for.split(",")[0].strip() + + # Get user agent from WebSocket headers + user_agent = None + if hasattr(ws, "headers"): + user_agent = ws.headers.get("user-agent") + # Truncate user agent if too long + if user_agent and len(user_agent) > 500: # Keep some margin + user_agent = user_agent[:500] + + return { + "client_ip": client_ip, + "user_agent": user_agent, + "timestamp": datetime.now().isoformat(), + "connection_type": "websocket", + } + + +async def create_session_token( + user_id: UUID, credential_id: bytes, info: dict | None = None +) -> str: + """Create a session token for a user.""" + return await sql.create_session_by_credential_id(user_id, credential_id, None, info) + + +async def validate_session_token(token: str) -> Optional[dict]: + """Validate a session token.""" + session_data = await sql.get_session(token) + if not session_data: + return None + + return { + "user_id": session_data["user_id"], + "credential_id": session_data["credential_id"], + "created_at": session_data["created_at"], + } + + +async def refresh_session_token(token: str) -> Optional[str]: + """Refresh a session token.""" + return await sql.refresh_session(token) + + +async def delete_session_token(token: str) -> None: + """Delete a session token.""" + await sql.delete_session(token) + + +async def logout_session(token: str) -> None: + """Log out a user by deleting their session token.""" + await sql.delete_session(token)