Database reworked simpler, JWTs replaced by sessions table and random tokens. Accessing Add device link is currently broken.
This commit is contained in:
@@ -21,7 +21,7 @@ from sqlalchemy import (
|
||||
select,
|
||||
update,
|
||||
)
|
||||
from sqlalchemy.dialects.sqlite import BLOB
|
||||
from sqlalchemy.dialects.sqlite import BLOB, JSON
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
@@ -52,8 +52,8 @@ class UserModel(Base):
|
||||
|
||||
class CredentialModel(Base):
|
||||
__tablename__ = "credentials"
|
||||
|
||||
credential_id: Mapped[bytes] = mapped_column(LargeBinary(64), primary_key=True)
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
credential_id: Mapped[bytes] = mapped_column(LargeBinary(64), unique=True)
|
||||
user_id: Mapped[bytes] = mapped_column(
|
||||
LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
|
||||
)
|
||||
@@ -68,14 +68,18 @@ class CredentialModel(Base):
|
||||
user: Mapped["UserModel"] = relationship("UserModel", back_populates="credentials")
|
||||
|
||||
|
||||
class ResetTokenModel(Base):
|
||||
__tablename__ = "reset_tokens"
|
||||
class SessionModel(Base):
|
||||
__tablename__ = "sessions"
|
||||
|
||||
token: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
token: Mapped[str] = mapped_column(String(32), primary_key=True)
|
||||
user_id: Mapped[bytes] = mapped_column(
|
||||
LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
|
||||
)
|
||||
credential_id: Mapped[int | None] = mapped_column(
|
||||
Integer, ForeignKey("credentials.id", ondelete="SET NULL")
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||
info: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
|
||||
# Relationship to user
|
||||
user: Mapped["UserModel"] = relationship("UserModel")
|
||||
@@ -90,13 +94,6 @@ class User:
|
||||
visits: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetToken:
|
||||
token: str
|
||||
user_id: UUID
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# Global engine and session factory
|
||||
engine = create_async_engine(DB_PATH, echo=False)
|
||||
async_session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
@@ -243,45 +240,6 @@ class DB:
|
||||
await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
|
||||
async def create_reset_token(self, user_id: UUID, token: str | None = None) -> str:
|
||||
"""Create a new reset token for a user."""
|
||||
if token is None:
|
||||
token = secrets.token_urlsafe(32)
|
||||
|
||||
reset_token_model = ResetTokenModel(
|
||||
token=token,
|
||||
user_id=user_id.bytes,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
self.session.add(reset_token_model)
|
||||
await self.session.flush()
|
||||
return token
|
||||
|
||||
async def get_reset_token(self, token: str) -> ResetToken | None:
|
||||
"""Get reset token by token string."""
|
||||
stmt = select(ResetTokenModel).where(ResetTokenModel.token == token)
|
||||
result = await self.session.execute(stmt)
|
||||
token_model = result.scalar_one_or_none()
|
||||
|
||||
if token_model:
|
||||
return ResetToken(
|
||||
token=token_model.token,
|
||||
user_id=UUID(bytes=token_model.user_id),
|
||||
created_at=token_model.created_at,
|
||||
)
|
||||
return None
|
||||
|
||||
async def delete_reset_token(self, token: str) -> None:
|
||||
"""Delete a reset token (used after successful credential addition)."""
|
||||
stmt = delete(ResetTokenModel).where(ResetTokenModel.token == token)
|
||||
await self.session.execute(stmt)
|
||||
|
||||
async def cleanup_expired_tokens(self) -> None:
|
||||
"""Remove expired reset tokens (older than 24 hours)."""
|
||||
expiry_time = datetime.now() - timedelta(hours=24)
|
||||
stmt = delete(ResetTokenModel).where(ResetTokenModel.created_at < expiry_time)
|
||||
await self.session.execute(stmt)
|
||||
|
||||
async def get_user_by_username(self, user_name: str) -> User | None:
|
||||
"""Get user by username."""
|
||||
stmt = select(UserModel).where(UserModel.user_name == user_name)
|
||||
@@ -298,6 +256,98 @@ class DB:
|
||||
)
|
||||
return None
|
||||
|
||||
async def create_session(
|
||||
self,
|
||||
user_id: UUID,
|
||||
credential_id: int | None = None,
|
||||
token: str | None = None,
|
||||
info: dict | None = None,
|
||||
) -> str:
|
||||
"""Create a new authentication session for a user. If credential_id is None, creates a session without a specific credential."""
|
||||
if token is None:
|
||||
token = secrets.token_urlsafe(12)
|
||||
|
||||
session_model = SessionModel(
|
||||
token=token,
|
||||
user_id=user_id.bytes,
|
||||
credential_id=credential_id,
|
||||
created_at=datetime.now(),
|
||||
info=info,
|
||||
)
|
||||
self.session.add(session_model)
|
||||
await self.session.flush()
|
||||
return token
|
||||
|
||||
async def create_session_by_credential_id(
|
||||
self,
|
||||
user_id: UUID,
|
||||
credential_id: bytes | None,
|
||||
token: str | None = None,
|
||||
info: dict | None = None,
|
||||
) -> str:
|
||||
"""Create a new authentication session for a user using WebAuthn credential ID. If credential_id is None, creates a session without a specific credential."""
|
||||
if credential_id is None:
|
||||
return await self.create_session(user_id, None, token, info)
|
||||
|
||||
# Get the database ID from the credential
|
||||
stmt = select(CredentialModel.id).where(
|
||||
CredentialModel.credential_id == credential_id
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
db_credential_id = result.scalar_one()
|
||||
|
||||
return await self.create_session(user_id, db_credential_id, token, info)
|
||||
|
||||
async def get_session(self, token: str) -> dict | None:
|
||||
"""Get session by token string."""
|
||||
stmt = select(SessionModel).where(SessionModel.token == token)
|
||||
result = await self.session.execute(stmt)
|
||||
session_model = result.scalar_one_or_none()
|
||||
|
||||
if session_model:
|
||||
# Check if session is expired (24 hours)
|
||||
expiry_time = session_model.created_at + timedelta(hours=24)
|
||||
if datetime.now() > expiry_time:
|
||||
# Clean up expired session
|
||||
await self.delete_session(token)
|
||||
return None
|
||||
|
||||
return {
|
||||
"token": session_model.token,
|
||||
"user_id": UUID(bytes=session_model.user_id),
|
||||
"credential_id": session_model.credential_id,
|
||||
"created_at": session_model.created_at,
|
||||
"info": session_model.info or {},
|
||||
}
|
||||
return None
|
||||
|
||||
async def delete_session(self, token: str) -> None:
|
||||
"""Delete a session by token."""
|
||||
stmt = delete(SessionModel).where(SessionModel.token == token)
|
||||
await self.session.execute(stmt)
|
||||
|
||||
async def cleanup_expired_sessions(self) -> None:
|
||||
"""Remove expired sessions (older than 24 hours)."""
|
||||
expiry_time = datetime.now() - timedelta(hours=24)
|
||||
stmt = delete(SessionModel).where(SessionModel.created_at < expiry_time)
|
||||
await self.session.execute(stmt)
|
||||
|
||||
async def refresh_session(self, token: str) -> str | None:
|
||||
"""Refresh a session by updating its created_at timestamp."""
|
||||
session_data = await self.get_session(token)
|
||||
if not session_data:
|
||||
return None
|
||||
|
||||
# Delete old session
|
||||
await self.delete_session(token)
|
||||
|
||||
# Create new session with same user and credential
|
||||
return await self.create_session(
|
||||
session_data["user_id"],
|
||||
session_data["credential_id"],
|
||||
info=session_data["info"],
|
||||
)
|
||||
|
||||
|
||||
# Standalone functions that handle database connections internally
|
||||
async def init_database() -> None:
|
||||
@@ -358,31 +408,55 @@ async def create_new_session(user_id: UUID, credential: StoredCredential) -> Non
|
||||
await db.create_new_session(user_id, credential)
|
||||
|
||||
|
||||
async def create_reset_token(user_id: UUID, token: str | None = None) -> str:
|
||||
"""Create a reset token for a user."""
|
||||
async with connect() as db:
|
||||
return await db.create_reset_token(user_id, token)
|
||||
|
||||
|
||||
async def get_reset_token(token: str) -> ResetToken | None:
|
||||
"""Get reset token by token string."""
|
||||
async with connect() as db:
|
||||
return await db.get_reset_token(token)
|
||||
|
||||
|
||||
async def delete_reset_token(token: str) -> None:
|
||||
"""Delete a reset token (used after successful credential addition)."""
|
||||
async with connect() as db:
|
||||
await db.delete_reset_token(token)
|
||||
|
||||
|
||||
async def cleanup_expired_tokens() -> None:
|
||||
"""Remove expired reset tokens (older than 24 hours)."""
|
||||
async with connect() as db:
|
||||
await db.cleanup_expired_tokens()
|
||||
|
||||
|
||||
async def get_user_by_username(user_name: str) -> User | None:
|
||||
"""Get user by username."""
|
||||
async with connect() as db:
|
||||
return await db.get_user_by_username(user_name)
|
||||
|
||||
|
||||
async def create_session(
|
||||
user_id: UUID,
|
||||
credential_id: int | None = None,
|
||||
token: str | None = None,
|
||||
info: dict | None = None,
|
||||
) -> str:
|
||||
"""Create a new authentication session for a user. If credential_id is None, creates a session without a specific credential."""
|
||||
async with connect() as db:
|
||||
return await db.create_session(user_id, credential_id, token, info)
|
||||
|
||||
|
||||
async def create_session_by_credential_id(
|
||||
user_id: UUID,
|
||||
credential_id: bytes | None,
|
||||
token: str | None = None,
|
||||
info: dict | None = None,
|
||||
) -> str:
|
||||
"""Create a new authentication session for a user using WebAuthn credential ID. If credential_id is None, creates a session without a specific credential."""
|
||||
async with connect() as db:
|
||||
return await db.create_session_by_credential_id(
|
||||
user_id, credential_id, token, info
|
||||
)
|
||||
|
||||
|
||||
async def get_session(token: str) -> dict | None:
|
||||
"""Get session by token string."""
|
||||
async with connect() as db:
|
||||
return await db.get_session(token)
|
||||
|
||||
|
||||
async def delete_session(token: str) -> None:
|
||||
"""Delete a session by token."""
|
||||
async with connect() as db:
|
||||
await db.delete_session(token)
|
||||
|
||||
|
||||
async def cleanup_expired_sessions() -> None:
|
||||
"""Remove expired sessions (older than 24 hours)."""
|
||||
async with connect() as db:
|
||||
await db.cleanup_expired_sessions()
|
||||
|
||||
|
||||
async def refresh_session(token: str) -> str | None:
|
||||
"""Refresh a session by updating its created_at timestamp."""
|
||||
async with connect() as db:
|
||||
return await db.refresh_session(token)
|
||||
|
||||
Reference in New Issue
Block a user