Database cleanup, base class, separated from FastAPI app.
This commit is contained in:
parent
00693c56fa
commit
c5733eefd6
@ -5,7 +5,8 @@ This module provides dataclasses and database abstractions for managing
|
|||||||
users, credentials, and sessions in a WebAuthn authentication system.
|
users, credentials, and sessions in a WebAuthn authentication system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@ -47,4 +48,108 @@ class Session:
|
|||||||
credential_uuid: UUID | None = None
|
credential_uuid: UUID | None = None
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["User", "Credential", "Session"]
|
class DatabaseInterface(ABC):
|
||||||
|
"""Abstract base class defining the database interface.
|
||||||
|
|
||||||
|
This class defines the public API that database implementations should provide.
|
||||||
|
Implementations may use decorators like @with_session that modify method signatures
|
||||||
|
at runtime, so this interface focuses on the logical operations rather than
|
||||||
|
exact parameter matching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def init_db(self) -> None:
|
||||||
|
"""Initialize database tables."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# User operations
|
||||||
|
@abstractmethod
|
||||||
|
async def get_user_by_user_uuid(self, user_uuid: UUID) -> User:
|
||||||
|
"""Get user record by WebAuthn user UUID."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def create_user(self, user: User) -> None:
|
||||||
|
"""Create a new user."""
|
||||||
|
|
||||||
|
# Credential operations
|
||||||
|
@abstractmethod
|
||||||
|
async def create_credential(self, credential: Credential) -> None:
|
||||||
|
"""Store a credential for a user."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_credential_by_id(self, credential_id: bytes) -> Credential:
|
||||||
|
"""Get credential by credential ID."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_credentials_by_user_uuid(self, user_uuid: UUID) -> list[bytes]:
|
||||||
|
"""Get all credential IDs for a user."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def update_credential(self, credential: Credential) -> None:
|
||||||
|
"""Update the sign count, created_at, last_used, and last_verified for a credential."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete_credential(self, uuid: UUID, user_uuid: UUID) -> None:
|
||||||
|
"""Delete a specific credential for a user."""
|
||||||
|
|
||||||
|
# Session operations
|
||||||
|
@abstractmethod
|
||||||
|
async def create_session(
|
||||||
|
self,
|
||||||
|
user_uuid: UUID,
|
||||||
|
key: bytes,
|
||||||
|
expires: datetime,
|
||||||
|
info: dict,
|
||||||
|
credential_uuid: UUID | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Create a new session."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_session(self, key: bytes) -> Session | None:
|
||||||
|
"""Get session by key."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete_session(self, key: bytes) -> None:
|
||||||
|
"""Delete session by key."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def update_session(
|
||||||
|
self, key: bytes, expires: datetime, info: dict
|
||||||
|
) -> Session | None:
|
||||||
|
"""Update session expiry and info."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def cleanup(self) -> None:
|
||||||
|
"""Called periodically to clean up expired records."""
|
||||||
|
|
||||||
|
# Combined operations
|
||||||
|
@abstractmethod
|
||||||
|
async def login(self, user_uuid: UUID, credential: Credential) -> None:
|
||||||
|
"""Update user and credential timestamps after successful login."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def create_user_and_credential(
|
||||||
|
self, user: User, credential: Credential
|
||||||
|
) -> None:
|
||||||
|
"""Create a new user and their first credential in a transaction."""
|
||||||
|
|
||||||
|
|
||||||
|
# Global DB instance
|
||||||
|
database_instance: DatabaseInterface | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def database() -> DatabaseInterface:
|
||||||
|
"""Get the global database instance."""
|
||||||
|
if database_instance is None:
|
||||||
|
raise RuntimeError("Database not initialized. Call e.g. db.sql.init() first.")
|
||||||
|
return database_instance
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"User",
|
||||||
|
"Credential",
|
||||||
|
"Session",
|
||||||
|
"DatabaseInterface",
|
||||||
|
"database_instance",
|
||||||
|
"database",
|
||||||
|
]
|
||||||
|
@ -5,8 +5,8 @@ This module provides an async database layer using SQLAlchemy async mode
|
|||||||
for managing users and credentials in a WebAuthn authentication system.
|
for managing users and credentials in a WebAuthn authentication system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import wraps
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
@ -23,24 +23,15 @@ from sqlalchemy.dialects.sqlite import BLOB, JSON
|
|||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from . import Credential, Session, User
|
from . import Credential, DatabaseInterface, Session, User
|
||||||
|
|
||||||
DB_PATH = "sqlite+aiosqlite:///webauthn.db"
|
DB_PATH = "sqlite+aiosqlite:///webauthn.db"
|
||||||
|
|
||||||
|
|
||||||
def with_session(func):
|
def init(*args, **kwargs):
|
||||||
"""Decorator that provides a database session with transaction to the method."""
|
from .. import db
|
||||||
|
|
||||||
@wraps(func)
|
db.database_instance = DB()
|
||||||
async def wrapper(self, *args, **kwargs):
|
|
||||||
async with self.async_session_factory() as session:
|
|
||||||
async with session.begin():
|
|
||||||
result = await func(self, session, *args, **kwargs)
|
|
||||||
await session.flush()
|
|
||||||
return result
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
# SQLAlchemy Models
|
# SQLAlchemy Models
|
||||||
@ -100,7 +91,7 @@ class SessionModel(Base):
|
|||||||
user: Mapped["UserModel"] = relationship("UserModel")
|
user: Mapped["UserModel"] = relationship("UserModel")
|
||||||
|
|
||||||
|
|
||||||
class DB:
|
class DB(DatabaseInterface):
|
||||||
"""Database class that handles its own connections."""
|
"""Database class that handles its own connections."""
|
||||||
|
|
||||||
def __init__(self, db_path: str = DB_PATH):
|
def __init__(self, db_path: str = DB_PATH):
|
||||||
@ -110,14 +101,22 @@ class DB:
|
|||||||
self.engine, expire_on_commit=False
|
self.engine, expire_on_commit=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def session(self):
|
||||||
|
"""Async context manager that provides a database session with transaction."""
|
||||||
|
async with self.async_session_factory() as session:
|
||||||
|
async with session.begin():
|
||||||
|
yield session
|
||||||
|
await session.flush()
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
async def init_db(self) -> None:
|
async def init_db(self) -> None:
|
||||||
"""Initialize database tables."""
|
"""Initialize database tables."""
|
||||||
async with self.engine.begin() as conn:
|
async with self.engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
@with_session
|
async def get_user_by_user_uuid(self, user_uuid: UUID) -> User:
|
||||||
async def get_user_by_user_uuid(self, session, user_uuid: UUID) -> User:
|
async with self.session() as session:
|
||||||
"""Get user record by WebAuthn user UUID."""
|
|
||||||
stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes)
|
stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
user_model = result.scalar_one_or_none()
|
user_model = result.scalar_one_or_none()
|
||||||
@ -132,9 +131,8 @@ class DB:
|
|||||||
)
|
)
|
||||||
raise ValueError("User not found")
|
raise ValueError("User not found")
|
||||||
|
|
||||||
@with_session
|
async def create_user(self, user: User) -> None:
|
||||||
async def create_user(self, session, user: User) -> None:
|
async with self.session() as session:
|
||||||
"""Create a new user."""
|
|
||||||
user_model = UserModel(
|
user_model = UserModel(
|
||||||
user_uuid=user.user_uuid.bytes,
|
user_uuid=user.user_uuid.bytes,
|
||||||
user_name=user.user_name,
|
user_name=user.user_name,
|
||||||
@ -144,9 +142,8 @@ class DB:
|
|||||||
)
|
)
|
||||||
session.add(user_model)
|
session.add(user_model)
|
||||||
|
|
||||||
@with_session
|
async def create_credential(self, credential: Credential) -> None:
|
||||||
async def create_credential(self, session, credential: Credential) -> None:
|
async with self.session() as session:
|
||||||
"""Store a credential for a user."""
|
|
||||||
credential_model = CredentialModel(
|
credential_model = CredentialModel(
|
||||||
uuid=credential.uuid.bytes,
|
uuid=credential.uuid.bytes,
|
||||||
credential_id=credential.credential_id,
|
credential_id=credential.credential_id,
|
||||||
@ -160,9 +157,8 @@ class DB:
|
|||||||
)
|
)
|
||||||
session.add(credential_model)
|
session.add(credential_model)
|
||||||
|
|
||||||
@with_session
|
async def get_credential_by_id(self, credential_id: bytes) -> Credential:
|
||||||
async def get_credential_by_id(self, session, credential_id: bytes) -> Credential:
|
async with self.session() as session:
|
||||||
"""Get credential by credential ID."""
|
|
||||||
stmt = select(CredentialModel).where(
|
stmt = select(CredentialModel).where(
|
||||||
CredentialModel.credential_id == credential_id
|
CredentialModel.credential_id == credential_id
|
||||||
)
|
)
|
||||||
@ -183,20 +179,16 @@ class DB:
|
|||||||
last_verified=credential_model.last_verified,
|
last_verified=credential_model.last_verified,
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_session
|
async def get_credentials_by_user_uuid(self, user_uuid: UUID) -> list[bytes]:
|
||||||
async def get_credentials_by_user_uuid(
|
async with self.session() as session:
|
||||||
self, session, user_uuid: UUID
|
|
||||||
) -> list[bytes]:
|
|
||||||
"""Get all credential IDs for a user."""
|
|
||||||
stmt = select(CredentialModel.credential_id).where(
|
stmt = select(CredentialModel.credential_id).where(
|
||||||
CredentialModel.user_uuid == user_uuid.bytes
|
CredentialModel.user_uuid == user_uuid.bytes
|
||||||
)
|
)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return [row[0] for row in result.fetchall()]
|
return [row[0] for row in result.fetchall()]
|
||||||
|
|
||||||
@with_session
|
async def update_credential(self, credential: Credential) -> None:
|
||||||
async def update_credential(self, session, credential: Credential) -> None:
|
async with self.session() as session:
|
||||||
"""Update the sign count, created_at, last_used, and last_verified for a credential."""
|
|
||||||
stmt = (
|
stmt = (
|
||||||
update(CredentialModel)
|
update(CredentialModel)
|
||||||
.where(CredentialModel.credential_id == credential.credential_id)
|
.where(CredentialModel.credential_id == credential.credential_id)
|
||||||
@ -209,9 +201,8 @@ class DB:
|
|||||||
)
|
)
|
||||||
await session.execute(stmt)
|
await session.execute(stmt)
|
||||||
|
|
||||||
@with_session
|
async def login(self, user_uuid: UUID, credential: Credential) -> None:
|
||||||
async def login(self, session, user_uuid: UUID, credential: Credential) -> None:
|
async with self.session() as session:
|
||||||
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
|
|
||||||
# Update credential
|
# Update credential
|
||||||
stmt = (
|
stmt = (
|
||||||
update(CredentialModel)
|
update(CredentialModel)
|
||||||
@ -233,11 +224,10 @@ class DB:
|
|||||||
)
|
)
|
||||||
await session.execute(stmt)
|
await session.execute(stmt)
|
||||||
|
|
||||||
@with_session
|
|
||||||
async def create_user_and_credential(
|
async def create_user_and_credential(
|
||||||
self, session, user: User, credential: Credential
|
self, user: User, credential: Credential
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create a new user and their first credential in a single transaction."""
|
async with self.session() as session:
|
||||||
# Set visits to 1 for the new user since they're creating their first session
|
# Set visits to 1 for the new user since they're creating their first session
|
||||||
user.visits = 1
|
user.visits = 1
|
||||||
|
|
||||||
@ -265,9 +255,8 @@ class DB:
|
|||||||
)
|
)
|
||||||
session.add(credential_model)
|
session.add(credential_model)
|
||||||
|
|
||||||
@with_session
|
async def delete_credential(self, uuid: UUID, user_uuid: UUID) -> None:
|
||||||
async def delete_credential(self, session, uuid: UUID, user_uuid: UUID) -> None:
|
async with self.session() as session:
|
||||||
"""Delete a credential by its ID."""
|
|
||||||
stmt = (
|
stmt = (
|
||||||
delete(CredentialModel)
|
delete(CredentialModel)
|
||||||
.where(CredentialModel.uuid == uuid.bytes)
|
.where(CredentialModel.uuid == uuid.bytes)
|
||||||
@ -275,17 +264,15 @@ class DB:
|
|||||||
)
|
)
|
||||||
await session.execute(stmt)
|
await session.execute(stmt)
|
||||||
|
|
||||||
@with_session
|
|
||||||
async def create_session(
|
async def create_session(
|
||||||
self,
|
self,
|
||||||
session,
|
|
||||||
user_uuid: UUID,
|
user_uuid: UUID,
|
||||||
key: bytes,
|
key: bytes,
|
||||||
expires: datetime,
|
expires: datetime,
|
||||||
info: dict,
|
info: dict,
|
||||||
credential_uuid: UUID | None = None,
|
credential_uuid: UUID | None = None,
|
||||||
) -> bytes:
|
) -> None:
|
||||||
"""Create a new authentication session for a user. If credential_uuid is None, creates a session without a specific credential."""
|
async with self.session() as session:
|
||||||
session_model = SessionModel(
|
session_model = SessionModel(
|
||||||
key=key,
|
key=key,
|
||||||
user_uuid=user_uuid.bytes,
|
user_uuid=user_uuid.bytes,
|
||||||
@ -294,11 +281,9 @@ class DB:
|
|||||||
info=info,
|
info=info,
|
||||||
)
|
)
|
||||||
session.add(session_model)
|
session.add(session_model)
|
||||||
return key
|
|
||||||
|
|
||||||
@with_session
|
async def get_session(self, key: bytes) -> Session | None:
|
||||||
async def get_session(self, session, key: bytes) -> Session | None:
|
async with self.session() as session:
|
||||||
"""Get session by 16-byte key."""
|
|
||||||
stmt = select(SessionModel).where(SessionModel.key == key)
|
stmt = select(SessionModel).where(SessionModel.key == key)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
session_model = result.scalar_one_or_none()
|
session_model = result.scalar_one_or_none()
|
||||||
@ -315,25 +300,20 @@ class DB:
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@with_session
|
async def delete_session(self, key: bytes) -> None:
|
||||||
async def delete_session(self, session, key: bytes) -> None:
|
async with self.session() as session:
|
||||||
"""Delete a session by 16-byte key."""
|
|
||||||
await session.execute(delete(SessionModel).where(SessionModel.key == key))
|
await session.execute(delete(SessionModel).where(SessionModel.key == key))
|
||||||
|
|
||||||
@with_session
|
async def update_session(self, key: bytes, expires: datetime, info: dict) -> None:
|
||||||
async def update_session(
|
async with self.session() as session:
|
||||||
self, session, key: bytes, expires: datetime, info: dict
|
|
||||||
) -> None:
|
|
||||||
"""Update session expiration time and/or info."""
|
|
||||||
await session.execute(
|
await session.execute(
|
||||||
update(SessionModel)
|
update(SessionModel)
|
||||||
.where(SessionModel.key == key)
|
.where(SessionModel.key == key)
|
||||||
.values(expires=expires, info=info)
|
.values(expires=expires, info=info)
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_session
|
async def cleanup(self) -> None:
|
||||||
async def cleanup_expired_sessions(self, session) -> None:
|
async with self.session() as session:
|
||||||
"""Remove expired sessions."""
|
|
||||||
current_time = datetime.now()
|
current_time = datetime.now()
|
||||||
stmt = delete(SessionModel).where(SessionModel.expires < current_time)
|
stmt = delete(SessionModel).where(SessionModel.expires < current_time)
|
||||||
await session.execute(stmt)
|
await session.execute(stmt)
|
||||||
|
@ -14,7 +14,7 @@ from fastapi import Cookie, Depends, FastAPI, Request, Response
|
|||||||
from fastapi.security import HTTPBearer
|
from fastapi.security import HTTPBearer
|
||||||
|
|
||||||
from .. import aaguid
|
from .. import aaguid
|
||||||
from ..db import sql
|
from ..db import database
|
||||||
from ..util.tokens import session_key
|
from ..util.tokens import session_key
|
||||||
from . import session
|
from . import session
|
||||||
|
|
||||||
@ -38,19 +38,19 @@ def register_api_routes(app: FastAPI):
|
|||||||
return {"status": "error", "valid": False}
|
return {"status": "error", "valid": False}
|
||||||
|
|
||||||
@app.post("/auth/user-info")
|
@app.post("/auth/user-info")
|
||||||
async def api_user_info(request: Request, response: Response, auth=Cookie(None)):
|
async def api_user_info(auth=Cookie(None)):
|
||||||
"""Get full user information for the authenticated user."""
|
"""Get full user information for the authenticated user."""
|
||||||
try:
|
try:
|
||||||
s = await session.get_session(auth, reset_allowed=True)
|
s = await session.get_session(auth, reset_allowed=True)
|
||||||
u = await sql.get_user_by_uuid(s.user_uuid)
|
u = await database().get_user_by_user_uuid(s.user_uuid)
|
||||||
# Get all credentials for the user
|
# Get all credentials for the user
|
||||||
credential_ids = await sql.get_user_credentials(s.user_uuid)
|
credential_ids = await database().get_credentials_by_user_uuid(s.user_uuid)
|
||||||
|
|
||||||
credentials = []
|
credentials = []
|
||||||
user_aaguids = set()
|
user_aaguids = set()
|
||||||
|
|
||||||
for cred_id in credential_ids:
|
for cred_id in credential_ids:
|
||||||
c = await sql.get_credential_by_id(cred_id)
|
c = await database().get_credential_by_id(cred_id)
|
||||||
|
|
||||||
# Convert AAGUID to string format
|
# Convert AAGUID to string format
|
||||||
aaguid_str = str(c.aaguid)
|
aaguid_str = str(c.aaguid)
|
||||||
@ -102,14 +102,12 @@ def register_api_routes(app: FastAPI):
|
|||||||
"""Log out the current user by clearing the session cookie and deleting from database."""
|
"""Log out the current user by clearing the session cookie and deleting from database."""
|
||||||
if not auth:
|
if not auth:
|
||||||
return {"status": "success", "message": "Already logged out"}
|
return {"status": "success", "message": "Already logged out"}
|
||||||
await sql.delete_session(session_key(auth))
|
await database().delete_session(session_key(auth))
|
||||||
response.delete_cookie("auth")
|
response.delete_cookie("auth")
|
||||||
return {"status": "success", "message": "Logged out successfully"}
|
return {"status": "success", "message": "Logged out successfully"}
|
||||||
|
|
||||||
@app.post("/auth/set-session")
|
@app.post("/auth/set-session")
|
||||||
async def api_set_session(
|
async def api_set_session(response: Response, auth=Depends(bearer_auth)):
|
||||||
request: Request, response: Response, auth=Depends(bearer_auth)
|
|
||||||
):
|
|
||||||
"""Set session cookie from Authorization header. Fetched after login by WebSocket."""
|
"""Set session cookie from Authorization header. Fetched after login by WebSocket."""
|
||||||
try:
|
try:
|
||||||
user = await session.get_session(auth.credentials)
|
user = await session.get_session(auth.credentials)
|
||||||
|
@ -20,7 +20,7 @@ from fastapi.responses import (
|
|||||||
)
|
)
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
from ..db import sql
|
from ..db import close_db, init_db
|
||||||
from . import session, ws
|
from . import session, ws
|
||||||
from .api import register_api_routes
|
from .api import register_api_routes
|
||||||
from .reset import register_reset_routes
|
from .reset import register_reset_routes
|
||||||
@ -30,8 +30,9 @@ STATIC_DIR = Path(__file__).parent.parent / "frontend-build"
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
await sql.init_database()
|
await init_db()
|
||||||
yield
|
yield
|
||||||
|
await close_db()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
@ -3,7 +3,7 @@ import logging
|
|||||||
from fastapi import Cookie, HTTPException, Request
|
from fastapi import Cookie, HTTPException, Request
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
from ..db import sql
|
from ..db import database
|
||||||
from ..util import passphrase, tokens
|
from ..util import passphrase, tokens
|
||||||
from . import session
|
from . import session
|
||||||
|
|
||||||
@ -20,7 +20,7 @@ def register_reset_routes(app):
|
|||||||
|
|
||||||
# Generate a human-readable token
|
# Generate a human-readable token
|
||||||
token = passphrase.generate() # e.g., "cross.rotate.yin.note.evoke"
|
token = passphrase.generate() # e.g., "cross.rotate.yin.note.evoke"
|
||||||
await sql.create_session(
|
await database().create_session(
|
||||||
user_uuid=s.user_uuid,
|
user_uuid=s.user_uuid,
|
||||||
key=tokens.reset_key(token),
|
key=tokens.reset_key(token),
|
||||||
expires=session.expires(),
|
expires=session.expires(),
|
||||||
@ -56,7 +56,7 @@ def register_reset_routes(app):
|
|||||||
try:
|
try:
|
||||||
# Get session token to validate it exists and get user_id
|
# Get session token to validate it exists and get user_id
|
||||||
key = tokens.reset_key(reset_token)
|
key = tokens.reset_key(reset_token)
|
||||||
sess = await sql.get_session(key)
|
sess = await database().get_session(key)
|
||||||
if not sess:
|
if not sess:
|
||||||
raise ValueError("Invalid or expired registration token")
|
raise ValueError("Invalid or expired registration token")
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from fastapi import Request, Response, WebSocket
|
from fastapi import Request, Response, WebSocket
|
||||||
|
|
||||||
from ..db import Session, sql
|
from ..db import Session, database
|
||||||
from ..util import passphrase
|
from ..util import passphrase
|
||||||
from ..util.tokens import create_token, reset_key, session_key
|
from ..util.tokens import create_token, reset_key, session_key
|
||||||
|
|
||||||
@ -37,7 +37,7 @@ def infodict(request: Request | WebSocket, type: str) -> dict:
|
|||||||
async def create_session(user_uuid: UUID, info: dict, credential_uuid: UUID) -> str:
|
async def create_session(user_uuid: UUID, info: dict, credential_uuid: UUID) -> str:
|
||||||
"""Create a new session and return a session token."""
|
"""Create a new session and return a session token."""
|
||||||
token = create_token()
|
token = create_token()
|
||||||
await sql.create_session(
|
await database().create_session(
|
||||||
user_uuid=user_uuid,
|
user_uuid=user_uuid,
|
||||||
key=session_key(token),
|
key=session_key(token),
|
||||||
expires=datetime.now() + EXPIRES,
|
expires=datetime.now() + EXPIRES,
|
||||||
@ -56,7 +56,7 @@ async def get_session(token: str, reset_allowed=False) -> Session:
|
|||||||
else:
|
else:
|
||||||
key = session_key(token)
|
key = session_key(token)
|
||||||
|
|
||||||
session = await sql.get_session(key)
|
session = await database().get_session(key)
|
||||||
if not session:
|
if not session:
|
||||||
raise ValueError("Invalid or expired session token")
|
raise ValueError("Invalid or expired session token")
|
||||||
return session
|
return session
|
||||||
@ -65,7 +65,9 @@ async def get_session(token: str, reset_allowed=False) -> Session:
|
|||||||
async def refresh_session_token(token: str):
|
async def refresh_session_token(token: str):
|
||||||
"""Refresh a session extending its expiry."""
|
"""Refresh a session extending its expiry."""
|
||||||
# Get the current session
|
# Get the current session
|
||||||
s = await sql.update_session(session_key(token), datetime.now() + EXPIRES, {})
|
s = await database().update_session(
|
||||||
|
session_key(token), datetime.now() + EXPIRES, {}
|
||||||
|
)
|
||||||
|
|
||||||
if not s:
|
if not s:
|
||||||
raise ValueError("Session not found or expired")
|
raise ValueError("Session not found or expired")
|
||||||
@ -86,4 +88,4 @@ def set_session_cookie(response: Response, token: str) -> None:
|
|||||||
async def delete_credential(credential_uuid: UUID, auth: str):
|
async def delete_credential(credential_uuid: UUID, auth: str):
|
||||||
"""Delete a specific credential for the current user."""
|
"""Delete a specific credential for the current user."""
|
||||||
s = await get_session(auth)
|
s = await get_session(auth)
|
||||||
await sql.delete_credential(credential_uuid, s.user_uuid)
|
await database().delete_credential(credential_uuid, s.user_uuid)
|
||||||
|
@ -18,7 +18,7 @@ from webauthn.helpers.exceptions import InvalidAuthenticationResponse
|
|||||||
|
|
||||||
from passkey.fastapi import session
|
from passkey.fastapi import session
|
||||||
|
|
||||||
from ..db import User, sql
|
from ..db import User, database
|
||||||
from ..sansio import Passkey
|
from ..sansio import Passkey
|
||||||
from ..util.tokens import create_token, session_key
|
from ..util.tokens import create_token, session_key
|
||||||
from .session import create_session, infodict
|
from .session import create_session, infodict
|
||||||
@ -65,13 +65,13 @@ async def websocket_register_new(
|
|||||||
credential = await register_chat(ws, user_uuid, user_name, origin=origin)
|
credential = await register_chat(ws, user_uuid, user_name, origin=origin)
|
||||||
|
|
||||||
# Store the user and credential in the database
|
# Store the user and credential in the database
|
||||||
await sql.create_user_and_credential(
|
await database().create_user_and_credential(
|
||||||
User(user_uuid, user_name, created_at=datetime.now()),
|
User(user_uuid, user_name, created_at=datetime.now()),
|
||||||
credential,
|
credential,
|
||||||
)
|
)
|
||||||
# Create a session token for the new user
|
# Create a session token for the new user
|
||||||
token = create_token()
|
token = create_token()
|
||||||
await sql.create_session(
|
await database().create_session(
|
||||||
user_uuid=user_uuid,
|
user_uuid=user_uuid,
|
||||||
key=session_key(token),
|
key=session_key(token),
|
||||||
expires=datetime.now() + session.EXPIRES,
|
expires=datetime.now() + session.EXPIRES,
|
||||||
@ -106,16 +106,16 @@ async def websocket_register_add(ws: WebSocket, auth=Cookie(None)):
|
|||||||
user_uuid = s.user_uuid
|
user_uuid = s.user_uuid
|
||||||
|
|
||||||
# Get user information to get the user_name
|
# Get user information to get the user_name
|
||||||
user = await sql.get_user_by_uuid(user_uuid)
|
user = await database().get_user_by_user_uuid(user_uuid)
|
||||||
user_name = user.user_name
|
user_name = user.user_name
|
||||||
challenge_ids = await sql.get_user_credentials(user_uuid)
|
challenge_ids = await database().get_credentials_by_user_uuid(user_uuid)
|
||||||
|
|
||||||
# WebAuthn registration
|
# WebAuthn registration
|
||||||
credential = await register_chat(
|
credential = await register_chat(
|
||||||
ws, user_uuid, user_name, challenge_ids, origin
|
ws, user_uuid, user_name, challenge_ids, origin
|
||||||
)
|
)
|
||||||
# Store the new credential in the database
|
# Store the new credential in the database
|
||||||
await sql.create_credential_for_user(credential)
|
await database().create_credential(credential)
|
||||||
|
|
||||||
await ws.send_json(
|
await ws.send_json(
|
||||||
{
|
{
|
||||||
@ -144,11 +144,11 @@ async def websocket_authenticate(ws: WebSocket):
|
|||||||
# Wait for the client to use his authenticator to authenticate
|
# Wait for the client to use his authenticator to authenticate
|
||||||
credential = passkey.auth_parse(await ws.receive_json())
|
credential = passkey.auth_parse(await ws.receive_json())
|
||||||
# Fetch from the database by credential ID
|
# Fetch from the database by credential ID
|
||||||
stored_cred = await sql.get_credential_by_id(credential.raw_id)
|
stored_cred = await database().get_credential_by_id(credential.raw_id)
|
||||||
# Verify the credential matches the stored data
|
# Verify the credential matches the stored data
|
||||||
passkey.auth_verify(credential, challenge, stored_cred, origin=origin)
|
passkey.auth_verify(credential, challenge, stored_cred, origin=origin)
|
||||||
# Update both credential and user's last_seen timestamp
|
# Update both credential and user's last_seen timestamp
|
||||||
await sql.login_user(stored_cred.user_uuid, stored_cred)
|
await database().login(stored_cred.user_uuid, stored_cred)
|
||||||
|
|
||||||
# Create a session token for the authenticated user
|
# Create a session token for the authenticated user
|
||||||
assert stored_cred.uuid is not None
|
assert stored_cred.uuid is not None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user