Database reworked simpler, JWTs replaced by sessions table and random tokens. Accessing Add device link is currently broken.
This commit is contained in:
parent
225d7b7542
commit
dc0b0f4613
6
API.md
6
API.md
@ -271,12 +271,6 @@ Register a new user with a new passkey credential.
|
|||||||
"message": "User registered successfully",
|
"message": "User registered successfully",
|
||||||
"token": "string (JWT)"
|
"token": "string (JWT)"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error response
|
|
||||||
{
|
|
||||||
"status": "error",
|
|
||||||
"message": "error description"
|
|
||||||
}
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### `WS /auth/ws/add_credential`
|
#### `WS /auth/ws/add_credential`
|
||||||
|
@ -21,7 +21,7 @@ from sqlalchemy import (
|
|||||||
select,
|
select,
|
||||||
update,
|
update,
|
||||||
)
|
)
|
||||||
from sqlalchemy.dialects.sqlite import BLOB
|
from sqlalchemy.dialects.sqlite import BLOB, JSON
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
@ -52,8 +52,8 @@ class UserModel(Base):
|
|||||||
|
|
||||||
class CredentialModel(Base):
|
class CredentialModel(Base):
|
||||||
__tablename__ = "credentials"
|
__tablename__ = "credentials"
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
credential_id: Mapped[bytes] = mapped_column(LargeBinary(64), primary_key=True)
|
credential_id: Mapped[bytes] = mapped_column(LargeBinary(64), unique=True)
|
||||||
user_id: Mapped[bytes] = mapped_column(
|
user_id: Mapped[bytes] = mapped_column(
|
||||||
LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
|
LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
|
||||||
)
|
)
|
||||||
@ -68,14 +68,18 @@ class CredentialModel(Base):
|
|||||||
user: Mapped["UserModel"] = relationship("UserModel", back_populates="credentials")
|
user: Mapped["UserModel"] = relationship("UserModel", back_populates="credentials")
|
||||||
|
|
||||||
|
|
||||||
class ResetTokenModel(Base):
|
class SessionModel(Base):
|
||||||
__tablename__ = "reset_tokens"
|
__tablename__ = "sessions"
|
||||||
|
|
||||||
token: Mapped[str] = mapped_column(String(64), primary_key=True)
|
token: Mapped[str] = mapped_column(String(32), primary_key=True)
|
||||||
user_id: Mapped[bytes] = mapped_column(
|
user_id: Mapped[bytes] = mapped_column(
|
||||||
LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
|
LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
|
||||||
)
|
)
|
||||||
|
credential_id: Mapped[int | None] = mapped_column(
|
||||||
|
Integer, ForeignKey("credentials.id", ondelete="SET NULL")
|
||||||
|
)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||||
|
info: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||||
|
|
||||||
# Relationship to user
|
# Relationship to user
|
||||||
user: Mapped["UserModel"] = relationship("UserModel")
|
user: Mapped["UserModel"] = relationship("UserModel")
|
||||||
@ -90,13 +94,6 @@ class User:
|
|||||||
visits: int = 0
|
visits: int = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ResetToken:
|
|
||||||
token: str
|
|
||||||
user_id: UUID
|
|
||||||
created_at: datetime
|
|
||||||
|
|
||||||
|
|
||||||
# Global engine and session factory
|
# Global engine and session factory
|
||||||
engine = create_async_engine(DB_PATH, echo=False)
|
engine = create_async_engine(DB_PATH, echo=False)
|
||||||
async_session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
async_session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
@ -243,45 +240,6 @@ class DB:
|
|||||||
await self.session.execute(stmt)
|
await self.session.execute(stmt)
|
||||||
await self.session.commit()
|
await self.session.commit()
|
||||||
|
|
||||||
async def create_reset_token(self, user_id: UUID, token: str | None = None) -> str:
|
|
||||||
"""Create a new reset token for a user."""
|
|
||||||
if token is None:
|
|
||||||
token = secrets.token_urlsafe(32)
|
|
||||||
|
|
||||||
reset_token_model = ResetTokenModel(
|
|
||||||
token=token,
|
|
||||||
user_id=user_id.bytes,
|
|
||||||
created_at=datetime.now(),
|
|
||||||
)
|
|
||||||
self.session.add(reset_token_model)
|
|
||||||
await self.session.flush()
|
|
||||||
return token
|
|
||||||
|
|
||||||
async def get_reset_token(self, token: str) -> ResetToken | None:
|
|
||||||
"""Get reset token by token string."""
|
|
||||||
stmt = select(ResetTokenModel).where(ResetTokenModel.token == token)
|
|
||||||
result = await self.session.execute(stmt)
|
|
||||||
token_model = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if token_model:
|
|
||||||
return ResetToken(
|
|
||||||
token=token_model.token,
|
|
||||||
user_id=UUID(bytes=token_model.user_id),
|
|
||||||
created_at=token_model.created_at,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def delete_reset_token(self, token: str) -> None:
|
|
||||||
"""Delete a reset token (used after successful credential addition)."""
|
|
||||||
stmt = delete(ResetTokenModel).where(ResetTokenModel.token == token)
|
|
||||||
await self.session.execute(stmt)
|
|
||||||
|
|
||||||
async def cleanup_expired_tokens(self) -> None:
|
|
||||||
"""Remove expired reset tokens (older than 24 hours)."""
|
|
||||||
expiry_time = datetime.now() - timedelta(hours=24)
|
|
||||||
stmt = delete(ResetTokenModel).where(ResetTokenModel.created_at < expiry_time)
|
|
||||||
await self.session.execute(stmt)
|
|
||||||
|
|
||||||
async def get_user_by_username(self, user_name: str) -> User | None:
|
async def get_user_by_username(self, user_name: str) -> User | None:
|
||||||
"""Get user by username."""
|
"""Get user by username."""
|
||||||
stmt = select(UserModel).where(UserModel.user_name == user_name)
|
stmt = select(UserModel).where(UserModel.user_name == user_name)
|
||||||
@ -298,6 +256,98 @@ class DB:
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def create_session(
|
||||||
|
self,
|
||||||
|
user_id: UUID,
|
||||||
|
credential_id: int | None = None,
|
||||||
|
token: str | None = None,
|
||||||
|
info: dict | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Create a new authentication session for a user. If credential_id is None, creates a session without a specific credential."""
|
||||||
|
if token is None:
|
||||||
|
token = secrets.token_urlsafe(12)
|
||||||
|
|
||||||
|
session_model = SessionModel(
|
||||||
|
token=token,
|
||||||
|
user_id=user_id.bytes,
|
||||||
|
credential_id=credential_id,
|
||||||
|
created_at=datetime.now(),
|
||||||
|
info=info,
|
||||||
|
)
|
||||||
|
self.session.add(session_model)
|
||||||
|
await self.session.flush()
|
||||||
|
return token
|
||||||
|
|
||||||
|
async def create_session_by_credential_id(
|
||||||
|
self,
|
||||||
|
user_id: UUID,
|
||||||
|
credential_id: bytes | None,
|
||||||
|
token: str | None = None,
|
||||||
|
info: dict | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Create a new authentication session for a user using WebAuthn credential ID. If credential_id is None, creates a session without a specific credential."""
|
||||||
|
if credential_id is None:
|
||||||
|
return await self.create_session(user_id, None, token, info)
|
||||||
|
|
||||||
|
# Get the database ID from the credential
|
||||||
|
stmt = select(CredentialModel.id).where(
|
||||||
|
CredentialModel.credential_id == credential_id
|
||||||
|
)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
db_credential_id = result.scalar_one()
|
||||||
|
|
||||||
|
return await self.create_session(user_id, db_credential_id, token, info)
|
||||||
|
|
||||||
|
async def get_session(self, token: str) -> dict | None:
|
||||||
|
"""Get session by token string."""
|
||||||
|
stmt = select(SessionModel).where(SessionModel.token == token)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
session_model = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if session_model:
|
||||||
|
# Check if session is expired (24 hours)
|
||||||
|
expiry_time = session_model.created_at + timedelta(hours=24)
|
||||||
|
if datetime.now() > expiry_time:
|
||||||
|
# Clean up expired session
|
||||||
|
await self.delete_session(token)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"token": session_model.token,
|
||||||
|
"user_id": UUID(bytes=session_model.user_id),
|
||||||
|
"credential_id": session_model.credential_id,
|
||||||
|
"created_at": session_model.created_at,
|
||||||
|
"info": session_model.info or {},
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def delete_session(self, token: str) -> None:
|
||||||
|
"""Delete a session by token."""
|
||||||
|
stmt = delete(SessionModel).where(SessionModel.token == token)
|
||||||
|
await self.session.execute(stmt)
|
||||||
|
|
||||||
|
async def cleanup_expired_sessions(self) -> None:
|
||||||
|
"""Remove expired sessions (older than 24 hours)."""
|
||||||
|
expiry_time = datetime.now() - timedelta(hours=24)
|
||||||
|
stmt = delete(SessionModel).where(SessionModel.created_at < expiry_time)
|
||||||
|
await self.session.execute(stmt)
|
||||||
|
|
||||||
|
async def refresh_session(self, token: str) -> str | None:
|
||||||
|
"""Refresh a session by updating its created_at timestamp."""
|
||||||
|
session_data = await self.get_session(token)
|
||||||
|
if not session_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Delete old session
|
||||||
|
await self.delete_session(token)
|
||||||
|
|
||||||
|
# Create new session with same user and credential
|
||||||
|
return await self.create_session(
|
||||||
|
session_data["user_id"],
|
||||||
|
session_data["credential_id"],
|
||||||
|
info=session_data["info"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Standalone functions that handle database connections internally
|
# Standalone functions that handle database connections internally
|
||||||
async def init_database() -> None:
|
async def init_database() -> None:
|
||||||
@ -358,31 +408,55 @@ async def create_new_session(user_id: UUID, credential: StoredCredential) -> Non
|
|||||||
await db.create_new_session(user_id, credential)
|
await db.create_new_session(user_id, credential)
|
||||||
|
|
||||||
|
|
||||||
async def create_reset_token(user_id: UUID, token: str | None = None) -> str:
|
|
||||||
"""Create a reset token for a user."""
|
|
||||||
async with connect() as db:
|
|
||||||
return await db.create_reset_token(user_id, token)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_reset_token(token: str) -> ResetToken | None:
|
|
||||||
"""Get reset token by token string."""
|
|
||||||
async with connect() as db:
|
|
||||||
return await db.get_reset_token(token)
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_reset_token(token: str) -> None:
|
|
||||||
"""Delete a reset token (used after successful credential addition)."""
|
|
||||||
async with connect() as db:
|
|
||||||
await db.delete_reset_token(token)
|
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_expired_tokens() -> None:
|
|
||||||
"""Remove expired reset tokens (older than 24 hours)."""
|
|
||||||
async with connect() as db:
|
|
||||||
await db.cleanup_expired_tokens()
|
|
||||||
|
|
||||||
|
|
||||||
async def get_user_by_username(user_name: str) -> User | None:
|
async def get_user_by_username(user_name: str) -> User | None:
|
||||||
"""Get user by username."""
|
"""Get user by username."""
|
||||||
async with connect() as db:
|
async with connect() as db:
|
||||||
return await db.get_user_by_username(user_name)
|
return await db.get_user_by_username(user_name)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_session(
|
||||||
|
user_id: UUID,
|
||||||
|
credential_id: int | None = None,
|
||||||
|
token: str | None = None,
|
||||||
|
info: dict | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Create a new authentication session for a user. If credential_id is None, creates a session without a specific credential."""
|
||||||
|
async with connect() as db:
|
||||||
|
return await db.create_session(user_id, credential_id, token, info)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_session_by_credential_id(
|
||||||
|
user_id: UUID,
|
||||||
|
credential_id: bytes | None,
|
||||||
|
token: str | None = None,
|
||||||
|
info: dict | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Create a new authentication session for a user using WebAuthn credential ID. If credential_id is None, creates a session without a specific credential."""
|
||||||
|
async with connect() as db:
|
||||||
|
return await db.create_session_by_credential_id(
|
||||||
|
user_id, credential_id, token, info
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session(token: str) -> dict | None:
|
||||||
|
"""Get session by token string."""
|
||||||
|
async with connect() as db:
|
||||||
|
return await db.get_session(token)
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_session(token: str) -> None:
|
||||||
|
"""Delete a session by token."""
|
||||||
|
async with connect() as db:
|
||||||
|
await db.delete_session(token)
|
||||||
|
|
||||||
|
|
||||||
|
async def cleanup_expired_sessions() -> None:
|
||||||
|
"""Remove expired sessions (older than 24 hours)."""
|
||||||
|
async with connect() as db:
|
||||||
|
await db.cleanup_expired_sessions()
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_session(token: str) -> str | None:
|
||||||
|
"""Refresh a session by updating its created_at timestamp."""
|
||||||
|
async with connect() as db:
|
||||||
|
return await db.refresh_session(token)
|
||||||
|
@ -12,7 +12,7 @@ from fastapi import FastAPI, Request, Response
|
|||||||
|
|
||||||
from .. import aaguid
|
from .. import aaguid
|
||||||
from ..db import sql
|
from ..db import sql
|
||||||
from ..util.jwt import refresh_session_token, validate_session_token
|
from ..util.session import refresh_session_token, validate_session_token
|
||||||
from .session import (
|
from .session import (
|
||||||
clear_session_cookie,
|
clear_session_cookie,
|
||||||
get_current_user,
|
get_current_user,
|
||||||
@ -37,7 +37,7 @@ def register_api_routes(app: FastAPI):
|
|||||||
current_credential_id = None
|
current_credential_id = None
|
||||||
session_token = get_session_token_from_cookie(request)
|
session_token = get_session_token_from_cookie(request)
|
||||||
if session_token:
|
if session_token:
|
||||||
token_data = validate_session_token(session_token)
|
token_data = await validate_session_token(session_token)
|
||||||
if token_data:
|
if token_data:
|
||||||
current_credential_id = token_data.get("credential_id")
|
current_credential_id = token_data.get("credential_id")
|
||||||
|
|
||||||
@ -97,9 +97,24 @@ def register_api_routes(app: FastAPI):
|
|||||||
return {"error": f"Failed to get user info: {str(e)}"}
|
return {"error": f"Failed to get user info: {str(e)}"}
|
||||||
|
|
||||||
@app.post("/auth/logout")
|
@app.post("/auth/logout")
|
||||||
async def api_logout(response: Response):
|
async def api_logout(request: Request, response: Response):
|
||||||
"""Log out the current user by clearing the session cookie."""
|
"""Log out the current user by clearing the session cookie and deleting from database."""
|
||||||
|
# Get the session token before clearing the cookie
|
||||||
|
session_token = get_session_token_from_cookie(request)
|
||||||
|
|
||||||
|
# Clear the cookie
|
||||||
clear_session_cookie(response)
|
clear_session_cookie(response)
|
||||||
|
|
||||||
|
# Delete the session from the database if it exists
|
||||||
|
if session_token:
|
||||||
|
from ..util.session import logout_session
|
||||||
|
|
||||||
|
try:
|
||||||
|
await logout_session(session_token)
|
||||||
|
except Exception:
|
||||||
|
# Continue even if session deletion fails
|
||||||
|
pass
|
||||||
|
|
||||||
return {"status": "success", "message": "Logged out successfully"}
|
return {"status": "success", "message": "Logged out successfully"}
|
||||||
|
|
||||||
@app.post("/auth/set-session")
|
@app.post("/auth/set-session")
|
||||||
@ -112,7 +127,7 @@ def register_api_routes(app: FastAPI):
|
|||||||
return {"error": "No session token provided"}
|
return {"error": "No session token provided"}
|
||||||
|
|
||||||
# Validate the session token
|
# Validate the session token
|
||||||
token_data = validate_session_token(session_token)
|
token_data = await validate_session_token(session_token)
|
||||||
if not token_data:
|
if not token_data:
|
||||||
return {"error": "Invalid or expired session token"}
|
return {"error": "Invalid or expired session token"}
|
||||||
|
|
||||||
@ -162,7 +177,7 @@ def register_api_routes(app: FastAPI):
|
|||||||
# Check if this is the current session credential
|
# Check if this is the current session credential
|
||||||
session_token = get_session_token_from_cookie(request)
|
session_token = get_session_token_from_cookie(request)
|
||||||
if session_token:
|
if session_token:
|
||||||
token_data = validate_session_token(session_token)
|
token_data = await validate_session_token(session_token)
|
||||||
if (
|
if (
|
||||||
token_data
|
token_data
|
||||||
and token_data.get("credential_id") == credential_id_bytes
|
and token_data.get("credential_id") == credential_id_bytes
|
||||||
@ -182,6 +197,67 @@ def register_api_routes(app: FastAPI):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"error": f"Failed to delete credential: {str(e)}"}
|
return {"error": f"Failed to delete credential: {str(e)}"}
|
||||||
|
|
||||||
|
@app.get("/auth/sessions")
|
||||||
|
async def api_get_sessions(request: Request):
|
||||||
|
"""Get all active sessions for the current user."""
|
||||||
|
try:
|
||||||
|
user = await get_current_user(request)
|
||||||
|
if not user:
|
||||||
|
return {"error": "Authentication required"}
|
||||||
|
|
||||||
|
# Get all sessions for this user
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from ..db.sql import SessionModel, connect
|
||||||
|
|
||||||
|
async with connect() as db:
|
||||||
|
stmt = select(SessionModel).where(
|
||||||
|
SessionModel.user_id == user.user_id.bytes
|
||||||
|
)
|
||||||
|
result = await db.session.execute(stmt)
|
||||||
|
session_models = result.scalars().all()
|
||||||
|
|
||||||
|
sessions = []
|
||||||
|
current_token = get_session_token_from_cookie(request)
|
||||||
|
|
||||||
|
for session in session_models:
|
||||||
|
# Check if session is expired
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
expiry_time = session.created_at + timedelta(hours=24)
|
||||||
|
is_expired = datetime.now() > expiry_time
|
||||||
|
|
||||||
|
sessions.append(
|
||||||
|
{
|
||||||
|
"token": session.token[:8]
|
||||||
|
+ "...", # Only show first 8 chars for security
|
||||||
|
"created_at": session.created_at.isoformat(),
|
||||||
|
"client_ip": session.info.get("client_ip")
|
||||||
|
if session.info
|
||||||
|
else None,
|
||||||
|
"user_agent": session.info.get("user_agent")
|
||||||
|
if session.info
|
||||||
|
else None,
|
||||||
|
"connection_type": session.info.get(
|
||||||
|
"connection_type", "http"
|
||||||
|
)
|
||||||
|
if session.info
|
||||||
|
else "http",
|
||||||
|
"is_current": session.token == current_token,
|
||||||
|
"is_reset_token": session.credential_id is None,
|
||||||
|
"is_expired": is_expired,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"sessions": sessions,
|
||||||
|
"total_sessions": len(sessions),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"Failed to get sessions: {str(e)}"}
|
||||||
|
|
||||||
|
|
||||||
async def validate_token(request: Request, response: Response) -> dict:
|
async def validate_token(request: Request, response: Response) -> dict:
|
||||||
"""Validate a session token and return user info. Also refreshes the token if valid."""
|
"""Validate a session token and return user info. Also refreshes the token if valid."""
|
||||||
@ -191,13 +267,13 @@ async def validate_token(request: Request, response: Response) -> dict:
|
|||||||
return {"error": "No session token found"}
|
return {"error": "No session token found"}
|
||||||
|
|
||||||
# Validate the session token
|
# Validate the session token
|
||||||
token_data = validate_session_token(session_token)
|
token_data = await validate_session_token(session_token)
|
||||||
if not token_data:
|
if not token_data:
|
||||||
clear_session_cookie(response)
|
clear_session_cookie(response)
|
||||||
return {"error": "Invalid or expired session token"}
|
return {"error": "Invalid or expired session token"}
|
||||||
|
|
||||||
# Refresh the token if valid
|
# Refresh the token if valid
|
||||||
new_token = refresh_session_token(session_token)
|
new_token = await refresh_session_token(session_token)
|
||||||
if new_token:
|
if new_token:
|
||||||
set_session_cookie(response, new_token)
|
set_session_cookie(response, new_token)
|
||||||
|
|
||||||
@ -206,9 +282,10 @@ async def validate_token(request: Request, response: Response) -> dict:
|
|||||||
"valid": True,
|
"valid": True,
|
||||||
"refreshed": bool(new_token),
|
"refreshed": bool(new_token),
|
||||||
"user_id": str(token_data["user_id"]),
|
"user_id": str(token_data["user_id"]),
|
||||||
"credential_id": token_data["credential_id"].hex(),
|
"credential_id": token_data["credential_id"].hex()
|
||||||
"issued_at": token_data["issued_at"],
|
if token_data["credential_id"]
|
||||||
"expires_at": token_data["expires_at"],
|
else None,
|
||||||
|
"created_at": token_data["created_at"].isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -7,14 +7,13 @@ This module provides endpoints for authenticated users to:
|
|||||||
- Add new passkeys to existing accounts via tokens
|
- Add new passkeys to existing accounts via tokens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
from fastapi import FastAPI, Path, Request
|
from fastapi import FastAPI, Path, Request
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
from ..db import sql
|
from ..db import sql
|
||||||
from ..util.passphrase import generate
|
from ..util.passphrase import generate
|
||||||
from .session import get_current_user
|
from ..util.session import get_client_info
|
||||||
|
from .session import get_current_user, is_device_addition_session, set_session_cookie
|
||||||
|
|
||||||
|
|
||||||
def register_reset_routes(app: FastAPI):
|
def register_reset_routes(app: FastAPI):
|
||||||
@ -32,8 +31,9 @@ def register_reset_routes(app: FastAPI):
|
|||||||
# Generate a human-readable token
|
# Generate a human-readable token
|
||||||
token = generate(n=4, sep=".") # e.g., "able-ocean-forest-dawn"
|
token = generate(n=4, sep=".") # e.g., "able-ocean-forest-dawn"
|
||||||
|
|
||||||
# Create reset token in database
|
# Create session token in database with credential_id=None for device addition
|
||||||
await sql.create_reset_token(user.user_id, token)
|
client_info = get_client_info(request)
|
||||||
|
await sql.create_session(user.user_id, None, token, client_info)
|
||||||
|
|
||||||
# Generate the device addition link with pretty URL
|
# Generate the device addition link with pretty URL
|
||||||
addition_link = f"{request.headers.get('origin', '')}/auth/{token}"
|
addition_link = f"{request.headers.get('origin', '')}/auth/{token}"
|
||||||
@ -51,38 +51,34 @@ def register_reset_routes(app: FastAPI):
|
|||||||
@app.get("/auth/device-session-check")
|
@app.get("/auth/device-session-check")
|
||||||
async def check_device_session(request: Request):
|
async def check_device_session(request: Request):
|
||||||
"""Check if the current session is for device addition."""
|
"""Check if the current session is for device addition."""
|
||||||
from .session import is_device_addition_session
|
|
||||||
|
|
||||||
is_device_session = await is_device_addition_session(request)
|
is_device_session = await is_device_addition_session(request)
|
||||||
return {"device_addition_session": is_device_session}
|
return {"device_addition_session": is_device_session}
|
||||||
|
|
||||||
@app.get("/auth/{passphrase}")
|
@app.get("/auth/{passphrase}")
|
||||||
async def reset_authentication(
|
async def reset_authentication(
|
||||||
|
request: Request,
|
||||||
passphrase: str = Path(pattern=r"^\w+(\.\w+){2,}$"),
|
passphrase: str = Path(pattern=r"^\w+(\.\w+){2,}$"),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
# Get reset token to validate it exists and get user_id
|
# Get session token to validate it exists and get user_id
|
||||||
reset_token = await sql.get_reset_token(passphrase)
|
session_data = await sql.get_session(passphrase)
|
||||||
if not reset_token:
|
if not session_data:
|
||||||
# Token doesn't exist, redirect to home
|
# Token doesn't exist, redirect to home
|
||||||
return RedirectResponse(url="/", status_code=303)
|
return RedirectResponse(url="/", status_code=303)
|
||||||
|
|
||||||
# Check if token is expired (24 hours)
|
# Check if this is a device addition session (credential_id is None)
|
||||||
expiry_time = reset_token.created_at + timedelta(hours=24)
|
if session_data["credential_id"] is not None:
|
||||||
if datetime.now() > expiry_time:
|
# Not a device addition session, redirect to home
|
||||||
# Token expired, clean it up and redirect to home
|
|
||||||
await sql.delete_reset_token(passphrase)
|
|
||||||
return RedirectResponse(url="/", status_code=303)
|
return RedirectResponse(url="/", status_code=303)
|
||||||
|
|
||||||
# Create a device addition session token for the user
|
# Create a device addition session token for the user
|
||||||
from ..util.jwt import create_device_addition_token
|
client_info = get_client_info(request)
|
||||||
|
session_token = await sql.create_session(
|
||||||
session_token = create_device_addition_token(reset_token.user_id)
|
session_data["user_id"], None, None, client_info
|
||||||
|
)
|
||||||
|
|
||||||
# Create response and set session cookie
|
# Create response and set session cookie
|
||||||
response = RedirectResponse(url="/auth/", status_code=303)
|
response = RedirectResponse(url="/auth/", status_code=303)
|
||||||
from .session import set_session_cookie
|
|
||||||
|
|
||||||
set_session_cookie(response, session_token)
|
set_session_cookie(response, session_token)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
@ -95,18 +91,17 @@ def register_reset_routes(app: FastAPI):
|
|||||||
async def use_device_addition_token(token: str) -> dict:
|
async def use_device_addition_token(token: str) -> dict:
|
||||||
"""Delete a device addition token after successful use."""
|
"""Delete a device addition token after successful use."""
|
||||||
try:
|
try:
|
||||||
# Get reset token first to validate it exists and is not expired
|
# Get session token first to validate it exists and is not expired
|
||||||
reset_token = await sql.get_reset_token(token)
|
session_data = await sql.get_session(token)
|
||||||
if not reset_token:
|
if not session_data:
|
||||||
return {"error": "Invalid or expired device addition token"}
|
return {"error": "Invalid or expired device addition token"}
|
||||||
|
|
||||||
# Check if token is expired (24 hours)
|
# Check if this is a device addition session (credential_id is None)
|
||||||
expiry_time = reset_token.created_at + timedelta(hours=24)
|
if session_data["credential_id"] is not None:
|
||||||
if datetime.now() > expiry_time:
|
return {"error": "Invalid device addition token"}
|
||||||
return {"error": "Device addition token has expired"}
|
|
||||||
|
|
||||||
# Delete the token (it's now used)
|
# Delete the token (it's now used)
|
||||||
await sql.delete_reset_token(token)
|
await sql.delete_session(token)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
|
@ -12,7 +12,7 @@ from uuid import UUID
|
|||||||
from fastapi import Request, Response
|
from fastapi import Request, Response
|
||||||
|
|
||||||
from ..db.sql import User, get_user_by_id
|
from ..db.sql import User, get_user_by_id
|
||||||
from ..util.jwt import validate_session_token
|
from ..util.session import validate_session_token
|
||||||
|
|
||||||
COOKIE_NAME = "auth"
|
COOKIE_NAME = "auth"
|
||||||
COOKIE_MAX_AGE = 86400 # 24 hours
|
COOKIE_MAX_AGE = 86400 # 24 hours
|
||||||
@ -24,7 +24,7 @@ async def get_current_user(request: Request) -> User | None:
|
|||||||
if not session_token:
|
if not session_token:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
token_data = validate_session_token(session_token)
|
token_data = await validate_session_token(session_token)
|
||||||
if not token_data:
|
if not token_data:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ async def validate_session_from_request(request: Request) -> dict | None:
|
|||||||
if not session_token:
|
if not session_token:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return validate_session_token(session_token)
|
return await validate_session_token(session_token)
|
||||||
|
|
||||||
|
|
||||||
async def get_session_token_from_bearer(request: Request) -> str | None:
|
async def get_session_token_from_bearer(request: Request) -> str | None:
|
||||||
@ -91,7 +91,7 @@ async def get_user_from_cookie_string(cookie_header: str) -> UUID | None:
|
|||||||
if not session_token:
|
if not session_token:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
token_data = validate_session_token(session_token)
|
token_data = await validate_session_token(session_token)
|
||||||
if not token_data:
|
if not token_data:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ async def is_device_addition_session(request: Request) -> bool:
|
|||||||
if not session_token:
|
if not session_token:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
token_data = validate_session_token(session_token)
|
token_data = await validate_session_token(session_token)
|
||||||
if not token_data:
|
if not token_data:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -117,7 +117,7 @@ async def get_device_addition_user_id(request: Request) -> UUID | None:
|
|||||||
if not session_token:
|
if not session_token:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
token_data = validate_session_token(session_token)
|
token_data = await validate_session_token(session_token)
|
||||||
if not token_data or not token_data.get("device_addition"):
|
if not token_data or not token_data.get("device_addition"):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -141,7 +141,7 @@ async def get_device_addition_user_id_from_cookie(cookie_header: str) -> UUID |
|
|||||||
if not session_token:
|
if not session_token:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
token_data = validate_session_token(session_token)
|
token_data = await validate_session_token(session_token)
|
||||||
if not token_data or not token_data.get("device_addition"):
|
if not token_data or not token_data.get("device_addition"):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ from webauthn.helpers.exceptions import InvalidAuthenticationResponse
|
|||||||
from ..db import sql
|
from ..db import sql
|
||||||
from ..db.sql import User
|
from ..db.sql import User
|
||||||
from ..sansio import Passkey
|
from ..sansio import Passkey
|
||||||
from ..util.jwt import create_session_token
|
from ..util.session import create_session_token, get_client_info_from_websocket
|
||||||
from .session import get_user_from_cookie_string
|
from .session import get_user_from_cookie_string
|
||||||
|
|
||||||
# Create a FastAPI subapp for WebSocket endpoints
|
# Create a FastAPI subapp for WebSocket endpoints
|
||||||
@ -69,7 +69,10 @@ async def websocket_register_new(ws: WebSocket, user_name: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create a session token for the new user
|
# Create a session token for the new user
|
||||||
session_token = create_session_token(user_id, credential.credential_id)
|
client_info = get_client_info_from_websocket(ws)
|
||||||
|
session_token = await create_session_token(
|
||||||
|
user_id, credential.credential_id, client_info
|
||||||
|
)
|
||||||
|
|
||||||
await ws.send_json(
|
await ws.send_json(
|
||||||
{
|
{
|
||||||
@ -248,8 +251,9 @@ async def websocket_authenticate(ws: WebSocket):
|
|||||||
await sql.login_user(stored_cred.user_id, stored_cred)
|
await sql.login_user(stored_cred.user_id, stored_cred)
|
||||||
|
|
||||||
# Create a session token for the authenticated user
|
# Create a session token for the authenticated user
|
||||||
session_token = create_session_token(
|
client_info = get_client_info_from_websocket(ws)
|
||||||
stored_cred.user_id, stored_cred.credential_id
|
session_token = await create_session_token(
|
||||||
|
stored_cred.user_id, stored_cred.credential_id, client_info
|
||||||
)
|
)
|
||||||
|
|
||||||
await ws.send_json(
|
await ws.send_json(
|
||||||
|
88
passkey/util/session.py
Normal file
88
passkey/util/session.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
"""
|
||||||
|
Database session management for WebAuthn authentication.
|
||||||
|
|
||||||
|
This module provides session management using database tokens instead of JWT tokens.
|
||||||
|
Session tokens are stored in the database and validated on each request.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from ..db import sql
|
||||||
|
|
||||||
|
|
||||||
|
def get_client_info(request: Request) -> dict:
|
||||||
|
"""Extract client information from FastAPI request and return as dict."""
|
||||||
|
# Get client IP (handle X-Forwarded-For for proxies)
|
||||||
|
# Get user agent
|
||||||
|
return {
|
||||||
|
"client_ip": request.client.host if request.client else "",
|
||||||
|
"user_agent": request.headers.get("user-agent", "")[:500],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_client_info_from_websocket(ws) -> dict:
|
||||||
|
"""Extract client information from WebSocket connection and return as dict."""
|
||||||
|
# Get client IP from WebSocket
|
||||||
|
client_ip = None
|
||||||
|
if hasattr(ws, "client") and ws.client:
|
||||||
|
client_ip = ws.client.host
|
||||||
|
|
||||||
|
# Check for forwarded headers
|
||||||
|
if hasattr(ws, "headers"):
|
||||||
|
forwarded_for = ws.headers.get("x-forwarded-for")
|
||||||
|
if forwarded_for:
|
||||||
|
client_ip = forwarded_for.split(",")[0].strip()
|
||||||
|
|
||||||
|
# Get user agent from WebSocket headers
|
||||||
|
user_agent = None
|
||||||
|
if hasattr(ws, "headers"):
|
||||||
|
user_agent = ws.headers.get("user-agent")
|
||||||
|
# Truncate user agent if too long
|
||||||
|
if user_agent and len(user_agent) > 500: # Keep some margin
|
||||||
|
user_agent = user_agent[:500]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"client_ip": client_ip,
|
||||||
|
"user_agent": user_agent,
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"connection_type": "websocket",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def create_session_token(
|
||||||
|
user_id: UUID, credential_id: bytes, info: dict | None = None
|
||||||
|
) -> str:
|
||||||
|
"""Create a session token for a user."""
|
||||||
|
return await sql.create_session_by_credential_id(user_id, credential_id, None, info)
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_session_token(token: str) -> Optional[dict]:
|
||||||
|
"""Validate a session token."""
|
||||||
|
session_data = await sql.get_session(token)
|
||||||
|
if not session_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"user_id": session_data["user_id"],
|
||||||
|
"credential_id": session_data["credential_id"],
|
||||||
|
"created_at": session_data["created_at"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_session_token(token: str) -> Optional[str]:
|
||||||
|
"""Refresh a session token."""
|
||||||
|
return await sql.refresh_session(token)
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_session_token(token: str) -> None:
|
||||||
|
"""Delete a session token."""
|
||||||
|
await sql.delete_session(token)
|
||||||
|
|
||||||
|
|
||||||
|
async def logout_session(token: str) -> None:
|
||||||
|
"""Log out a user by deleting their session token."""
|
||||||
|
await sql.delete_session(token)
|
Loading…
x
Reference in New Issue
Block a user