diff --git a/frontend/src/App.vue b/frontend/src/App.vue index 2f252b4..3b8127a 100644 --- a/frontend/src/App.vue +++ b/frontend/src/App.vue @@ -22,7 +22,12 @@ import AddCredentialView from '@/components/AddCredentialView.vue' const store = useAuthStore() onMounted(async () => { - // Check for device addition session first + // Was an error message passed in the URL? + const message = location.hash.substring(1) + if (message) { + store.showMessage(decodeURIComponent(message), 'error') + history.replaceState(null, '', location.pathname) + } try { await store.loadUserInfo() } catch (error) { diff --git a/frontend/src/components/AddCredentialView.vue b/frontend/src/components/AddCredentialView.vue index d339514..4d0a98f 100644 --- a/frontend/src/components/AddCredentialView.vue +++ b/frontend/src/components/AddCredentialView.vue @@ -15,7 +15,7 @@ diff --git a/frontend/src/stores/auth.js b/frontend/src/stores/auth.js index 869c717..9139d79 100644 --- a/frontend/src/stores/auth.js +++ b/frontend/src/stores/auth.js @@ -33,10 +33,7 @@ export const useAuthStore = defineStore('auth', { async setSessionCookie(sessionToken) { const response = await fetch('/auth/set-session', { method: 'POST', - headers: { - 'Authorization': `Bearer ${sessionToken}`, - 'Content-Type': 'application/json' - }, + headers: {'Authorization': `Bearer ${sessionToken}`}, }) const result = await response.json() if (result.error) { diff --git a/frontend/src/utils/passkey.js b/frontend/src/utils/passkey.js index 503e1ed..d97828d 100644 --- a/frontend/src/utils/passkey.js +++ b/frontend/src/utils/passkey.js @@ -22,10 +22,7 @@ export async function registerCredential() { return register('/auth/ws/add_credential') } export async function registerWithToken(token) { - return register('/auth/ws/add_device_credential', { token }) -} -export async function registerWithSession() { - return register('/auth/ws/add_device_credential_session') + return register('/auth/ws/add_credential', { token }) } export async function authenticateUser() { diff --git a/passkey/db/__init__.py b/passkey/db/__init__.py new file mode 100644 index 0000000..979644b --- /dev/null +++ b/passkey/db/__init__.py @@ -0,0 +1,50 @@ +""" +Database module for WebAuthn passkey authentication. + +This module provides dataclasses and database abstractions for managing +users, credentials, and sessions in a WebAuthn authentication system. +""" + +from dataclasses import dataclass +from datetime import datetime +from uuid import UUID + + +@dataclass +class User: + """User data structure.""" + + user_uuid: UUID + user_name: str + created_at: datetime | None = None + last_seen: datetime | None = None + visits: int = 0 + + +@dataclass +class Credential: + """Credential data structure.""" + + uuid: UUID + credential_id: bytes + user_uuid: UUID + aaguid: UUID + public_key: bytes + sign_count: int + created_at: datetime + last_used: datetime | None = None + last_verified: datetime | None = None + + +@dataclass +class Session: + """Session data structure.""" + + key: bytes + user_uuid: UUID + expires: datetime + credential_uuid: UUID | None = None + info: dict | None = None + + +__all__ = ["User", "Credential", "Session"] diff --git a/passkey/db/sql.py b/passkey/db/sql.py index fc0d4b9..3f0d671 100644 --- a/passkey/db/sql.py +++ b/passkey/db/sql.py @@ -5,10 +5,8 @@ This module provides an async database layer using SQLAlchemy async mode for managing users and credentials in a WebAuthn authentication system. """ -import secrets from contextlib import asynccontextmanager -from dataclasses import dataclass -from datetime import datetime, timedelta +from datetime import datetime from uuid import UUID from sqlalchemy import ( @@ -25,7 +23,7 @@ from sqlalchemy.dialects.sqlite import BLOB, JSON from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship -from ..sansio import StoredCredential +from . import Credential, Session, User DB_PATH = "sqlite+aiosqlite:///webauthn.db" @@ -38,7 +36,7 @@ class Base(DeclarativeBase): class UserModel(Base): __tablename__ = "users" - user_id: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True) + user_uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True) user_name: Mapped[str] = mapped_column(String, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now) last_seen: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) @@ -52,10 +50,12 @@ class UserModel(Base): class CredentialModel(Base): __tablename__ = "credentials" - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - credential_id: Mapped[bytes] = mapped_column(LargeBinary(64), unique=True) - user_id: Mapped[bytes] = mapped_column( - LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE") + uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True) + credential_id: Mapped[bytes] = mapped_column( + LargeBinary(64), unique=True, index=True + ) + user_uuid: Mapped[bytes] = mapped_column( + LargeBinary(16), ForeignKey("users.user_uuid", ondelete="CASCADE") ) aaguid: Mapped[bytes] = mapped_column(LargeBinary(16), nullable=False) public_key: Mapped[bytes] = mapped_column(BLOB, nullable=False) @@ -71,29 +71,20 @@ class CredentialModel(Base): class SessionModel(Base): __tablename__ = "sessions" - token: Mapped[str] = mapped_column(String(32), primary_key=True) - user_id: Mapped[bytes] = mapped_column( - LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE") + key: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True) + user_uuid: Mapped[bytes] = mapped_column( + LargeBinary(16), ForeignKey("users.user_uuid", ondelete="CASCADE") ) - credential_id: Mapped[int | None] = mapped_column( - Integer, ForeignKey("credentials.id", ondelete="SET NULL") + credential_uuid: Mapped[bytes | None] = mapped_column( + LargeBinary(16), ForeignKey("credentials.uuid", ondelete="CASCADE") ) - created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now) + expires: Mapped[datetime] = mapped_column(DateTime, nullable=False) info: Mapped[dict | None] = mapped_column(JSON, nullable=True) # Relationship to user user: Mapped["UserModel"] = relationship("UserModel") -@dataclass -class User: - user_id: UUID - user_name: str - created_at: datetime | None = None - last_seen: datetime | None = None - visits: int = 0 - - # Global engine and session factory engine = create_async_engine(DB_PATH, echo=False) async_session_factory = async_sessionmaker(engine, expire_on_commit=False) @@ -116,15 +107,15 @@ class DB: async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) - async def get_user_by_user_id(self, user_id: UUID) -> User: - """Get user record by WebAuthn user ID.""" - stmt = select(UserModel).where(UserModel.user_id == user_id.bytes) + async def get_user_by_user_uuid(self, user_uuid: UUID) -> User: + """Get user record by WebAuthn user UUID.""" + stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes) result = await self.session.execute(stmt) user_model = result.scalar_one_or_none() if user_model: return User( - user_id=UUID(bytes=user_model.user_id), + user_uuid=UUID(bytes=user_model.user_uuid), user_name=user_model.user_name, created_at=user_model.created_at, last_seen=user_model.last_seen, @@ -135,7 +126,7 @@ class DB: async def create_user(self, user: User) -> None: """Create a new user.""" user_model = UserModel( - user_id=user.user_id.bytes, + user_uuid=user.user_uuid.bytes, user_name=user.user_name, created_at=user.created_at or datetime.now(), last_seen=user.last_seen, @@ -144,11 +135,12 @@ class DB: self.session.add(user_model) await self.session.flush() - async def create_credential(self, credential: StoredCredential) -> None: + async def create_credential(self, credential: Credential) -> None: """Store a credential for a user.""" credential_model = CredentialModel( + uuid=credential.uuid.bytes, credential_id=credential.credential_id, - user_id=credential.user_id.bytes, + user_uuid=credential.user_uuid.bytes, aaguid=credential.aaguid.bytes, public_key=credential.public_key, sign_count=credential.sign_count, @@ -159,7 +151,7 @@ class DB: self.session.add(credential_model) await self.session.flush() - async def get_credential_by_id(self, credential_id: bytes) -> StoredCredential: + async def get_credential_by_id(self, credential_id: bytes) -> Credential: """Get credential by credential ID.""" stmt = select(CredentialModel).where( CredentialModel.credential_id == credential_id @@ -167,28 +159,29 @@ class DB: result = await self.session.execute(stmt) credential_model = result.scalar_one_or_none() - if credential_model: - return StoredCredential( - credential_id=credential_model.credential_id, - user_id=UUID(bytes=credential_model.user_id), - aaguid=UUID(bytes=credential_model.aaguid), - public_key=credential_model.public_key, - sign_count=credential_model.sign_count, - created_at=credential_model.created_at, - last_used=credential_model.last_used, - last_verified=credential_model.last_verified, - ) - raise ValueError("Credential not registered") + if not credential_model: + raise ValueError("Credential not registered") + return Credential( + uuid=UUID(bytes=credential_model.uuid), + credential_id=credential_model.credential_id, + user_uuid=UUID(bytes=credential_model.user_uuid), + aaguid=UUID(bytes=credential_model.aaguid), + public_key=credential_model.public_key, + sign_count=credential_model.sign_count, + created_at=credential_model.created_at, + last_used=credential_model.last_used, + last_verified=credential_model.last_verified, + ) - async def get_credentials_by_user_id(self, user_id: UUID) -> list[bytes]: + async def get_credentials_by_user_uuid(self, user_uuid: UUID) -> list[bytes]: """Get all credential IDs for a user.""" stmt = select(CredentialModel.credential_id).where( - CredentialModel.user_id == user_id.bytes + CredentialModel.user_uuid == user_uuid.bytes ) result = await self.session.execute(stmt) return [row[0] for row in result.fetchall()] - async def update_credential(self, credential: StoredCredential) -> None: + async def update_credential(self, credential: Credential) -> None: """Update the sign count, created_at, last_used, and last_verified for a credential.""" stmt = ( update(CredentialModel) @@ -202,7 +195,7 @@ class DB: ) await self.session.execute(stmt) - async def login(self, user_id: UUID, credential: StoredCredential) -> None: + async def login(self, user_uuid: UUID, credential: Credential) -> None: """Update the last_seen timestamp for a user and the credential record used for logging in.""" async with self.session.begin(): # Update credential @@ -211,137 +204,77 @@ class DB: # Update user's last_seen and increment visits stmt = ( update(UserModel) - .where(UserModel.user_id == user_id.bytes) + .where(UserModel.user_uuid == user_uuid.bytes) .values(last_seen=credential.last_used, visits=UserModel.visits + 1) ) await self.session.execute(stmt) - async def create_new_session( - self, user_id: UUID, credential: StoredCredential - ) -> None: - """Create a new session for a user by incrementing visits and updating last_seen.""" - async with self.session.begin(): - # Update credential - await self.update_credential(credential) - - # Update user's last_seen and increment visits - stmt = ( - update(UserModel) - .where(UserModel.user_id == user_id.bytes) - .values(last_seen=credential.last_used, visits=UserModel.visits + 1) - ) - await self.session.execute(stmt) - - async def delete_credential(self, credential_id: bytes) -> None: + async def delete_credential(self, uuid: UUID, user_uuid: UUID) -> None: """Delete a credential by its ID.""" - stmt = delete(CredentialModel).where( - CredentialModel.credential_id == credential_id + stmt = ( + delete(CredentialModel) + .where(CredentialModel.uuid == uuid.bytes) + .where(CredentialModel.user_uuid == user_uuid.bytes) ) await self.session.execute(stmt) await self.session.commit() - async def get_user_by_username(self, user_name: str) -> User | None: - """Get user by username.""" - stmt = select(UserModel).where(UserModel.user_name == user_name) - result = await self.session.execute(stmt) - user_model = result.scalar_one_or_none() - - if user_model: - return User( - user_id=UUID(bytes=user_model.user_id), - user_name=user_model.user_name, - created_at=user_model.created_at, - last_seen=user_model.last_seen, - visits=user_model.visits, - ) - return None - async def create_session( self, - user_id: UUID, - credential_id: int | None = None, - token: str | None = None, - info: dict | None = None, - ) -> str: - """Create a new authentication session for a user. If credential_id is None, creates a session without a specific credential.""" - if token is None: - token = secrets.token_urlsafe(12) - + user_uuid: UUID, + key: bytes, + expires: datetime, + info: dict, + credential_uuid: UUID | None = None, + ) -> bytes: + """Create a new authentication session for a user. If credential_uuid is None, creates a session without a specific credential.""" session_model = SessionModel( - token=token, - user_id=user_id.bytes, - credential_id=credential_id, - created_at=datetime.now(), + key=key, + user_uuid=user_uuid.bytes, + credential_uuid=credential_uuid.bytes if credential_uuid else None, + expires=expires, info=info, ) self.session.add(session_model) await self.session.flush() - return token + return key - async def create_session_by_credential_id( - self, - user_id: UUID, - credential_id: bytes | None, - token: str | None = None, - info: dict | None = None, - ) -> str: - """Create a new authentication session for a user using WebAuthn credential ID. If credential_id is None, creates a session without a specific credential.""" - if credential_id is None: - return await self.create_session(user_id, None, token, info) - - # Get the database ID from the credential - stmt = select(CredentialModel.id).where( - CredentialModel.credential_id == credential_id - ) + async def get_session(self, key: bytes) -> Session | None: + """Get session by 16-byte key.""" + stmt = select(SessionModel).where(SessionModel.key == key) result = await self.session.execute(stmt) - db_credential_id = result.scalar_one() + session_model = result.scalar_one_or_none() - return await self.create_session(user_id, db_credential_id, token, info) - - async def get_session(self, token: str) -> SessionModel | None: - """Get session by token string.""" - stmt = select(SessionModel).where(SessionModel.token == token) - result = await self.session.execute(stmt) - session = result.scalar_one_or_none() - - if session: - # Check if session is expired (24 hours) - expiry_time = session.created_at + timedelta(hours=24) - if datetime.now() > expiry_time: - # Clean up expired session - await self.delete_session(token) - return None - - return session + if session_model: + return Session( + key=session_model.key, + user_uuid=UUID(bytes=session_model.user_uuid), + credential_uuid=UUID(bytes=session_model.credential_uuid) + if session_model.credential_uuid + else None, + expires=session_model.expires, + info=session_model.info, + ) return None - async def delete_session(self, token: str) -> None: - """Delete a session by token.""" - stmt = delete(SessionModel).where(SessionModel.token == token) - await self.session.execute(stmt) + async def delete_session(self, key: bytes) -> None: + """Delete a session by 16-byte key.""" + await self.session.execute(delete(SessionModel).where(SessionModel.key == key)) + + async def update_session(self, key: bytes, expires: datetime, info: dict) -> None: + """Update session expiration time and/or info.""" + await self.session.execute( + update(SessionModel) + .where(SessionModel.key == key) + .values(expires=expires, info=info) + ) async def cleanup_expired_sessions(self) -> None: - """Remove expired sessions (older than 24 hours).""" - expiry_time = datetime.now() - timedelta(hours=24) - stmt = delete(SessionModel).where(SessionModel.created_at < expiry_time) + """Remove expired sessions.""" + current_time = datetime.now() + stmt = delete(SessionModel).where(SessionModel.expires < current_time) await self.session.execute(stmt) - async def refresh_session(self, token: str) -> str | None: - """Refresh a session by updating its created_at timestamp.""" - session = await self.get_session(token) - if not session: - return None - - # Delete old session - await self.delete_session(token) - - # Create new session with same user and credential - return await self.create_session( - user_id=UUID(bytes=session.user_id), - credential_id=session.credential_id, - info=session.info, - ) - # Standalone functions that handle database connections internally async def init_database() -> None: @@ -350,7 +283,7 @@ async def init_database() -> None: await db.init_db() -async def create_user_and_credential(user: User, credential: StoredCredential) -> None: +async def create_user_and_credential(user: User, credential: Credential) -> None: """Create a new user and their first credential in a single transaction.""" async with connect() as db: await db.session.begin() @@ -360,97 +293,73 @@ async def create_user_and_credential(user: User, credential: StoredCredential) - await db.create_credential(credential) -async def get_user_by_id(user_id: UUID) -> User: - """Get user record by WebAuthn user ID.""" +async def get_user_by_uuid(user_uuid: UUID) -> User: + """Get user record by WebAuthn user UUID.""" async with connect() as db: - return await db.get_user_by_user_id(user_id) + return await db.get_user_by_user_uuid(user_uuid) -async def create_credential_for_user(credential: StoredCredential) -> None: +async def create_credential_for_user(credential: Credential) -> None: """Store a credential for an existing user.""" async with connect() as db: await db.create_credential(credential) -async def get_credential_by_id(credential_id: bytes) -> StoredCredential: +async def get_credential_by_id(credential_id: bytes) -> Credential: """Get credential by credential ID.""" async with connect() as db: return await db.get_credential_by_id(credential_id) -async def get_user_credentials(user_id: UUID) -> list[bytes]: +async def get_user_credentials(user_uuid: UUID) -> list[bytes]: """Get all credential IDs for a user.""" async with connect() as db: - return await db.get_credentials_by_user_id(user_id) + return await db.get_credentials_by_user_uuid(user_uuid) -async def login_user(user_id: UUID, credential: StoredCredential) -> None: +async def login_user(user_uuid: UUID, credential: Credential) -> None: """Update the last_seen timestamp for a user and the credential record used for logging in.""" async with connect() as db: - await db.login(user_id, credential) + await db.login(user_uuid, credential) -async def delete_user_credential(credential_id: bytes) -> None: +async def delete_credential(uuid: UUID, user_uuid: UUID) -> None: """Delete a credential by its ID.""" async with connect() as db: - await db.delete_credential(credential_id) - - -async def create_new_session(user_id: UUID, credential: StoredCredential) -> None: - """Create a new session for a user by incrementing visits and updating last_seen.""" - async with connect() as db: - await db.create_new_session(user_id, credential) - - -async def get_user_by_username(user_name: str) -> User | None: - """Get user by username.""" - async with connect() as db: - return await db.get_user_by_username(user_name) + await db.delete_credential(uuid, user_uuid) async def create_session( - user_id: UUID, - credential_id: int | None = None, - token: str | None = None, - info: dict | None = None, -) -> str: - """Create a new authentication session for a user. If credential_id is None, creates a session without a specific credential.""" + user_uuid: UUID, + key: bytes, + expires: datetime, + info: dict, + credential_uuid: UUID | None = None, +) -> bytes: + """Create a new authentication session for a user. If credential_uuid is None, creates a session without a specific credential.""" async with connect() as db: - return await db.create_session(user_id, credential_id, token, info) + return await db.create_session(user_uuid, key, expires, info, credential_uuid) -async def create_session_by_credential_id( - user_id: UUID, - credential_id: bytes | None, - token: str | None = None, - info: dict | None = None, -) -> str: - """Create a new authentication session for a user using WebAuthn credential ID. If credential_id is None, creates a session without a specific credential.""" +async def get_session(key: bytes) -> Session | None: + """Get session by 16-byte key.""" async with connect() as db: - return await db.create_session_by_credential_id( - user_id, credential_id, token, info - ) + return await db.get_session(key) -async def get_session(token: str) -> SessionModel | None: - """Get session by token string.""" +async def delete_session(key: bytes) -> None: + """Delete a session by 16-byte key.""" async with connect() as db: - return await db.get_session(token) + await db.delete_session(key) -async def delete_session(token: str) -> None: - """Delete a session by token.""" +async def update_session(key: bytes, expires: datetime, info: dict) -> None: + """Update session expiration time and/or info.""" async with connect() as db: - await db.delete_session(token) + await db.update_session(key, expires, info) async def cleanup_expired_sessions() -> None: - """Remove expired sessions (older than 24 hours).""" + """Remove expired sessions.""" async with connect() as db: await db.cleanup_expired_sessions() - - -async def refresh_session(token: str) -> str | None: - """Refresh a session by updating its created_at timestamp.""" - async with connect() as db: - return await db.refresh_session(token) diff --git a/passkey/fastapi/api.py b/passkey/fastapi/api.py index 420534e..fe597aa 100644 --- a/passkey/fastapi/api.py +++ b/passkey/fastapi/api.py @@ -8,67 +8,67 @@ This module contains all the HTTP API endpoints for: - Login/logout functionality """ -from fastapi import FastAPI, Request, Response +from uuid import UUID + +from fastapi import Cookie, Depends, FastAPI, Request, Response +from fastapi.security import HTTPBearer from .. import aaguid from ..db import sql -from ..util.session import refresh_session_token, validate_session_token -from .session import ( - clear_session_cookie, - get_current_user, - get_session_token_from_bearer, - get_session_token_from_cookie, - set_session_cookie, -) +from ..util.tokens import session_key +from . import session + +bearer_auth = HTTPBearer(auto_error=True) def register_api_routes(app: FastAPI): """Register all API routes on the FastAPI app.""" - @app.post("/auth/user-info") - async def api_user_info(request: Request, response: Response): - """Get user information and credentials from session cookie.""" + @app.post("/auth/validate") + async def validate_token(request: Request, response: Response, auth=Cookie(None)): + """Lightweight token validation endpoint.""" try: - user = await get_current_user(request) - if not user: - return {"error": "Not authenticated"} - - # Get current session credential ID - current_credential_id = None - session_token = get_session_token_from_cookie(request) - if session_token: - token_data = await validate_session_token(session_token) - if token_data: - current_credential_id = token_data.get("credential_id") + s = await session.get_session(auth) + return { + "status": "success", + "valid": True, + "user_uuid": str(s.user_uuid), + } + except ValueError: + return {"status": "error", "valid": False} + @app.post("/auth/user-info") + async def api_user_info(request: Request, response: Response, auth=Cookie(None)): + """Get full user information for the authenticated user.""" + try: + s = await session.get_session(auth, reset_allowed=True) + u = await sql.get_user_by_uuid(s.user_uuid) # Get all credentials for the user - credential_ids = await sql.get_user_credentials(user.user_id) + credential_ids = await sql.get_user_credentials(s.user_uuid) credentials = [] user_aaguids = set() for cred_id in credential_ids: - stored_cred = await sql.get_credential_by_id(cred_id) + c = await sql.get_credential_by_id(cred_id) # Convert AAGUID to string format - aaguid_str = str(stored_cred.aaguid) + aaguid_str = str(c.aaguid) user_aaguids.add(aaguid_str) # Check if this is the current session credential - is_current_session = current_credential_id == stored_cred.credential_id + is_current_session = s.credential_uuid == c.uuid credentials.append( { - "credential_id": stored_cred.credential_id.hex(), + "credential_uuid": str(c.uuid), "aaguid": aaguid_str, - "created_at": stored_cred.created_at.isoformat(), - "last_used": stored_cred.last_used.isoformat() - if stored_cred.last_used + "created_at": c.created_at.isoformat(), + "last_used": c.last_used.isoformat() if c.last_used else None, + "last_verified": c.last_verified.isoformat() + if c.last_verified else None, - "last_verified": stored_cred.last_verified.isoformat() - if stored_cred.last_verified - else None, - "sign_count": stored_cred.sign_count, + "sign_count": c.sign_count, "is_current_session": is_current_session, } ) @@ -82,13 +82,11 @@ def register_api_routes(app: FastAPI): return { "status": "success", "user": { - "user_id": str(user.user_id), - "user_name": user.user_name, - "created_at": user.created_at.isoformat() - if user.created_at - else None, - "last_seen": user.last_seen.isoformat() if user.last_seen else None, - "visits": user.visits, + "user_uuid": str(u.user_uuid), + "user_name": u.user_name, + "created_at": u.created_at.isoformat() if u.created_at else None, + "last_seen": u.last_seen.isoformat() if u.last_seen else None, + "visits": u.visits, }, "credentials": credentials, "aaguid_info": aaguid_info, @@ -97,196 +95,44 @@ def register_api_routes(app: FastAPI): return {"error": f"Failed to get user info: {str(e)}"} @app.post("/auth/logout") - async def api_logout(request: Request, response: Response): + async def api_logout(response: Response, auth=Cookie(None)): """Log out the current user by clearing the session cookie and deleting from database.""" - # Get the session token before clearing the cookie - session_token = get_session_token_from_cookie(request) - - # Clear the cookie - clear_session_cookie(response) - - # Delete the session from the database if it exists - if session_token: - from ..util.session import logout_session - - try: - await logout_session(session_token) - except Exception: - # Continue even if session deletion fails - pass - + if not auth: + return {"status": "success", "message": "Already logged out"} + await sql.delete_session(session_key(auth)) + response.delete_cookie("auth") return {"status": "success", "message": "Logged out successfully"} @app.post("/auth/set-session") - async def api_set_session(request: Request, response: Response): - """Set session cookie using JWT token from request body or Authorization header.""" + async def api_set_session( + request: Request, response: Response, auth=Depends(bearer_auth) + ): + """Set session cookie from Authorization header. Fetched after login by WebSocket.""" try: - session_token = await get_session_token_from_bearer(request) - - if not session_token: - return {"error": "No session token provided"} - - # Validate the session token - token_data = await validate_session_token(session_token) - if not token_data: - return {"error": "Invalid or expired session token"} - - # Set the HTTP-only cookie - set_session_cookie(response, session_token) + user = await session.get_session(auth.credentials) + if not user: + raise ValueError("Invalid Authorization header.") + session.set_session_cookie(response, auth.credentials) return { "status": "success", "message": "Session cookie set successfully", - "user_id": str(token_data["user_id"]), + "user_uuid": str(user.user_uuid), } + except ValueError as e: + return {"error": str(e)} except Exception as e: return {"error": f"Failed to set session: {str(e)}"} - @app.post("/auth/delete-credential") - async def api_delete_credential(request: Request): + @app.delete("/auth/credential/{uuid}") + async def api_delete_credential(uuid: UUID, auth: str = Cookie(None)): """Delete a specific credential for the current user.""" try: - user = await get_current_user(request) - if not user: - return {"error": "Not authenticated"} - - # Get the credential ID from the request body - try: - body = await request.json() - credential_id = body.get("credential_id") - if not credential_id: - return {"error": "credential_id is required"} - except Exception: - return {"error": "Invalid request body"} - - # Convert credential_id from hex string to bytes - try: - credential_id_bytes = bytes.fromhex(credential_id) - except ValueError: - return {"error": "Invalid credential_id format"} - - # First, verify the credential belongs to the current user - try: - stored_cred = await sql.get_credential_by_id(credential_id_bytes) - if stored_cred.user_id != user.user_id: - return {"error": "Credential not found or access denied"} - except ValueError: - return {"error": "Credential not found"} - - # Check if this is the current session credential - session_token = get_session_token_from_cookie(request) - if session_token: - token_data = await validate_session_token(session_token) - if ( - token_data - and token_data.get("credential_id") == credential_id_bytes - ): - return {"error": "Cannot delete current session credential"} - - # Get user's remaining credentials count - remaining_credentials = await sql.get_user_credentials(user.user_id) - if len(remaining_credentials) <= 1: - return {"error": "Cannot delete last remaining credential"} - - # Delete the credential - await sql.delete_user_credential(credential_id_bytes) - + await session.delete_credential(uuid, auth) return {"status": "success", "message": "Credential deleted successfully"} - except Exception as e: - return {"error": f"Failed to delete credential: {str(e)}"} - - @app.get("/auth/sessions") - async def api_get_sessions(request: Request): - """Get all active sessions for the current user.""" - try: - user = await get_current_user(request) - if not user: - return {"error": "Authentication required"} - - # Get all sessions for this user - from sqlalchemy import select - - from ..db.sql import SessionModel, connect - - async with connect() as db: - stmt = select(SessionModel).where( - SessionModel.user_id == user.user_id.bytes - ) - result = await db.session.execute(stmt) - session_models = result.scalars().all() - - sessions = [] - current_token = get_session_token_from_cookie(request) - - for session in session_models: - # Check if session is expired - from datetime import datetime, timedelta - - expiry_time = session.created_at + timedelta(hours=24) - is_expired = datetime.now() > expiry_time - - sessions.append( - { - "token": session.token[:8] - + "...", # Only show first 8 chars for security - "created_at": session.created_at.isoformat(), - "client_ip": session.info.get("client_ip") - if session.info - else None, - "user_agent": session.info.get("user_agent") - if session.info - else None, - "connection_type": session.info.get( - "connection_type", "http" - ) - if session.info - else "http", - "is_current": session.token == current_token, - "is_reset_token": session.credential_id is None, - "is_expired": is_expired, - } - ) - - return { - "status": "success", - "sessions": sessions, - "total_sessions": len(sessions), - } - - except Exception as e: - return {"error": f"Failed to get sessions: {str(e)}"} - - -async def validate_token(request: Request, response: Response) -> dict: - """Validate a session token and return user info. Also refreshes the token if valid.""" - try: - session_token = get_session_token_from_cookie(request) - if not session_token: - return {"error": "No session token found"} - - # Validate the session token - token_data = await validate_session_token(session_token) - if not token_data: - clear_session_cookie(response) - return {"error": "Invalid or expired session token"} - - # Refresh the token if valid - new_token = await refresh_session_token(session_token) - if new_token: - set_session_cookie(response, new_token) - - return { - "status": "success", - "valid": True, - "refreshed": bool(new_token), - "user_id": str(token_data["user_id"]), - "credential_id": token_data["credential_id"].hex() - if token_data["credential_id"] - else None, - "created_at": token_data["created_at"].isoformat(), - } - - except Exception as e: - return {"error": f"Failed to validate token: {str(e)}"} + except ValueError as e: + return {"error": str(e)} + except Exception: + return {"error": "Failed to delete credential"} diff --git a/passkey/fastapi/main.py b/passkey/fastapi/main.py index c04b61d..2f7faa7 100644 --- a/passkey/fastapi/main.py +++ b/passkey/fastapi/main.py @@ -9,15 +9,12 @@ This module provides a simple WebAuthn implementation that: - Enables true passwordless authentication where users don't need to enter a user_name """ +import contextlib import logging from contextlib import asynccontextmanager from pathlib import Path -from fastapi import ( - FastAPI, - Request, - Response, -) +from fastapi import Cookie, FastAPI, Request, Response from fastapi.responses import ( FileResponse, JSONResponse, @@ -25,12 +22,9 @@ from fastapi.responses import ( from fastapi.staticfiles import StaticFiles from ..db import sql -from .api import ( - register_api_routes, - validate_token, -) +from . import session, ws +from .api import register_api_routes from .reset import register_reset_routes -from .ws import ws_app STATIC_DIR = Path(__file__).parent.parent / "frontend-build" @@ -44,7 +38,7 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) # Mount the WebSocket subapp -app.mount("/auth/ws", ws_app) +app.mount("/auth/ws", ws.app) # Register API routes register_api_routes(app) @@ -52,24 +46,19 @@ register_reset_routes(app) @app.get("/auth/forward-auth") -async def forward_authentication(request: Request): - """A verification endpoint to use with Caddy forward_auth or Nginx auth_request.""" - # Create a dummy response object for internal validation (we won't use it for cookies) - response = Response() +async def forward_authentication(request: Request, auth=Cookie(None)): + """A validation endpoint to use with Caddy forward_auth or Nginx auth_request.""" + with contextlib.suppress(ValueError): + s = await session.get_session(auth) + # If authenticated, return a success response + if s.info and s.info["type"] == "authenticated": + return Response(status_code=204, headers={"x-auth-user": str(s.user_uuid)}) - result = await validate_token(request, response) - if result.get("status") != "success": - # Serve the index.html of the authentication app if not authenticated - return FileResponse( - STATIC_DIR / "index.html", - status_code=401, - headers={"www-authenticate": "PrivateToken"}, - ) - - # If authenticated, return a success response - return Response( - status_code=204, - headers={"x-auth-user-id": result["user_id"]}, + # Serve the index.html of the authentication app if not authenticated + return FileResponse( + STATIC_DIR / "index.html", + status_code=401, + headers={"www-authenticate": "PrivateToken"}, ) diff --git a/passkey/fastapi/reset.py b/passkey/fastapi/reset.py index 354d82f..a5cbd15 100644 --- a/passkey/fastapi/reset.py +++ b/passkey/fastapi/reset.py @@ -1,114 +1,74 @@ -""" -Device addition API handlers for WebAuthn authentication. +import logging -This module provides endpoints for authenticated users to: -- Generate device addition links with human-readable tokens -- Validate device addition tokens -- Add new passkeys to existing accounts via tokens -""" - -from uuid import UUID - -from fastapi import FastAPI, Path, Request +from fastapi import Cookie, HTTPException, Request from fastapi.responses import RedirectResponse from ..db import sql -from ..util.passphrase import generate -from ..util.session import get_client_info -from .session import get_current_user, is_device_addition_session, set_session_cookie +from ..util import passphrase, tokens +from . import session -def register_reset_routes(app: FastAPI): +def register_reset_routes(app): """Register all device addition/reset routes on the FastAPI app.""" - @app.post("/auth/create-device-link") - async def api_create_device_link(request: Request): + @app.post("/auth/create-link") + async def api_create_link(request: Request, auth=Cookie(None)): """Create a device addition link for the authenticated user.""" try: # Require authentication - user = await get_current_user(request) - if not user: - return {"error": "Authentication required"} + s = await session.get_session(auth) # Generate a human-readable token - token = generate(n=4, sep=".") # e.g., "able-ocean-forest-dawn" - - # Create session token in database with credential_id=None for device addition - client_info = get_client_info(request) - await sql.create_session(user.user_id, None, token, client_info) + token = passphrase.generate() # e.g., "cross.rotate.yin.note.evoke" + await sql.create_session( + user_uuid=s.user_uuid, + key=tokens.reset_key(token), + expires=session.expires(), + info=session.infodict(request, "device addition"), + ) # Generate the device addition link with pretty URL - addition_link = f"{request.headers.get('origin', '')}/auth/{token}" + path = request.url.path.removesuffix("create-link") + token + url = f"{request.headers['origin']}{path}" return { "status": "success", - "message": "Device addition link generated successfully", - "addition_link": addition_link, - "expires_in_hours": 24, + "message": "Registration link generated successfully", + "url": url, + "expires": session.expires().isoformat(), } + except ValueError: + return {"error": "Authentication required"} except Exception as e: - return {"error": f"Failed to create device addition link: {str(e)}"} + return {"error": f"Failed to create registration link: {str(e)}"} - @app.get("/auth/device-session-check") - async def check_device_session(request: Request): - """Check if the current session is for device addition.""" - is_device_session = await is_device_addition_session(request) - return {"device_addition_session": is_device_session} - - @app.get("/auth/{passphrase}") + @app.get("/auth/{reset_token}") async def reset_authentication( request: Request, - passphrase: str = Path(pattern=r"^\w+(\.\w+){2,}$"), + reset_token: str, ): + """Verifies the token and redirects to auth app for credential registration.""" + # This route should only match to exact passphrases + print(f"Reset handler called with url: {request.url.path}") + if not passphrase.is_well_formed(reset_token): + raise HTTPException(status_code=404) try: # Get session token to validate it exists and get user_id - session = await sql.get_session(passphrase) - if not session: - # Token doesn't exist, redirect to home - return RedirectResponse(url="/", status_code=303) + key = tokens.reset_key(reset_token) + sess = await sql.get_session(key) + if not sess: + raise ValueError("Invalid or expired registration token") - # Check if this is a device addition session (credential_id is None) - if session.credential_id is not None: - # Not a device addition session, redirect to home - return RedirectResponse(url="/", status_code=303) - - # Create a device addition session token for the user - client_info = get_client_info(request) - session_token = await sql.create_session( - UUID(bytes=session.user_id), None, None, client_info - ) - - # Create response and set session cookie response = RedirectResponse(url="/auth/", status_code=303) - set_session_cookie(response, session_token) - + session.set_session_cookie(response, reset_token) return response - except Exception: - # On any error, redirect to home - return RedirectResponse(url="/", status_code=303) - - -async def use_reset_token(token: str) -> dict: - """Delete a device addition token after successful use.""" - try: - # Get session token first to validate it exists and is not expired - session = await sql.get_session(token) - if not session: - return {"error": "Invalid or expired device addition token"} - - # Check if this is a device addition session (credential_id is None) - if session.credential_id is not None: - return {"error": "Invalid device addition token"} - - # Delete the token (it's now used) - await sql.delete_session(token) - - return { - "status": "success", - "message": "Device addition token used successfully", - } - - except Exception as e: - return {"error": f"Failed to use device addition token: {str(e)}"} + except Exception as e: + # On any error, redirect to auth app + if isinstance(e, ValueError): + msg = str(e) + else: + logging.exception("Internal Server Error in reset_authentication") + msg = "Internal Server Error" + return RedirectResponse(url=f"/auth/#{msg}", status_code=303) diff --git a/passkey/fastapi/session.py b/passkey/fastapi/session.py index 9a14454..62e8af5 100644 --- a/passkey/fastapi/session.py +++ b/passkey/fastapi/session.py @@ -5,144 +5,85 @@ This module provides session management functionality including: - Getting current user from session cookies - Setting and clearing HTTP-only cookies - Session validation and token handling +- Device addition token management +- Device addition route handlers """ +from datetime import datetime, timedelta from uuid import UUID from fastapi import Request, Response -from ..db.sql import User, get_user_by_id -from ..util.session import validate_session_token +from ..db import Session, sql +from ..util import passphrase +from ..util.tokens import create_token, reset_key, session_key -COOKIE_NAME = "auth" -COOKIE_MAX_AGE = 86400 # 24 hours +EXPIRES = timedelta(hours=24) -async def get_current_user(request: Request) -> User | None: - """Get the current user from the session cookie.""" - session_token = request.cookies.get(COOKIE_NAME) - if not session_token: - return None - - token_data = await validate_session_token(session_token) - if not token_data: - return None - - try: - user = await get_user_by_id(token_data["user_id"]) - return user - except Exception: - return None +def expires() -> datetime: + return datetime.now() + EXPIRES -def set_session_cookie(response: Response, session_token: str) -> None: +def infodict(request: Request, type: str) -> dict: + """Extract client information from request.""" + return { + "ip": request.client.host if request.client else "", + "user_agent": request.headers.get("user-agent", "")[:500], + "type": type, + } + + +async def create_session(user_uuid: UUID, info: dict, credential_uuid: UUID) -> str: + """Create a new session and return a session token.""" + token = create_token() + await sql.create_session( + user_uuid=user_uuid, + key=session_key(token), + expires=datetime.now() + EXPIRES, + info=info, + credential_uuid=credential_uuid, + ) + return token + + +async def get_session(token: str, reset_allowed=False) -> Session: + """Validate a session token and return session data if valid.""" + if passphrase.is_well_formed(token): + if not reset_allowed: + raise ValueError("Reset link is not allowed for this endpoint") + key = reset_key(token) + else: + key = session_key(token) + + session = await sql.get_session(key) + if not session: + raise ValueError("Invalid or expired session token") + return session + + +async def refresh_session_token(token: str): + """Refresh a session extending its expiry.""" + # Get the current session + s = await sql.update_session(session_key(token), datetime.now() + EXPIRES, {}) + + if not s: + raise ValueError("Session not found or expired") + + +def set_session_cookie(response: Response, token: str) -> None: """Set the session token as an HTTP-only cookie.""" response.set_cookie( - key=COOKIE_NAME, - value=session_token, - max_age=COOKIE_MAX_AGE, + key="auth", + value=token, + max_age=int(EXPIRES.total_seconds()), httponly=True, secure=True, - samesite="lax", + path="/auth/", ) -def clear_session_cookie(response: Response) -> None: - """Clear the session cookie.""" - response.delete_cookie(key=COOKIE_NAME) - - -def get_session_token_from_cookie(request: Request) -> str | None: - """Extract session token from request cookies.""" - return request.cookies.get(COOKIE_NAME) - - -async def validate_session_from_request(request: Request) -> dict | None: - """Validate session token from request and return token data.""" - session_token = get_session_token_from_cookie(request) - if not session_token: - return None - - return await validate_session_token(session_token) - - -async def get_session_token_from_bearer(request: Request) -> str | None: - """Extract session token from Authorization header or request body.""" - # Try to get token from Authorization header first - auth_header = request.headers.get("Authorization") - if auth_header and auth_header.startswith("Bearer "): - return auth_header.removeprefix("Bearer ") - - -async def get_user_from_cookie_string(cookie_header: str) -> UUID | None: - """Parse cookie header and return user ID if valid session exists.""" - if not cookie_header: - return None - - # Parse cookies from header (simple implementation) - cookies = {} - for cookie in cookie_header.split(";"): - cookie = cookie.strip() - if "=" in cookie: - name, value = cookie.split("=", 1) - cookies[name] = value - - session_token = cookies.get(COOKIE_NAME) - if not session_token: - return None - - token_data = await validate_session_token(session_token) - if not token_data: - return None - - return token_data["user_id"] - - -async def is_device_addition_session(request: Request) -> bool: - """Check if the current session is for device addition.""" - session_token = request.cookies.get(COOKIE_NAME) - if not session_token: - return False - - token_data = await validate_session_token(session_token) - if not token_data: - return False - - return token_data.get("device_addition", False) - - -async def get_device_addition_user_id(request: Request) -> UUID | None: - """Get user ID from device addition session.""" - session_token = request.cookies.get(COOKIE_NAME) - if not session_token: - return None - - token_data = await validate_session_token(session_token) - if not token_data or not token_data.get("device_addition"): - return None - - return token_data.get("user_id") - - -async def get_device_addition_user_id_from_cookie(cookie_header: str) -> UUID | None: - """Parse cookie header and return user ID if valid device addition session exists.""" - if not cookie_header: - return None - - # Parse cookies from header (simple implementation) - cookies = {} - for cookie in cookie_header.split(";"): - cookie = cookie.strip() - if "=" in cookie: - name, value = cookie.split("=", 1) - cookies[name] = value - - session_token = cookies.get(COOKIE_NAME) - if not session_token: - return None - - token_data = await validate_session_token(session_token) - if not token_data or not token_data.get("device_addition"): - return None - - return token_data["user_id"] +async def delete_credential(credential_uuid: UUID, auth: str): + """Delete a specific credential for the current user.""" + s = await get_session(auth) + await sql.delete_credential(credential_uuid, s.user_uuid) diff --git a/passkey/fastapi/ws.py b/passkey/fastapi/ws.py index 1e5c662..e8805ad 100644 --- a/passkey/fastapi/ws.py +++ b/passkey/fastapi/ws.py @@ -13,17 +13,18 @@ from datetime import datetime from uuid import UUID import uuid7 -from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi import Cookie, FastAPI, Query, Request, WebSocket, WebSocketDisconnect from webauthn.helpers.exceptions import InvalidAuthenticationResponse -from ..db import sql -from ..db.sql import User +from passkey.fastapi import session + +from ..db import User, sql from ..sansio import Passkey -from ..util.session import create_session_token, get_client_info_from_websocket -from .session import get_user_from_cookie_string +from ..util.tokens import create_token, reset_key, session_key +from .session import create_session, infodict # Create a FastAPI subapp for WebSocket endpoints -ws_app = FastAPI() +app = FastAPI() # Initialize the passkey instance passkey = Passkey( @@ -34,51 +35,55 @@ passkey = Passkey( async def register_chat( ws: WebSocket, - user_id: UUID, + user_uuid: UUID, user_name: str, credential_ids: list[bytes] | None = None, origin: str | None = None, ): """Generate registration options and send them to the client.""" options, challenge = passkey.reg_generate_options( - user_id=user_id, + user_id=user_uuid, user_name=user_name, credential_ids=credential_ids, origin=origin, ) await ws.send_json(options) response = await ws.receive_json() - return passkey.reg_verify(response, challenge, user_id, origin=origin) + return passkey.reg_verify(response, challenge, user_uuid, origin=origin) -@ws_app.websocket("/register_new") -async def websocket_register_new(ws: WebSocket, user_name: str): +@app.websocket("/register") +async def websocket_register_new( + request: Request, ws: WebSocket, user_name: str = Query(""), auth=Cookie(None) +): """Register a new user and with a new passkey credential.""" await ws.accept() origin = ws.headers.get("origin") try: - user_id = uuid7.create() - + user_uuid = uuid7.create() # WebAuthn registration - credential = await register_chat(ws, user_id, user_name, origin=origin) + credential = await register_chat(ws, user_uuid, user_name, origin=origin) # Store the user and credential in the database await sql.create_user_and_credential( - User(user_id, user_name, created_at=datetime.now()), + User(user_uuid, user_name, created_at=datetime.now()), credential, ) - # Create a session token for the new user - client_info = get_client_info_from_websocket(ws) - session_token = await create_session_token( - user_id, credential.credential_id, client_info + token = create_token() + await sql.create_session( + user_uuid=user_uuid, + key=session_key(token), + expires=datetime.now() + session.EXPIRES, + info=infodict(request, "authenticated"), + credential_uuid=credential.uuid, ) await ws.send_json( { "status": "success", - "user_id": str(user_id), - "session_token": session_token, + "user_uuid": str(user_uuid), + "session_token": token, } ) except ValueError as e: @@ -90,28 +95,31 @@ async def websocket_register_new(ws: WebSocket, user_name: str): await ws.send_json({"error": "Internal Server Error"}) -@ws_app.websocket("/add_credential") -async def websocket_register_add(ws: WebSocket): +@app.websocket("/add_credential") +async def websocket_register_add(ws: WebSocket, token: str | None = None): """Register a new credential for an existing user.""" await ws.accept() origin = ws.headers.get("origin") try: - # Authenticate user via cookie - cookie_header = ws.headers.get("cookie", "") - user_id = await get_user_from_cookie_string(cookie_header) - - if not user_id: - await ws.send_json({"error": "Authentication required"}) + if not token: + await ws.send_json({"error": "Token is required"}) return + # If a token is provided, use it to look up the session + key = reset_key(token) + s = await sql.get_session(key) + if not s: + await ws.send_json({"error": "Invalid or expired token"}) + return + user_uuid = s.user_uuid # Get user information to get the user_name - user = await sql.get_user_by_id(user_id) + user = await sql.get_user_by_uuid(user_uuid) user_name = user.user_name - challenge_ids = await sql.get_user_credentials(user_id) + challenge_ids = await sql.get_user_credentials(user_uuid) # WebAuthn registration credential = await register_chat( - ws, user_id, user_name, challenge_ids, origin=origin + ws, user_uuid, user_name, challenge_ids, origin=origin ) # Store the new credential in the database await sql.create_credential_for_user(credential) @@ -119,7 +127,7 @@ async def websocket_register_add(ws: WebSocket): await ws.send_json( { "status": "success", - "user_id": str(user_id), + "user_uuid": str(user_uuid), "credential_id": credential.credential_id.hex(), "message": "New credential added successfully", } @@ -133,103 +141,8 @@ async def websocket_register_add(ws: WebSocket): await ws.send_json({"error": "Internal Server Error"}) -@ws_app.websocket("/add_device_credential") -async def websocket_add_device_credential(ws: WebSocket, token: str): - """Add a new credential for an existing user via device addition token.""" - await ws.accept() - origin = ws.headers.get("origin") - try: - reset_token = await sql.get_session(token) - if not reset_token: - await ws.send_json({"error": "Invalid or expired device addition token"}) - return - - # Get user information - user = await sql.get_user_by_id(reset_token.user_id) - - # WebAuthn registration - # Fetch challenge IDs for the user - challenge_ids = await sql.get_user_credentials(reset_token.user_id) - - credential = await register_chat( - ws, reset_token.user_id, user.user_name, challenge_ids, origin=origin - ) - - # Store the new credential in the database - await sql.create_credential_for_user(credential) - - # Delete the device addition token (it's now used) - await sql.delete_reset_token(token) - - await ws.send_json( - { - "status": "success", - "user_id": str(reset_token.user_id), - "credential_id": credential.credential_id.hex(), - "message": "New credential added successfully via device addition token", - } - ) - except ValueError as e: - await ws.send_json({"error": str(e)}) - except WebSocketDisconnect: - pass - except Exception: - logging.exception("Internal Server Error") - await ws.send_json({"error": "Internal Server Error"}) - - -@ws_app.websocket("/add_device_credential_session") -async def websocket_add_device_credential_session(ws: WebSocket): - """Add a new credential for an existing user via device addition session.""" - await ws.accept() - origin = ws.headers.get("origin") - try: - # Get device addition user ID from session cookie - cookie_header = ws.headers.get("cookie", "") - from .session import get_device_addition_user_id_from_cookie - - user_id = await get_device_addition_user_id_from_cookie(cookie_header) - - if not user_id: - await ws.send_json({"error": "No valid device addition session found"}) - return - - # Get user information - user = await sql.get_user_by_id(user_id) - if not user: - await ws.send_json({"error": "User not found"}) - return - - # WebAuthn registration - # Fetch challenge IDs for the user - challenge_ids = await sql.get_user_credentials(user_id) - - credential = await register_chat( - ws, user_id, user.user_name, challenge_ids, origin=origin - ) - - # Store the new credential in the database - await sql.create_credential_for_user(credential) - - await ws.send_json( - { - "status": "success", - "user_id": str(user_id), - "credential_id": credential.credential_id.hex(), - "message": "New credential added successfully via device addition session", - } - ) - except ValueError as e: - await ws.send_json({"error": str(e)}) - except WebSocketDisconnect: - pass - except Exception: - logging.exception("Internal Server Error") - await ws.send_json({"error": "Internal Server Error"}) - - -@ws_app.websocket("/authenticate") -async def websocket_authenticate(ws: WebSocket): +@app.websocket("/authenticate") +async def websocket_authenticate(request: Request, ws: WebSocket): await ws.accept() origin = ws.headers.get("origin") try: @@ -242,19 +155,21 @@ async def websocket_authenticate(ws: WebSocket): # Verify the credential matches the stored data passkey.auth_verify(credential, challenge, stored_cred, origin=origin) # Update both credential and user's last_seen timestamp - await sql.login_user(stored_cred.user_id, stored_cred) + await sql.login_user(stored_cred.user_uuid, stored_cred) # Create a session token for the authenticated user - client_info = get_client_info_from_websocket(ws) - session_token = await create_session_token( - stored_cred.user_id, stored_cred.credential_id, client_info + assert stored_cred.uuid is not None + token = await create_session( + user_uuid=stored_cred.user_uuid, + info=infodict(request, "auth"), + credential_uuid=stored_cred.uuid, ) await ws.send_json( { "status": "success", - "user_id": str(stored_cred.user_id), - "session_token": session_token, + "user_uuid": str(stored_cred.user_uuid), + "session_token": token, } ) except (ValueError, InvalidAuthenticationResponse) as e: diff --git a/passkey/sansio.py b/passkey/sansio.py index e3a85dc..d81c801 100644 --- a/passkey/sansio.py +++ b/passkey/sansio.py @@ -8,7 +8,6 @@ This module provides a unified interface for WebAuthn operations including: """ import json -from dataclasses import dataclass from datetime import datetime from uuid import UUID @@ -36,21 +35,7 @@ from webauthn.helpers.structs import ( UserVerificationRequirement, ) - -@dataclass -class StoredCredential: - """Credential data stored in the database.""" - - # Fields set only at registration time - credential_id: bytes - user_id: UUID - aaguid: UUID - public_key: bytes - # Mutable fields that may be updated during authentication - sign_count: int - created_at: datetime - last_used: datetime | None = None - last_verified: datetime | None = None +from .db import Credential class Passkey: @@ -129,7 +114,7 @@ class Passkey: expected_challenge: bytes, user_id: UUID, origin: str | None = None, - ) -> StoredCredential: + ) -> Credential: """ Verify registration response. @@ -147,7 +132,7 @@ class Passkey: expected_origin=origin or self.origin, expected_rp_id=self.rp_id, ) - return StoredCredential( + return Credential( credential_id=credential.raw_id, user_id=user_id, aaguid=UUID(registration.aaguid), @@ -195,7 +180,7 @@ class Passkey: self, credential: AuthenticationCredential, expected_challenge: bytes, - stored_cred: StoredCredential, + stored_cred: Credential, origin: str | None = None, ) -> VerifiedAuthentication: """ diff --git a/passkey/util/passphrase.py b/passkey/util/passphrase.py index 9c2b055..70ec6e5 100644 --- a/passkey/util/passphrase.py +++ b/passkey/util/passphrase.py @@ -2,8 +2,18 @@ import secrets from .wordlist import words +N_WORDS = 5 -def generate(n=4, sep="."): +wset = set(words) + + +def generate(n=N_WORDS, sep="."): """Generate a password of random words without repeating any word.""" - wl = list(words) + wl = words.copy() return sep.join(wl.pop(secrets.randbelow(len(wl))) for i in range(n)) + + +def is_well_formed(passphrase: str, n=N_WORDS, sep=".") -> bool: + """Check if the passphrase is well-formed according to the regex pattern.""" + p = passphrase.split(sep) + return len(p) == n and all(w in wset for w in passphrase.split(".")) diff --git a/passkey/util/session.py b/passkey/util/session.py deleted file mode 100644 index 1f84411..0000000 --- a/passkey/util/session.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -Database session management for WebAuthn authentication. - -This module provides session management using database tokens instead of JWT tokens. -Session tokens are stored in the database and validated on each request. -""" - -from datetime import datetime -from typing import Optional -from uuid import UUID - -from fastapi import Request - -from ..db import sql - - -def get_client_info(request: Request) -> dict: - """Extract client information from FastAPI request and return as dict.""" - # Get client IP (handle X-Forwarded-For for proxies) - # Get user agent - return { - "client_ip": request.client.host if request.client else "", - "user_agent": request.headers.get("user-agent", "")[:500], - } - - -def get_client_info_from_websocket(ws) -> dict: - """Extract client information from WebSocket connection and return as dict.""" - # Get client IP from WebSocket - client_ip = None - if hasattr(ws, "client") and ws.client: - client_ip = ws.client.host - - # Check for forwarded headers - if hasattr(ws, "headers"): - forwarded_for = ws.headers.get("x-forwarded-for") - if forwarded_for: - client_ip = forwarded_for.split(",")[0].strip() - - # Get user agent from WebSocket headers - user_agent = None - if hasattr(ws, "headers"): - user_agent = ws.headers.get("user-agent") - # Truncate user agent if too long - if user_agent and len(user_agent) > 500: # Keep some margin - user_agent = user_agent[:500] - - return { - "client_ip": client_ip, - "user_agent": user_agent, - "timestamp": datetime.now().isoformat(), - "connection_type": "websocket", - } - - -async def create_session_token( - user_id: UUID, credential_id: bytes, info: dict | None = None -) -> str: - """Create a session token for a user.""" - return await sql.create_session_by_credential_id(user_id, credential_id, None, info) - - -async def validate_session_token(token: str) -> Optional[dict]: - """Validate a session token.""" - session_data = await sql.get_session(token) - if not session_data: - return None - - return { - "user_id": session_data["user_id"], - "credential_id": session_data["credential_id"], - "created_at": session_data["created_at"], - } - - -async def refresh_session_token(token: str) -> Optional[str]: - """Refresh a session token.""" - return await sql.refresh_session(token) - - -async def delete_session_token(token: str) -> None: - """Delete a session token.""" - await sql.delete_session(token) - - -async def logout_session(token: str) -> None: - """Log out a user by deleting their session token.""" - await sql.delete_session(token) diff --git a/passkey/util/tokens.py b/passkey/util/tokens.py new file mode 100644 index 0000000..90c78c2 --- /dev/null +++ b/passkey/util/tokens.py @@ -0,0 +1,17 @@ +import base64 +import hashlib +import secrets + + +def create_token() -> str: + return secrets.token_urlsafe(12) # 16 characters Base64 + + +def session_key(token: str) -> bytes: + if len(token) != 16: + raise ValueError("Session token must be exactly 16 characters long") + return b"sess" + base64.urlsafe_b64decode(token) + + +def reset_key(passphrase: str) -> bytes: + return b"rset" + hashlib.sha512(passphrase.encode()).digest()[:12]