Major cleanup and refactoring of the backend (frontend not fully updated).
This commit is contained in:
parent
0cfa622bf1
commit
c5e5fe23e3
@ -22,7 +22,12 @@ import AddCredentialView from '@/components/AddCredentialView.vue'
|
|||||||
const store = useAuthStore()
|
const store = useAuthStore()
|
||||||
|
|
||||||
onMounted(async () => {
|
onMounted(async () => {
|
||||||
// Check for device addition session first
|
// Was an error message passed in the URL?
|
||||||
|
const message = location.hash.substring(1)
|
||||||
|
if (message) {
|
||||||
|
store.showMessage(decodeURIComponent(message), 'error')
|
||||||
|
history.replaceState(null, '', location.pathname)
|
||||||
|
}
|
||||||
try {
|
try {
|
||||||
await store.loadUserInfo()
|
await store.loadUserInfo()
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
<script setup>
|
<script setup>
|
||||||
import { useAuthStore } from '@/stores/auth'
|
import { useAuthStore } from '@/stores/auth'
|
||||||
import { registerWithSession } from '@/utils/passkey'
|
import { registerCredential } from '@/utils/passkey'
|
||||||
import { ref, onMounted } from 'vue'
|
import { ref, onMounted } from 'vue'
|
||||||
|
|
||||||
const authStore = useAuthStore()
|
const authStore = useAuthStore()
|
||||||
@ -25,9 +25,7 @@ const hasDeviceSession = ref(false)
|
|||||||
onMounted(async () => {
|
onMounted(async () => {
|
||||||
try {
|
try {
|
||||||
// Check if we have a device addition session
|
// Check if we have a device addition session
|
||||||
const response = await fetch('/auth/device-session-check', {
|
const response = await fetch('/auth/device-session-check')
|
||||||
credentials: 'include'
|
|
||||||
})
|
|
||||||
const data = await response.json()
|
const data = await response.json()
|
||||||
|
|
||||||
if (data.device_addition_session) {
|
if (data.device_addition_session) {
|
||||||
@ -50,7 +48,7 @@ function register() {
|
|||||||
|
|
||||||
authStore.isLoading = true
|
authStore.isLoading = true
|
||||||
authStore.showMessage('Starting registration...', 'info')
|
authStore.showMessage('Starting registration...', 'info')
|
||||||
registerWithSession().finally(() => {
|
registerCredential().finally(() => {
|
||||||
authStore.isLoading = false
|
authStore.isLoading = false
|
||||||
}).then(() => {
|
}).then(() => {
|
||||||
authStore.showMessage('Passkey registered successfully!', 'success', 2000)
|
authStore.showMessage('Passkey registered successfully!', 'success', 2000)
|
||||||
|
@ -46,7 +46,7 @@ const copyLink = async (event) => {
|
|||||||
|
|
||||||
onMounted(async () => {
|
onMounted(async () => {
|
||||||
try {
|
try {
|
||||||
const response = await fetch('/auth/create-device-link', { method: 'POST' })
|
const response = await fetch('/auth/create-link', { method: 'POST' })
|
||||||
const result = await response.json()
|
const result = await response.json()
|
||||||
if (result.error) throw new Error(result.error)
|
if (result.error) throw new Error(result.error)
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ onMounted(async () => {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Failed to fetch device link:', error)
|
console.error('Failed to create link:', error)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
</script>
|
</script>
|
||||||
|
@ -33,10 +33,7 @@ export const useAuthStore = defineStore('auth', {
|
|||||||
async setSessionCookie(sessionToken) {
|
async setSessionCookie(sessionToken) {
|
||||||
const response = await fetch('/auth/set-session', {
|
const response = await fetch('/auth/set-session', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {'Authorization': `Bearer ${sessionToken}`},
|
||||||
'Authorization': `Bearer ${sessionToken}`,
|
|
||||||
'Content-Type': 'application/json'
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
const result = await response.json()
|
const result = await response.json()
|
||||||
if (result.error) {
|
if (result.error) {
|
||||||
|
@ -22,10 +22,7 @@ export async function registerCredential() {
|
|||||||
return register('/auth/ws/add_credential')
|
return register('/auth/ws/add_credential')
|
||||||
}
|
}
|
||||||
export async function registerWithToken(token) {
|
export async function registerWithToken(token) {
|
||||||
return register('/auth/ws/add_device_credential', { token })
|
return register('/auth/ws/add_credential', { token })
|
||||||
}
|
|
||||||
export async function registerWithSession() {
|
|
||||||
return register('/auth/ws/add_device_credential_session')
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function authenticateUser() {
|
export async function authenticateUser() {
|
||||||
|
50
passkey/db/__init__.py
Normal file
50
passkey/db/__init__.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
"""
|
||||||
|
Database module for WebAuthn passkey authentication.
|
||||||
|
|
||||||
|
This module provides dataclasses and database abstractions for managing
|
||||||
|
users, credentials, and sessions in a WebAuthn authentication system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class User:
|
||||||
|
"""User data structure."""
|
||||||
|
|
||||||
|
user_uuid: UUID
|
||||||
|
user_name: str
|
||||||
|
created_at: datetime | None = None
|
||||||
|
last_seen: datetime | None = None
|
||||||
|
visits: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Credential:
|
||||||
|
"""Credential data structure."""
|
||||||
|
|
||||||
|
uuid: UUID
|
||||||
|
credential_id: bytes
|
||||||
|
user_uuid: UUID
|
||||||
|
aaguid: UUID
|
||||||
|
public_key: bytes
|
||||||
|
sign_count: int
|
||||||
|
created_at: datetime
|
||||||
|
last_used: datetime | None = None
|
||||||
|
last_verified: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Session:
|
||||||
|
"""Session data structure."""
|
||||||
|
|
||||||
|
key: bytes
|
||||||
|
user_uuid: UUID
|
||||||
|
expires: datetime
|
||||||
|
credential_uuid: UUID | None = None
|
||||||
|
info: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["User", "Credential", "Session"]
|
@ -5,10 +5,8 @@ This module provides an async database layer using SQLAlchemy async mode
|
|||||||
for managing users and credentials in a WebAuthn authentication system.
|
for managing users and credentials in a WebAuthn authentication system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import secrets
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from datetime import datetime
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
@ -25,7 +23,7 @@ from sqlalchemy.dialects.sqlite import BLOB, JSON
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from ..sansio import StoredCredential
|
from . import Credential, Session, User
|
||||||
|
|
||||||
DB_PATH = "sqlite+aiosqlite:///webauthn.db"
|
DB_PATH = "sqlite+aiosqlite:///webauthn.db"
|
||||||
|
|
||||||
@ -38,7 +36,7 @@ class Base(DeclarativeBase):
|
|||||||
class UserModel(Base):
|
class UserModel(Base):
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
user_id: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
|
user_uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
|
||||||
user_name: Mapped[str] = mapped_column(String, nullable=False)
|
user_name: Mapped[str] = mapped_column(String, nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
||||||
last_seen: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
last_seen: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
@ -52,10 +50,12 @@ class UserModel(Base):
|
|||||||
|
|
||||||
class CredentialModel(Base):
|
class CredentialModel(Base):
|
||||||
__tablename__ = "credentials"
|
__tablename__ = "credentials"
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
uuid: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
|
||||||
credential_id: Mapped[bytes] = mapped_column(LargeBinary(64), unique=True)
|
credential_id: Mapped[bytes] = mapped_column(
|
||||||
user_id: Mapped[bytes] = mapped_column(
|
LargeBinary(64), unique=True, index=True
|
||||||
LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
|
)
|
||||||
|
user_uuid: Mapped[bytes] = mapped_column(
|
||||||
|
LargeBinary(16), ForeignKey("users.user_uuid", ondelete="CASCADE")
|
||||||
)
|
)
|
||||||
aaguid: Mapped[bytes] = mapped_column(LargeBinary(16), nullable=False)
|
aaguid: Mapped[bytes] = mapped_column(LargeBinary(16), nullable=False)
|
||||||
public_key: Mapped[bytes] = mapped_column(BLOB, nullable=False)
|
public_key: Mapped[bytes] = mapped_column(BLOB, nullable=False)
|
||||||
@ -71,29 +71,20 @@ class CredentialModel(Base):
|
|||||||
class SessionModel(Base):
|
class SessionModel(Base):
|
||||||
__tablename__ = "sessions"
|
__tablename__ = "sessions"
|
||||||
|
|
||||||
token: Mapped[str] = mapped_column(String(32), primary_key=True)
|
key: Mapped[bytes] = mapped_column(LargeBinary(16), primary_key=True)
|
||||||
user_id: Mapped[bytes] = mapped_column(
|
user_uuid: Mapped[bytes] = mapped_column(
|
||||||
LargeBinary(16), ForeignKey("users.user_id", ondelete="CASCADE")
|
LargeBinary(16), ForeignKey("users.user_uuid", ondelete="CASCADE")
|
||||||
)
|
)
|
||||||
credential_id: Mapped[int | None] = mapped_column(
|
credential_uuid: Mapped[bytes | None] = mapped_column(
|
||||||
Integer, ForeignKey("credentials.id", ondelete="SET NULL")
|
LargeBinary(16), ForeignKey("credentials.uuid", ondelete="CASCADE")
|
||||||
)
|
)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
|
expires: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||||
info: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
info: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||||
|
|
||||||
# Relationship to user
|
# Relationship to user
|
||||||
user: Mapped["UserModel"] = relationship("UserModel")
|
user: Mapped["UserModel"] = relationship("UserModel")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class User:
|
|
||||||
user_id: UUID
|
|
||||||
user_name: str
|
|
||||||
created_at: datetime | None = None
|
|
||||||
last_seen: datetime | None = None
|
|
||||||
visits: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
# Global engine and session factory
|
# Global engine and session factory
|
||||||
engine = create_async_engine(DB_PATH, echo=False)
|
engine = create_async_engine(DB_PATH, echo=False)
|
||||||
async_session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
async_session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
@ -116,15 +107,15 @@ class DB:
|
|||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
async def get_user_by_user_id(self, user_id: UUID) -> User:
|
async def get_user_by_user_uuid(self, user_uuid: UUID) -> User:
|
||||||
"""Get user record by WebAuthn user ID."""
|
"""Get user record by WebAuthn user UUID."""
|
||||||
stmt = select(UserModel).where(UserModel.user_id == user_id.bytes)
|
stmt = select(UserModel).where(UserModel.user_uuid == user_uuid.bytes)
|
||||||
result = await self.session.execute(stmt)
|
result = await self.session.execute(stmt)
|
||||||
user_model = result.scalar_one_or_none()
|
user_model = result.scalar_one_or_none()
|
||||||
|
|
||||||
if user_model:
|
if user_model:
|
||||||
return User(
|
return User(
|
||||||
user_id=UUID(bytes=user_model.user_id),
|
user_uuid=UUID(bytes=user_model.user_uuid),
|
||||||
user_name=user_model.user_name,
|
user_name=user_model.user_name,
|
||||||
created_at=user_model.created_at,
|
created_at=user_model.created_at,
|
||||||
last_seen=user_model.last_seen,
|
last_seen=user_model.last_seen,
|
||||||
@ -135,7 +126,7 @@ class DB:
|
|||||||
async def create_user(self, user: User) -> None:
|
async def create_user(self, user: User) -> None:
|
||||||
"""Create a new user."""
|
"""Create a new user."""
|
||||||
user_model = UserModel(
|
user_model = UserModel(
|
||||||
user_id=user.user_id.bytes,
|
user_uuid=user.user_uuid.bytes,
|
||||||
user_name=user.user_name,
|
user_name=user.user_name,
|
||||||
created_at=user.created_at or datetime.now(),
|
created_at=user.created_at or datetime.now(),
|
||||||
last_seen=user.last_seen,
|
last_seen=user.last_seen,
|
||||||
@ -144,11 +135,12 @@ class DB:
|
|||||||
self.session.add(user_model)
|
self.session.add(user_model)
|
||||||
await self.session.flush()
|
await self.session.flush()
|
||||||
|
|
||||||
async def create_credential(self, credential: StoredCredential) -> None:
|
async def create_credential(self, credential: Credential) -> None:
|
||||||
"""Store a credential for a user."""
|
"""Store a credential for a user."""
|
||||||
credential_model = CredentialModel(
|
credential_model = CredentialModel(
|
||||||
|
uuid=credential.uuid.bytes,
|
||||||
credential_id=credential.credential_id,
|
credential_id=credential.credential_id,
|
||||||
user_id=credential.user_id.bytes,
|
user_uuid=credential.user_uuid.bytes,
|
||||||
aaguid=credential.aaguid.bytes,
|
aaguid=credential.aaguid.bytes,
|
||||||
public_key=credential.public_key,
|
public_key=credential.public_key,
|
||||||
sign_count=credential.sign_count,
|
sign_count=credential.sign_count,
|
||||||
@ -159,7 +151,7 @@ class DB:
|
|||||||
self.session.add(credential_model)
|
self.session.add(credential_model)
|
||||||
await self.session.flush()
|
await self.session.flush()
|
||||||
|
|
||||||
async def get_credential_by_id(self, credential_id: bytes) -> StoredCredential:
|
async def get_credential_by_id(self, credential_id: bytes) -> Credential:
|
||||||
"""Get credential by credential ID."""
|
"""Get credential by credential ID."""
|
||||||
stmt = select(CredentialModel).where(
|
stmt = select(CredentialModel).where(
|
||||||
CredentialModel.credential_id == credential_id
|
CredentialModel.credential_id == credential_id
|
||||||
@ -167,28 +159,29 @@ class DB:
|
|||||||
result = await self.session.execute(stmt)
|
result = await self.session.execute(stmt)
|
||||||
credential_model = result.scalar_one_or_none()
|
credential_model = result.scalar_one_or_none()
|
||||||
|
|
||||||
if credential_model:
|
if not credential_model:
|
||||||
return StoredCredential(
|
raise ValueError("Credential not registered")
|
||||||
credential_id=credential_model.credential_id,
|
return Credential(
|
||||||
user_id=UUID(bytes=credential_model.user_id),
|
uuid=UUID(bytes=credential_model.uuid),
|
||||||
aaguid=UUID(bytes=credential_model.aaguid),
|
credential_id=credential_model.credential_id,
|
||||||
public_key=credential_model.public_key,
|
user_uuid=UUID(bytes=credential_model.user_uuid),
|
||||||
sign_count=credential_model.sign_count,
|
aaguid=UUID(bytes=credential_model.aaguid),
|
||||||
created_at=credential_model.created_at,
|
public_key=credential_model.public_key,
|
||||||
last_used=credential_model.last_used,
|
sign_count=credential_model.sign_count,
|
||||||
last_verified=credential_model.last_verified,
|
created_at=credential_model.created_at,
|
||||||
)
|
last_used=credential_model.last_used,
|
||||||
raise ValueError("Credential not registered")
|
last_verified=credential_model.last_verified,
|
||||||
|
)
|
||||||
|
|
||||||
async def get_credentials_by_user_id(self, user_id: UUID) -> list[bytes]:
|
async def get_credentials_by_user_uuid(self, user_uuid: UUID) -> list[bytes]:
|
||||||
"""Get all credential IDs for a user."""
|
"""Get all credential IDs for a user."""
|
||||||
stmt = select(CredentialModel.credential_id).where(
|
stmt = select(CredentialModel.credential_id).where(
|
||||||
CredentialModel.user_id == user_id.bytes
|
CredentialModel.user_uuid == user_uuid.bytes
|
||||||
)
|
)
|
||||||
result = await self.session.execute(stmt)
|
result = await self.session.execute(stmt)
|
||||||
return [row[0] for row in result.fetchall()]
|
return [row[0] for row in result.fetchall()]
|
||||||
|
|
||||||
async def update_credential(self, credential: StoredCredential) -> None:
|
async def update_credential(self, credential: Credential) -> None:
|
||||||
"""Update the sign count, created_at, last_used, and last_verified for a credential."""
|
"""Update the sign count, created_at, last_used, and last_verified for a credential."""
|
||||||
stmt = (
|
stmt = (
|
||||||
update(CredentialModel)
|
update(CredentialModel)
|
||||||
@ -202,7 +195,7 @@ class DB:
|
|||||||
)
|
)
|
||||||
await self.session.execute(stmt)
|
await self.session.execute(stmt)
|
||||||
|
|
||||||
async def login(self, user_id: UUID, credential: StoredCredential) -> None:
|
async def login(self, user_uuid: UUID, credential: Credential) -> None:
|
||||||
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
|
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
|
||||||
async with self.session.begin():
|
async with self.session.begin():
|
||||||
# Update credential
|
# Update credential
|
||||||
@ -211,137 +204,77 @@ class DB:
|
|||||||
# Update user's last_seen and increment visits
|
# Update user's last_seen and increment visits
|
||||||
stmt = (
|
stmt = (
|
||||||
update(UserModel)
|
update(UserModel)
|
||||||
.where(UserModel.user_id == user_id.bytes)
|
.where(UserModel.user_uuid == user_uuid.bytes)
|
||||||
.values(last_seen=credential.last_used, visits=UserModel.visits + 1)
|
.values(last_seen=credential.last_used, visits=UserModel.visits + 1)
|
||||||
)
|
)
|
||||||
await self.session.execute(stmt)
|
await self.session.execute(stmt)
|
||||||
|
|
||||||
async def create_new_session(
|
async def delete_credential(self, uuid: UUID, user_uuid: UUID) -> None:
|
||||||
self, user_id: UUID, credential: StoredCredential
|
|
||||||
) -> None:
|
|
||||||
"""Create a new session for a user by incrementing visits and updating last_seen."""
|
|
||||||
async with self.session.begin():
|
|
||||||
# Update credential
|
|
||||||
await self.update_credential(credential)
|
|
||||||
|
|
||||||
# Update user's last_seen and increment visits
|
|
||||||
stmt = (
|
|
||||||
update(UserModel)
|
|
||||||
.where(UserModel.user_id == user_id.bytes)
|
|
||||||
.values(last_seen=credential.last_used, visits=UserModel.visits + 1)
|
|
||||||
)
|
|
||||||
await self.session.execute(stmt)
|
|
||||||
|
|
||||||
async def delete_credential(self, credential_id: bytes) -> None:
|
|
||||||
"""Delete a credential by its ID."""
|
"""Delete a credential by its ID."""
|
||||||
stmt = delete(CredentialModel).where(
|
stmt = (
|
||||||
CredentialModel.credential_id == credential_id
|
delete(CredentialModel)
|
||||||
|
.where(CredentialModel.uuid == uuid.bytes)
|
||||||
|
.where(CredentialModel.user_uuid == user_uuid.bytes)
|
||||||
)
|
)
|
||||||
await self.session.execute(stmt)
|
await self.session.execute(stmt)
|
||||||
await self.session.commit()
|
await self.session.commit()
|
||||||
|
|
||||||
async def get_user_by_username(self, user_name: str) -> User | None:
|
|
||||||
"""Get user by username."""
|
|
||||||
stmt = select(UserModel).where(UserModel.user_name == user_name)
|
|
||||||
result = await self.session.execute(stmt)
|
|
||||||
user_model = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if user_model:
|
|
||||||
return User(
|
|
||||||
user_id=UUID(bytes=user_model.user_id),
|
|
||||||
user_name=user_model.user_name,
|
|
||||||
created_at=user_model.created_at,
|
|
||||||
last_seen=user_model.last_seen,
|
|
||||||
visits=user_model.visits,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def create_session(
|
async def create_session(
|
||||||
self,
|
self,
|
||||||
user_id: UUID,
|
user_uuid: UUID,
|
||||||
credential_id: int | None = None,
|
key: bytes,
|
||||||
token: str | None = None,
|
expires: datetime,
|
||||||
info: dict | None = None,
|
info: dict,
|
||||||
) -> str:
|
credential_uuid: UUID | None = None,
|
||||||
"""Create a new authentication session for a user. If credential_id is None, creates a session without a specific credential."""
|
) -> bytes:
|
||||||
if token is None:
|
"""Create a new authentication session for a user. If credential_uuid is None, creates a session without a specific credential."""
|
||||||
token = secrets.token_urlsafe(12)
|
|
||||||
|
|
||||||
session_model = SessionModel(
|
session_model = SessionModel(
|
||||||
token=token,
|
key=key,
|
||||||
user_id=user_id.bytes,
|
user_uuid=user_uuid.bytes,
|
||||||
credential_id=credential_id,
|
credential_uuid=credential_uuid.bytes if credential_uuid else None,
|
||||||
created_at=datetime.now(),
|
expires=expires,
|
||||||
info=info,
|
info=info,
|
||||||
)
|
)
|
||||||
self.session.add(session_model)
|
self.session.add(session_model)
|
||||||
await self.session.flush()
|
await self.session.flush()
|
||||||
return token
|
return key
|
||||||
|
|
||||||
async def create_session_by_credential_id(
|
async def get_session(self, key: bytes) -> Session | None:
|
||||||
self,
|
"""Get session by 16-byte key."""
|
||||||
user_id: UUID,
|
stmt = select(SessionModel).where(SessionModel.key == key)
|
||||||
credential_id: bytes | None,
|
|
||||||
token: str | None = None,
|
|
||||||
info: dict | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""Create a new authentication session for a user using WebAuthn credential ID. If credential_id is None, creates a session without a specific credential."""
|
|
||||||
if credential_id is None:
|
|
||||||
return await self.create_session(user_id, None, token, info)
|
|
||||||
|
|
||||||
# Get the database ID from the credential
|
|
||||||
stmt = select(CredentialModel.id).where(
|
|
||||||
CredentialModel.credential_id == credential_id
|
|
||||||
)
|
|
||||||
result = await self.session.execute(stmt)
|
result = await self.session.execute(stmt)
|
||||||
db_credential_id = result.scalar_one()
|
session_model = result.scalar_one_or_none()
|
||||||
|
|
||||||
return await self.create_session(user_id, db_credential_id, token, info)
|
if session_model:
|
||||||
|
return Session(
|
||||||
async def get_session(self, token: str) -> SessionModel | None:
|
key=session_model.key,
|
||||||
"""Get session by token string."""
|
user_uuid=UUID(bytes=session_model.user_uuid),
|
||||||
stmt = select(SessionModel).where(SessionModel.token == token)
|
credential_uuid=UUID(bytes=session_model.credential_uuid)
|
||||||
result = await self.session.execute(stmt)
|
if session_model.credential_uuid
|
||||||
session = result.scalar_one_or_none()
|
else None,
|
||||||
|
expires=session_model.expires,
|
||||||
if session:
|
info=session_model.info,
|
||||||
# Check if session is expired (24 hours)
|
)
|
||||||
expiry_time = session.created_at + timedelta(hours=24)
|
|
||||||
if datetime.now() > expiry_time:
|
|
||||||
# Clean up expired session
|
|
||||||
await self.delete_session(token)
|
|
||||||
return None
|
|
||||||
|
|
||||||
return session
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def delete_session(self, token: str) -> None:
|
async def delete_session(self, key: bytes) -> None:
|
||||||
"""Delete a session by token."""
|
"""Delete a session by 16-byte key."""
|
||||||
stmt = delete(SessionModel).where(SessionModel.token == token)
|
await self.session.execute(delete(SessionModel).where(SessionModel.key == key))
|
||||||
await self.session.execute(stmt)
|
|
||||||
|
async def update_session(self, key: bytes, expires: datetime, info: dict) -> None:
|
||||||
|
"""Update session expiration time and/or info."""
|
||||||
|
await self.session.execute(
|
||||||
|
update(SessionModel)
|
||||||
|
.where(SessionModel.key == key)
|
||||||
|
.values(expires=expires, info=info)
|
||||||
|
)
|
||||||
|
|
||||||
async def cleanup_expired_sessions(self) -> None:
|
async def cleanup_expired_sessions(self) -> None:
|
||||||
"""Remove expired sessions (older than 24 hours)."""
|
"""Remove expired sessions."""
|
||||||
expiry_time = datetime.now() - timedelta(hours=24)
|
current_time = datetime.now()
|
||||||
stmt = delete(SessionModel).where(SessionModel.created_at < expiry_time)
|
stmt = delete(SessionModel).where(SessionModel.expires < current_time)
|
||||||
await self.session.execute(stmt)
|
await self.session.execute(stmt)
|
||||||
|
|
||||||
async def refresh_session(self, token: str) -> str | None:
|
|
||||||
"""Refresh a session by updating its created_at timestamp."""
|
|
||||||
session = await self.get_session(token)
|
|
||||||
if not session:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Delete old session
|
|
||||||
await self.delete_session(token)
|
|
||||||
|
|
||||||
# Create new session with same user and credential
|
|
||||||
return await self.create_session(
|
|
||||||
user_id=UUID(bytes=session.user_id),
|
|
||||||
credential_id=session.credential_id,
|
|
||||||
info=session.info,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Standalone functions that handle database connections internally
|
# Standalone functions that handle database connections internally
|
||||||
async def init_database() -> None:
|
async def init_database() -> None:
|
||||||
@ -350,7 +283,7 @@ async def init_database() -> None:
|
|||||||
await db.init_db()
|
await db.init_db()
|
||||||
|
|
||||||
|
|
||||||
async def create_user_and_credential(user: User, credential: StoredCredential) -> None:
|
async def create_user_and_credential(user: User, credential: Credential) -> None:
|
||||||
"""Create a new user and their first credential in a single transaction."""
|
"""Create a new user and their first credential in a single transaction."""
|
||||||
async with connect() as db:
|
async with connect() as db:
|
||||||
await db.session.begin()
|
await db.session.begin()
|
||||||
@ -360,97 +293,73 @@ async def create_user_and_credential(user: User, credential: StoredCredential) -
|
|||||||
await db.create_credential(credential)
|
await db.create_credential(credential)
|
||||||
|
|
||||||
|
|
||||||
async def get_user_by_id(user_id: UUID) -> User:
|
async def get_user_by_uuid(user_uuid: UUID) -> User:
|
||||||
"""Get user record by WebAuthn user ID."""
|
"""Get user record by WebAuthn user UUID."""
|
||||||
async with connect() as db:
|
async with connect() as db:
|
||||||
return await db.get_user_by_user_id(user_id)
|
return await db.get_user_by_user_uuid(user_uuid)
|
||||||
|
|
||||||
|
|
||||||
async def create_credential_for_user(credential: StoredCredential) -> None:
|
async def create_credential_for_user(credential: Credential) -> None:
|
||||||
"""Store a credential for an existing user."""
|
"""Store a credential for an existing user."""
|
||||||
async with connect() as db:
|
async with connect() as db:
|
||||||
await db.create_credential(credential)
|
await db.create_credential(credential)
|
||||||
|
|
||||||
|
|
||||||
async def get_credential_by_id(credential_id: bytes) -> StoredCredential:
|
async def get_credential_by_id(credential_id: bytes) -> Credential:
|
||||||
"""Get credential by credential ID."""
|
"""Get credential by credential ID."""
|
||||||
async with connect() as db:
|
async with connect() as db:
|
||||||
return await db.get_credential_by_id(credential_id)
|
return await db.get_credential_by_id(credential_id)
|
||||||
|
|
||||||
|
|
||||||
async def get_user_credentials(user_id: UUID) -> list[bytes]:
|
async def get_user_credentials(user_uuid: UUID) -> list[bytes]:
|
||||||
"""Get all credential IDs for a user."""
|
"""Get all credential IDs for a user."""
|
||||||
async with connect() as db:
|
async with connect() as db:
|
||||||
return await db.get_credentials_by_user_id(user_id)
|
return await db.get_credentials_by_user_uuid(user_uuid)
|
||||||
|
|
||||||
|
|
||||||
async def login_user(user_id: UUID, credential: StoredCredential) -> None:
|
async def login_user(user_uuid: UUID, credential: Credential) -> None:
|
||||||
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
|
"""Update the last_seen timestamp for a user and the credential record used for logging in."""
|
||||||
async with connect() as db:
|
async with connect() as db:
|
||||||
await db.login(user_id, credential)
|
await db.login(user_uuid, credential)
|
||||||
|
|
||||||
|
|
||||||
async def delete_user_credential(credential_id: bytes) -> None:
|
async def delete_credential(uuid: UUID, user_uuid: UUID) -> None:
|
||||||
"""Delete a credential by its ID."""
|
"""Delete a credential by its ID."""
|
||||||
async with connect() as db:
|
async with connect() as db:
|
||||||
await db.delete_credential(credential_id)
|
await db.delete_credential(uuid, user_uuid)
|
||||||
|
|
||||||
|
|
||||||
async def create_new_session(user_id: UUID, credential: StoredCredential) -> None:
|
|
||||||
"""Create a new session for a user by incrementing visits and updating last_seen."""
|
|
||||||
async with connect() as db:
|
|
||||||
await db.create_new_session(user_id, credential)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_user_by_username(user_name: str) -> User | None:
|
|
||||||
"""Get user by username."""
|
|
||||||
async with connect() as db:
|
|
||||||
return await db.get_user_by_username(user_name)
|
|
||||||
|
|
||||||
|
|
||||||
async def create_session(
|
async def create_session(
|
||||||
user_id: UUID,
|
user_uuid: UUID,
|
||||||
credential_id: int | None = None,
|
key: bytes,
|
||||||
token: str | None = None,
|
expires: datetime,
|
||||||
info: dict | None = None,
|
info: dict,
|
||||||
) -> str:
|
credential_uuid: UUID | None = None,
|
||||||
"""Create a new authentication session for a user. If credential_id is None, creates a session without a specific credential."""
|
) -> bytes:
|
||||||
|
"""Create a new authentication session for a user. If credential_uuid is None, creates a session without a specific credential."""
|
||||||
async with connect() as db:
|
async with connect() as db:
|
||||||
return await db.create_session(user_id, credential_id, token, info)
|
return await db.create_session(user_uuid, key, expires, info, credential_uuid)
|
||||||
|
|
||||||
|
|
||||||
async def create_session_by_credential_id(
|
async def get_session(key: bytes) -> Session | None:
|
||||||
user_id: UUID,
|
"""Get session by 16-byte key."""
|
||||||
credential_id: bytes | None,
|
|
||||||
token: str | None = None,
|
|
||||||
info: dict | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""Create a new authentication session for a user using WebAuthn credential ID. If credential_id is None, creates a session without a specific credential."""
|
|
||||||
async with connect() as db:
|
async with connect() as db:
|
||||||
return await db.create_session_by_credential_id(
|
return await db.get_session(key)
|
||||||
user_id, credential_id, token, info
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_session(token: str) -> SessionModel | None:
|
async def delete_session(key: bytes) -> None:
|
||||||
"""Get session by token string."""
|
"""Delete a session by 16-byte key."""
|
||||||
async with connect() as db:
|
async with connect() as db:
|
||||||
return await db.get_session(token)
|
await db.delete_session(key)
|
||||||
|
|
||||||
|
|
||||||
async def delete_session(token: str) -> None:
|
async def update_session(key: bytes, expires: datetime, info: dict) -> None:
|
||||||
"""Delete a session by token."""
|
"""Update session expiration time and/or info."""
|
||||||
async with connect() as db:
|
async with connect() as db:
|
||||||
await db.delete_session(token)
|
await db.update_session(key, expires, info)
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_expired_sessions() -> None:
|
async def cleanup_expired_sessions() -> None:
|
||||||
"""Remove expired sessions (older than 24 hours)."""
|
"""Remove expired sessions."""
|
||||||
async with connect() as db:
|
async with connect() as db:
|
||||||
await db.cleanup_expired_sessions()
|
await db.cleanup_expired_sessions()
|
||||||
|
|
||||||
|
|
||||||
async def refresh_session(token: str) -> str | None:
|
|
||||||
"""Refresh a session by updating its created_at timestamp."""
|
|
||||||
async with connect() as db:
|
|
||||||
return await db.refresh_session(token)
|
|
||||||
|
@ -8,67 +8,67 @@ This module contains all the HTTP API endpoints for:
|
|||||||
- Login/logout functionality
|
- Login/logout functionality
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import FastAPI, Request, Response
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import Cookie, Depends, FastAPI, Request, Response
|
||||||
|
from fastapi.security import HTTPBearer
|
||||||
|
|
||||||
from .. import aaguid
|
from .. import aaguid
|
||||||
from ..db import sql
|
from ..db import sql
|
||||||
from ..util.session import refresh_session_token, validate_session_token
|
from ..util.tokens import session_key
|
||||||
from .session import (
|
from . import session
|
||||||
clear_session_cookie,
|
|
||||||
get_current_user,
|
bearer_auth = HTTPBearer(auto_error=True)
|
||||||
get_session_token_from_bearer,
|
|
||||||
get_session_token_from_cookie,
|
|
||||||
set_session_cookie,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def register_api_routes(app: FastAPI):
|
def register_api_routes(app: FastAPI):
|
||||||
"""Register all API routes on the FastAPI app."""
|
"""Register all API routes on the FastAPI app."""
|
||||||
|
|
||||||
@app.post("/auth/user-info")
|
@app.post("/auth/validate")
|
||||||
async def api_user_info(request: Request, response: Response):
|
async def validate_token(request: Request, response: Response, auth=Cookie(None)):
|
||||||
"""Get user information and credentials from session cookie."""
|
"""Lightweight token validation endpoint."""
|
||||||
try:
|
try:
|
||||||
user = await get_current_user(request)
|
s = await session.get_session(auth)
|
||||||
if not user:
|
return {
|
||||||
return {"error": "Not authenticated"}
|
"status": "success",
|
||||||
|
"valid": True,
|
||||||
# Get current session credential ID
|
"user_uuid": str(s.user_uuid),
|
||||||
current_credential_id = None
|
}
|
||||||
session_token = get_session_token_from_cookie(request)
|
except ValueError:
|
||||||
if session_token:
|
return {"status": "error", "valid": False}
|
||||||
token_data = await validate_session_token(session_token)
|
|
||||||
if token_data:
|
|
||||||
current_credential_id = token_data.get("credential_id")
|
|
||||||
|
|
||||||
|
@app.post("/auth/user-info")
|
||||||
|
async def api_user_info(request: Request, response: Response, auth=Cookie(None)):
|
||||||
|
"""Get full user information for the authenticated user."""
|
||||||
|
try:
|
||||||
|
s = await session.get_session(auth, reset_allowed=True)
|
||||||
|
u = await sql.get_user_by_uuid(s.user_uuid)
|
||||||
# Get all credentials for the user
|
# Get all credentials for the user
|
||||||
credential_ids = await sql.get_user_credentials(user.user_id)
|
credential_ids = await sql.get_user_credentials(s.user_uuid)
|
||||||
|
|
||||||
credentials = []
|
credentials = []
|
||||||
user_aaguids = set()
|
user_aaguids = set()
|
||||||
|
|
||||||
for cred_id in credential_ids:
|
for cred_id in credential_ids:
|
||||||
stored_cred = await sql.get_credential_by_id(cred_id)
|
c = await sql.get_credential_by_id(cred_id)
|
||||||
|
|
||||||
# Convert AAGUID to string format
|
# Convert AAGUID to string format
|
||||||
aaguid_str = str(stored_cred.aaguid)
|
aaguid_str = str(c.aaguid)
|
||||||
user_aaguids.add(aaguid_str)
|
user_aaguids.add(aaguid_str)
|
||||||
|
|
||||||
# Check if this is the current session credential
|
# Check if this is the current session credential
|
||||||
is_current_session = current_credential_id == stored_cred.credential_id
|
is_current_session = s.credential_uuid == c.uuid
|
||||||
|
|
||||||
credentials.append(
|
credentials.append(
|
||||||
{
|
{
|
||||||
"credential_id": stored_cred.credential_id.hex(),
|
"credential_uuid": str(c.uuid),
|
||||||
"aaguid": aaguid_str,
|
"aaguid": aaguid_str,
|
||||||
"created_at": stored_cred.created_at.isoformat(),
|
"created_at": c.created_at.isoformat(),
|
||||||
"last_used": stored_cred.last_used.isoformat()
|
"last_used": c.last_used.isoformat() if c.last_used else None,
|
||||||
if stored_cred.last_used
|
"last_verified": c.last_verified.isoformat()
|
||||||
|
if c.last_verified
|
||||||
else None,
|
else None,
|
||||||
"last_verified": stored_cred.last_verified.isoformat()
|
"sign_count": c.sign_count,
|
||||||
if stored_cred.last_verified
|
|
||||||
else None,
|
|
||||||
"sign_count": stored_cred.sign_count,
|
|
||||||
"is_current_session": is_current_session,
|
"is_current_session": is_current_session,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -82,13 +82,11 @@ def register_api_routes(app: FastAPI):
|
|||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"user": {
|
"user": {
|
||||||
"user_id": str(user.user_id),
|
"user_uuid": str(u.user_uuid),
|
||||||
"user_name": user.user_name,
|
"user_name": u.user_name,
|
||||||
"created_at": user.created_at.isoformat()
|
"created_at": u.created_at.isoformat() if u.created_at else None,
|
||||||
if user.created_at
|
"last_seen": u.last_seen.isoformat() if u.last_seen else None,
|
||||||
else None,
|
"visits": u.visits,
|
||||||
"last_seen": user.last_seen.isoformat() if user.last_seen else None,
|
|
||||||
"visits": user.visits,
|
|
||||||
},
|
},
|
||||||
"credentials": credentials,
|
"credentials": credentials,
|
||||||
"aaguid_info": aaguid_info,
|
"aaguid_info": aaguid_info,
|
||||||
@ -97,196 +95,44 @@ def register_api_routes(app: FastAPI):
|
|||||||
return {"error": f"Failed to get user info: {str(e)}"}
|
return {"error": f"Failed to get user info: {str(e)}"}
|
||||||
|
|
||||||
@app.post("/auth/logout")
|
@app.post("/auth/logout")
|
||||||
async def api_logout(request: Request, response: Response):
|
async def api_logout(response: Response, auth=Cookie(None)):
|
||||||
"""Log out the current user by clearing the session cookie and deleting from database."""
|
"""Log out the current user by clearing the session cookie and deleting from database."""
|
||||||
# Get the session token before clearing the cookie
|
if not auth:
|
||||||
session_token = get_session_token_from_cookie(request)
|
return {"status": "success", "message": "Already logged out"}
|
||||||
|
await sql.delete_session(session_key(auth))
|
||||||
# Clear the cookie
|
response.delete_cookie("auth")
|
||||||
clear_session_cookie(response)
|
|
||||||
|
|
||||||
# Delete the session from the database if it exists
|
|
||||||
if session_token:
|
|
||||||
from ..util.session import logout_session
|
|
||||||
|
|
||||||
try:
|
|
||||||
await logout_session(session_token)
|
|
||||||
except Exception:
|
|
||||||
# Continue even if session deletion fails
|
|
||||||
pass
|
|
||||||
|
|
||||||
return {"status": "success", "message": "Logged out successfully"}
|
return {"status": "success", "message": "Logged out successfully"}
|
||||||
|
|
||||||
@app.post("/auth/set-session")
|
@app.post("/auth/set-session")
|
||||||
async def api_set_session(request: Request, response: Response):
|
async def api_set_session(
|
||||||
"""Set session cookie using JWT token from request body or Authorization header."""
|
request: Request, response: Response, auth=Depends(bearer_auth)
|
||||||
|
):
|
||||||
|
"""Set session cookie from Authorization header. Fetched after login by WebSocket."""
|
||||||
try:
|
try:
|
||||||
session_token = await get_session_token_from_bearer(request)
|
user = await session.get_session(auth.credentials)
|
||||||
|
if not user:
|
||||||
if not session_token:
|
raise ValueError("Invalid Authorization header.")
|
||||||
return {"error": "No session token provided"}
|
session.set_session_cookie(response, auth.credentials)
|
||||||
|
|
||||||
# Validate the session token
|
|
||||||
token_data = await validate_session_token(session_token)
|
|
||||||
if not token_data:
|
|
||||||
return {"error": "Invalid or expired session token"}
|
|
||||||
|
|
||||||
# Set the HTTP-only cookie
|
|
||||||
set_session_cookie(response, session_token)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"message": "Session cookie set successfully",
|
"message": "Session cookie set successfully",
|
||||||
"user_id": str(token_data["user_id"]),
|
"user_uuid": str(user.user_uuid),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
return {"error": str(e)}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"error": f"Failed to set session: {str(e)}"}
|
return {"error": f"Failed to set session: {str(e)}"}
|
||||||
|
|
||||||
@app.post("/auth/delete-credential")
|
@app.delete("/auth/credential/{uuid}")
|
||||||
async def api_delete_credential(request: Request):
|
async def api_delete_credential(uuid: UUID, auth: str = Cookie(None)):
|
||||||
"""Delete a specific credential for the current user."""
|
"""Delete a specific credential for the current user."""
|
||||||
try:
|
try:
|
||||||
user = await get_current_user(request)
|
await session.delete_credential(uuid, auth)
|
||||||
if not user:
|
|
||||||
return {"error": "Not authenticated"}
|
|
||||||
|
|
||||||
# Get the credential ID from the request body
|
|
||||||
try:
|
|
||||||
body = await request.json()
|
|
||||||
credential_id = body.get("credential_id")
|
|
||||||
if not credential_id:
|
|
||||||
return {"error": "credential_id is required"}
|
|
||||||
except Exception:
|
|
||||||
return {"error": "Invalid request body"}
|
|
||||||
|
|
||||||
# Convert credential_id from hex string to bytes
|
|
||||||
try:
|
|
||||||
credential_id_bytes = bytes.fromhex(credential_id)
|
|
||||||
except ValueError:
|
|
||||||
return {"error": "Invalid credential_id format"}
|
|
||||||
|
|
||||||
# First, verify the credential belongs to the current user
|
|
||||||
try:
|
|
||||||
stored_cred = await sql.get_credential_by_id(credential_id_bytes)
|
|
||||||
if stored_cred.user_id != user.user_id:
|
|
||||||
return {"error": "Credential not found or access denied"}
|
|
||||||
except ValueError:
|
|
||||||
return {"error": "Credential not found"}
|
|
||||||
|
|
||||||
# Check if this is the current session credential
|
|
||||||
session_token = get_session_token_from_cookie(request)
|
|
||||||
if session_token:
|
|
||||||
token_data = await validate_session_token(session_token)
|
|
||||||
if (
|
|
||||||
token_data
|
|
||||||
and token_data.get("credential_id") == credential_id_bytes
|
|
||||||
):
|
|
||||||
return {"error": "Cannot delete current session credential"}
|
|
||||||
|
|
||||||
# Get user's remaining credentials count
|
|
||||||
remaining_credentials = await sql.get_user_credentials(user.user_id)
|
|
||||||
if len(remaining_credentials) <= 1:
|
|
||||||
return {"error": "Cannot delete last remaining credential"}
|
|
||||||
|
|
||||||
# Delete the credential
|
|
||||||
await sql.delete_user_credential(credential_id_bytes)
|
|
||||||
|
|
||||||
return {"status": "success", "message": "Credential deleted successfully"}
|
return {"status": "success", "message": "Credential deleted successfully"}
|
||||||
|
|
||||||
except Exception as e:
|
except ValueError as e:
|
||||||
return {"error": f"Failed to delete credential: {str(e)}"}
|
return {"error": str(e)}
|
||||||
|
except Exception:
|
||||||
@app.get("/auth/sessions")
|
return {"error": "Failed to delete credential"}
|
||||||
async def api_get_sessions(request: Request):
|
|
||||||
"""Get all active sessions for the current user."""
|
|
||||||
try:
|
|
||||||
user = await get_current_user(request)
|
|
||||||
if not user:
|
|
||||||
return {"error": "Authentication required"}
|
|
||||||
|
|
||||||
# Get all sessions for this user
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
from ..db.sql import SessionModel, connect
|
|
||||||
|
|
||||||
async with connect() as db:
|
|
||||||
stmt = select(SessionModel).where(
|
|
||||||
SessionModel.user_id == user.user_id.bytes
|
|
||||||
)
|
|
||||||
result = await db.session.execute(stmt)
|
|
||||||
session_models = result.scalars().all()
|
|
||||||
|
|
||||||
sessions = []
|
|
||||||
current_token = get_session_token_from_cookie(request)
|
|
||||||
|
|
||||||
for session in session_models:
|
|
||||||
# Check if session is expired
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
expiry_time = session.created_at + timedelta(hours=24)
|
|
||||||
is_expired = datetime.now() > expiry_time
|
|
||||||
|
|
||||||
sessions.append(
|
|
||||||
{
|
|
||||||
"token": session.token[:8]
|
|
||||||
+ "...", # Only show first 8 chars for security
|
|
||||||
"created_at": session.created_at.isoformat(),
|
|
||||||
"client_ip": session.info.get("client_ip")
|
|
||||||
if session.info
|
|
||||||
else None,
|
|
||||||
"user_agent": session.info.get("user_agent")
|
|
||||||
if session.info
|
|
||||||
else None,
|
|
||||||
"connection_type": session.info.get(
|
|
||||||
"connection_type", "http"
|
|
||||||
)
|
|
||||||
if session.info
|
|
||||||
else "http",
|
|
||||||
"is_current": session.token == current_token,
|
|
||||||
"is_reset_token": session.credential_id is None,
|
|
||||||
"is_expired": is_expired,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"sessions": sessions,
|
|
||||||
"total_sessions": len(sessions),
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {"error": f"Failed to get sessions: {str(e)}"}
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_token(request: Request, response: Response) -> dict:
|
|
||||||
"""Validate a session token and return user info. Also refreshes the token if valid."""
|
|
||||||
try:
|
|
||||||
session_token = get_session_token_from_cookie(request)
|
|
||||||
if not session_token:
|
|
||||||
return {"error": "No session token found"}
|
|
||||||
|
|
||||||
# Validate the session token
|
|
||||||
token_data = await validate_session_token(session_token)
|
|
||||||
if not token_data:
|
|
||||||
clear_session_cookie(response)
|
|
||||||
return {"error": "Invalid or expired session token"}
|
|
||||||
|
|
||||||
# Refresh the token if valid
|
|
||||||
new_token = await refresh_session_token(session_token)
|
|
||||||
if new_token:
|
|
||||||
set_session_cookie(response, new_token)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"valid": True,
|
|
||||||
"refreshed": bool(new_token),
|
|
||||||
"user_id": str(token_data["user_id"]),
|
|
||||||
"credential_id": token_data["credential_id"].hex()
|
|
||||||
if token_data["credential_id"]
|
|
||||||
else None,
|
|
||||||
"created_at": token_data["created_at"].isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {"error": f"Failed to validate token: {str(e)}"}
|
|
||||||
|
@ -9,15 +9,12 @@ This module provides a simple WebAuthn implementation that:
|
|||||||
- Enables true passwordless authentication where users don't need to enter a user_name
|
- Enables true passwordless authentication where users don't need to enter a user_name
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from fastapi import (
|
from fastapi import Cookie, FastAPI, Request, Response
|
||||||
FastAPI,
|
|
||||||
Request,
|
|
||||||
Response,
|
|
||||||
)
|
|
||||||
from fastapi.responses import (
|
from fastapi.responses import (
|
||||||
FileResponse,
|
FileResponse,
|
||||||
JSONResponse,
|
JSONResponse,
|
||||||
@ -25,12 +22,9 @@ from fastapi.responses import (
|
|||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
from ..db import sql
|
from ..db import sql
|
||||||
from .api import (
|
from . import session, ws
|
||||||
register_api_routes,
|
from .api import register_api_routes
|
||||||
validate_token,
|
|
||||||
)
|
|
||||||
from .reset import register_reset_routes
|
from .reset import register_reset_routes
|
||||||
from .ws import ws_app
|
|
||||||
|
|
||||||
STATIC_DIR = Path(__file__).parent.parent / "frontend-build"
|
STATIC_DIR = Path(__file__).parent.parent / "frontend-build"
|
||||||
|
|
||||||
@ -44,7 +38,7 @@ async def lifespan(app: FastAPI):
|
|||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
# Mount the WebSocket subapp
|
# Mount the WebSocket subapp
|
||||||
app.mount("/auth/ws", ws_app)
|
app.mount("/auth/ws", ws.app)
|
||||||
|
|
||||||
# Register API routes
|
# Register API routes
|
||||||
register_api_routes(app)
|
register_api_routes(app)
|
||||||
@ -52,24 +46,19 @@ register_reset_routes(app)
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/auth/forward-auth")
|
@app.get("/auth/forward-auth")
|
||||||
async def forward_authentication(request: Request):
|
async def forward_authentication(request: Request, auth=Cookie(None)):
|
||||||
"""A verification endpoint to use with Caddy forward_auth or Nginx auth_request."""
|
"""A validation endpoint to use with Caddy forward_auth or Nginx auth_request."""
|
||||||
# Create a dummy response object for internal validation (we won't use it for cookies)
|
with contextlib.suppress(ValueError):
|
||||||
response = Response()
|
s = await session.get_session(auth)
|
||||||
|
# If authenticated, return a success response
|
||||||
|
if s.info and s.info["type"] == "authenticated":
|
||||||
|
return Response(status_code=204, headers={"x-auth-user": str(s.user_uuid)})
|
||||||
|
|
||||||
result = await validate_token(request, response)
|
# Serve the index.html of the authentication app if not authenticated
|
||||||
if result.get("status") != "success":
|
return FileResponse(
|
||||||
# Serve the index.html of the authentication app if not authenticated
|
STATIC_DIR / "index.html",
|
||||||
return FileResponse(
|
status_code=401,
|
||||||
STATIC_DIR / "index.html",
|
headers={"www-authenticate": "PrivateToken"},
|
||||||
status_code=401,
|
|
||||||
headers={"www-authenticate": "PrivateToken"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# If authenticated, return a success response
|
|
||||||
return Response(
|
|
||||||
status_code=204,
|
|
||||||
headers={"x-auth-user-id": result["user_id"]},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,114 +1,74 @@
|
|||||||
"""
|
import logging
|
||||||
Device addition API handlers for WebAuthn authentication.
|
|
||||||
|
|
||||||
This module provides endpoints for authenticated users to:
|
from fastapi import Cookie, HTTPException, Request
|
||||||
- Generate device addition links with human-readable tokens
|
|
||||||
- Validate device addition tokens
|
|
||||||
- Add new passkeys to existing accounts via tokens
|
|
||||||
"""
|
|
||||||
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from fastapi import FastAPI, Path, Request
|
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
from ..db import sql
|
from ..db import sql
|
||||||
from ..util.passphrase import generate
|
from ..util import passphrase, tokens
|
||||||
from ..util.session import get_client_info
|
from . import session
|
||||||
from .session import get_current_user, is_device_addition_session, set_session_cookie
|
|
||||||
|
|
||||||
|
|
||||||
def register_reset_routes(app: FastAPI):
|
def register_reset_routes(app):
|
||||||
"""Register all device addition/reset routes on the FastAPI app."""
|
"""Register all device addition/reset routes on the FastAPI app."""
|
||||||
|
|
||||||
@app.post("/auth/create-device-link")
|
@app.post("/auth/create-link")
|
||||||
async def api_create_device_link(request: Request):
|
async def api_create_link(request: Request, auth=Cookie(None)):
|
||||||
"""Create a device addition link for the authenticated user."""
|
"""Create a device addition link for the authenticated user."""
|
||||||
try:
|
try:
|
||||||
# Require authentication
|
# Require authentication
|
||||||
user = await get_current_user(request)
|
s = await session.get_session(auth)
|
||||||
if not user:
|
|
||||||
return {"error": "Authentication required"}
|
|
||||||
|
|
||||||
# Generate a human-readable token
|
# Generate a human-readable token
|
||||||
token = generate(n=4, sep=".") # e.g., "able-ocean-forest-dawn"
|
token = passphrase.generate() # e.g., "cross.rotate.yin.note.evoke"
|
||||||
|
await sql.create_session(
|
||||||
# Create session token in database with credential_id=None for device addition
|
user_uuid=s.user_uuid,
|
||||||
client_info = get_client_info(request)
|
key=tokens.reset_key(token),
|
||||||
await sql.create_session(user.user_id, None, token, client_info)
|
expires=session.expires(),
|
||||||
|
info=session.infodict(request, "device addition"),
|
||||||
|
)
|
||||||
|
|
||||||
# Generate the device addition link with pretty URL
|
# Generate the device addition link with pretty URL
|
||||||
addition_link = f"{request.headers.get('origin', '')}/auth/{token}"
|
path = request.url.path.removesuffix("create-link") + token
|
||||||
|
url = f"{request.headers['origin']}{path}"
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"message": "Device addition link generated successfully",
|
"message": "Registration link generated successfully",
|
||||||
"addition_link": addition_link,
|
"url": url,
|
||||||
"expires_in_hours": 24,
|
"expires": session.expires().isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
return {"error": "Authentication required"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"error": f"Failed to create device addition link: {str(e)}"}
|
return {"error": f"Failed to create registration link: {str(e)}"}
|
||||||
|
|
||||||
@app.get("/auth/device-session-check")
|
@app.get("/auth/{reset_token}")
|
||||||
async def check_device_session(request: Request):
|
|
||||||
"""Check if the current session is for device addition."""
|
|
||||||
is_device_session = await is_device_addition_session(request)
|
|
||||||
return {"device_addition_session": is_device_session}
|
|
||||||
|
|
||||||
@app.get("/auth/{passphrase}")
|
|
||||||
async def reset_authentication(
|
async def reset_authentication(
|
||||||
request: Request,
|
request: Request,
|
||||||
passphrase: str = Path(pattern=r"^\w+(\.\w+){2,}$"),
|
reset_token: str,
|
||||||
):
|
):
|
||||||
|
"""Verifies the token and redirects to auth app for credential registration."""
|
||||||
|
# This route should only match to exact passphrases
|
||||||
|
print(f"Reset handler called with url: {request.url.path}")
|
||||||
|
if not passphrase.is_well_formed(reset_token):
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
try:
|
try:
|
||||||
# Get session token to validate it exists and get user_id
|
# Get session token to validate it exists and get user_id
|
||||||
session = await sql.get_session(passphrase)
|
key = tokens.reset_key(reset_token)
|
||||||
if not session:
|
sess = await sql.get_session(key)
|
||||||
# Token doesn't exist, redirect to home
|
if not sess:
|
||||||
return RedirectResponse(url="/", status_code=303)
|
raise ValueError("Invalid or expired registration token")
|
||||||
|
|
||||||
# Check if this is a device addition session (credential_id is None)
|
|
||||||
if session.credential_id is not None:
|
|
||||||
# Not a device addition session, redirect to home
|
|
||||||
return RedirectResponse(url="/", status_code=303)
|
|
||||||
|
|
||||||
# Create a device addition session token for the user
|
|
||||||
client_info = get_client_info(request)
|
|
||||||
session_token = await sql.create_session(
|
|
||||||
UUID(bytes=session.user_id), None, None, client_info
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create response and set session cookie
|
|
||||||
response = RedirectResponse(url="/auth/", status_code=303)
|
response = RedirectResponse(url="/auth/", status_code=303)
|
||||||
set_session_cookie(response, session_token)
|
session.set_session_cookie(response, reset_token)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception:
|
except Exception as e:
|
||||||
# On any error, redirect to home
|
# On any error, redirect to auth app
|
||||||
return RedirectResponse(url="/", status_code=303)
|
if isinstance(e, ValueError):
|
||||||
|
msg = str(e)
|
||||||
|
else:
|
||||||
async def use_reset_token(token: str) -> dict:
|
logging.exception("Internal Server Error in reset_authentication")
|
||||||
"""Delete a device addition token after successful use."""
|
msg = "Internal Server Error"
|
||||||
try:
|
return RedirectResponse(url=f"/auth/#{msg}", status_code=303)
|
||||||
# Get session token first to validate it exists and is not expired
|
|
||||||
session = await sql.get_session(token)
|
|
||||||
if not session:
|
|
||||||
return {"error": "Invalid or expired device addition token"}
|
|
||||||
|
|
||||||
# Check if this is a device addition session (credential_id is None)
|
|
||||||
if session.credential_id is not None:
|
|
||||||
return {"error": "Invalid device addition token"}
|
|
||||||
|
|
||||||
# Delete the token (it's now used)
|
|
||||||
await sql.delete_session(token)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"message": "Device addition token used successfully",
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {"error": f"Failed to use device addition token: {str(e)}"}
|
|
||||||
|
@ -5,144 +5,85 @@ This module provides session management functionality including:
|
|||||||
- Getting current user from session cookies
|
- Getting current user from session cookies
|
||||||
- Setting and clearing HTTP-only cookies
|
- Setting and clearing HTTP-only cookies
|
||||||
- Session validation and token handling
|
- Session validation and token handling
|
||||||
|
- Device addition token management
|
||||||
|
- Device addition route handlers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import Request, Response
|
from fastapi import Request, Response
|
||||||
|
|
||||||
from ..db.sql import User, get_user_by_id
|
from ..db import Session, sql
|
||||||
from ..util.session import validate_session_token
|
from ..util import passphrase
|
||||||
|
from ..util.tokens import create_token, reset_key, session_key
|
||||||
|
|
||||||
COOKIE_NAME = "auth"
|
EXPIRES = timedelta(hours=24)
|
||||||
COOKIE_MAX_AGE = 86400 # 24 hours
|
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(request: Request) -> User | None:
|
def expires() -> datetime:
|
||||||
"""Get the current user from the session cookie."""
|
return datetime.now() + EXPIRES
|
||||||
session_token = request.cookies.get(COOKIE_NAME)
|
|
||||||
if not session_token:
|
|
||||||
return None
|
|
||||||
|
|
||||||
token_data = await validate_session_token(session_token)
|
|
||||||
if not token_data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
user = await get_user_by_id(token_data["user_id"])
|
|
||||||
return user
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def set_session_cookie(response: Response, session_token: str) -> None:
|
def infodict(request: Request, type: str) -> dict:
|
||||||
|
"""Extract client information from request."""
|
||||||
|
return {
|
||||||
|
"ip": request.client.host if request.client else "",
|
||||||
|
"user_agent": request.headers.get("user-agent", "")[:500],
|
||||||
|
"type": type,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def create_session(user_uuid: UUID, info: dict, credential_uuid: UUID) -> str:
|
||||||
|
"""Create a new session and return a session token."""
|
||||||
|
token = create_token()
|
||||||
|
await sql.create_session(
|
||||||
|
user_uuid=user_uuid,
|
||||||
|
key=session_key(token),
|
||||||
|
expires=datetime.now() + EXPIRES,
|
||||||
|
info=info,
|
||||||
|
credential_uuid=credential_uuid,
|
||||||
|
)
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session(token: str, reset_allowed=False) -> Session:
|
||||||
|
"""Validate a session token and return session data if valid."""
|
||||||
|
if passphrase.is_well_formed(token):
|
||||||
|
if not reset_allowed:
|
||||||
|
raise ValueError("Reset link is not allowed for this endpoint")
|
||||||
|
key = reset_key(token)
|
||||||
|
else:
|
||||||
|
key = session_key(token)
|
||||||
|
|
||||||
|
session = await sql.get_session(key)
|
||||||
|
if not session:
|
||||||
|
raise ValueError("Invalid or expired session token")
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_session_token(token: str):
|
||||||
|
"""Refresh a session extending its expiry."""
|
||||||
|
# Get the current session
|
||||||
|
s = await sql.update_session(session_key(token), datetime.now() + EXPIRES, {})
|
||||||
|
|
||||||
|
if not s:
|
||||||
|
raise ValueError("Session not found or expired")
|
||||||
|
|
||||||
|
|
||||||
|
def set_session_cookie(response: Response, token: str) -> None:
|
||||||
"""Set the session token as an HTTP-only cookie."""
|
"""Set the session token as an HTTP-only cookie."""
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
key=COOKIE_NAME,
|
key="auth",
|
||||||
value=session_token,
|
value=token,
|
||||||
max_age=COOKIE_MAX_AGE,
|
max_age=int(EXPIRES.total_seconds()),
|
||||||
httponly=True,
|
httponly=True,
|
||||||
secure=True,
|
secure=True,
|
||||||
samesite="lax",
|
path="/auth/",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def clear_session_cookie(response: Response) -> None:
|
async def delete_credential(credential_uuid: UUID, auth: str):
|
||||||
"""Clear the session cookie."""
|
"""Delete a specific credential for the current user."""
|
||||||
response.delete_cookie(key=COOKIE_NAME)
|
s = await get_session(auth)
|
||||||
|
await sql.delete_credential(credential_uuid, s.user_uuid)
|
||||||
|
|
||||||
def get_session_token_from_cookie(request: Request) -> str | None:
|
|
||||||
"""Extract session token from request cookies."""
|
|
||||||
return request.cookies.get(COOKIE_NAME)
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_session_from_request(request: Request) -> dict | None:
|
|
||||||
"""Validate session token from request and return token data."""
|
|
||||||
session_token = get_session_token_from_cookie(request)
|
|
||||||
if not session_token:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return await validate_session_token(session_token)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_session_token_from_bearer(request: Request) -> str | None:
|
|
||||||
"""Extract session token from Authorization header or request body."""
|
|
||||||
# Try to get token from Authorization header first
|
|
||||||
auth_header = request.headers.get("Authorization")
|
|
||||||
if auth_header and auth_header.startswith("Bearer "):
|
|
||||||
return auth_header.removeprefix("Bearer ")
|
|
||||||
|
|
||||||
|
|
||||||
async def get_user_from_cookie_string(cookie_header: str) -> UUID | None:
|
|
||||||
"""Parse cookie header and return user ID if valid session exists."""
|
|
||||||
if not cookie_header:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Parse cookies from header (simple implementation)
|
|
||||||
cookies = {}
|
|
||||||
for cookie in cookie_header.split(";"):
|
|
||||||
cookie = cookie.strip()
|
|
||||||
if "=" in cookie:
|
|
||||||
name, value = cookie.split("=", 1)
|
|
||||||
cookies[name] = value
|
|
||||||
|
|
||||||
session_token = cookies.get(COOKIE_NAME)
|
|
||||||
if not session_token:
|
|
||||||
return None
|
|
||||||
|
|
||||||
token_data = await validate_session_token(session_token)
|
|
||||||
if not token_data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return token_data["user_id"]
|
|
||||||
|
|
||||||
|
|
||||||
async def is_device_addition_session(request: Request) -> bool:
|
|
||||||
"""Check if the current session is for device addition."""
|
|
||||||
session_token = request.cookies.get(COOKIE_NAME)
|
|
||||||
if not session_token:
|
|
||||||
return False
|
|
||||||
|
|
||||||
token_data = await validate_session_token(session_token)
|
|
||||||
if not token_data:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return token_data.get("device_addition", False)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_device_addition_user_id(request: Request) -> UUID | None:
|
|
||||||
"""Get user ID from device addition session."""
|
|
||||||
session_token = request.cookies.get(COOKIE_NAME)
|
|
||||||
if not session_token:
|
|
||||||
return None
|
|
||||||
|
|
||||||
token_data = await validate_session_token(session_token)
|
|
||||||
if not token_data or not token_data.get("device_addition"):
|
|
||||||
return None
|
|
||||||
|
|
||||||
return token_data.get("user_id")
|
|
||||||
|
|
||||||
|
|
||||||
async def get_device_addition_user_id_from_cookie(cookie_header: str) -> UUID | None:
|
|
||||||
"""Parse cookie header and return user ID if valid device addition session exists."""
|
|
||||||
if not cookie_header:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Parse cookies from header (simple implementation)
|
|
||||||
cookies = {}
|
|
||||||
for cookie in cookie_header.split(";"):
|
|
||||||
cookie = cookie.strip()
|
|
||||||
if "=" in cookie:
|
|
||||||
name, value = cookie.split("=", 1)
|
|
||||||
cookies[name] = value
|
|
||||||
|
|
||||||
session_token = cookies.get(COOKIE_NAME)
|
|
||||||
if not session_token:
|
|
||||||
return None
|
|
||||||
|
|
||||||
token_data = await validate_session_token(session_token)
|
|
||||||
if not token_data or not token_data.get("device_addition"):
|
|
||||||
return None
|
|
||||||
|
|
||||||
return token_data["user_id"]
|
|
||||||
|
@ -13,17 +13,18 @@ from datetime import datetime
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import uuid7
|
import uuid7
|
||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
from fastapi import Cookie, FastAPI, Query, Request, WebSocket, WebSocketDisconnect
|
||||||
from webauthn.helpers.exceptions import InvalidAuthenticationResponse
|
from webauthn.helpers.exceptions import InvalidAuthenticationResponse
|
||||||
|
|
||||||
from ..db import sql
|
from passkey.fastapi import session
|
||||||
from ..db.sql import User
|
|
||||||
|
from ..db import User, sql
|
||||||
from ..sansio import Passkey
|
from ..sansio import Passkey
|
||||||
from ..util.session import create_session_token, get_client_info_from_websocket
|
from ..util.tokens import create_token, reset_key, session_key
|
||||||
from .session import get_user_from_cookie_string
|
from .session import create_session, infodict
|
||||||
|
|
||||||
# Create a FastAPI subapp for WebSocket endpoints
|
# Create a FastAPI subapp for WebSocket endpoints
|
||||||
ws_app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
# Initialize the passkey instance
|
# Initialize the passkey instance
|
||||||
passkey = Passkey(
|
passkey = Passkey(
|
||||||
@ -34,51 +35,55 @@ passkey = Passkey(
|
|||||||
|
|
||||||
async def register_chat(
|
async def register_chat(
|
||||||
ws: WebSocket,
|
ws: WebSocket,
|
||||||
user_id: UUID,
|
user_uuid: UUID,
|
||||||
user_name: str,
|
user_name: str,
|
||||||
credential_ids: list[bytes] | None = None,
|
credential_ids: list[bytes] | None = None,
|
||||||
origin: str | None = None,
|
origin: str | None = None,
|
||||||
):
|
):
|
||||||
"""Generate registration options and send them to the client."""
|
"""Generate registration options and send them to the client."""
|
||||||
options, challenge = passkey.reg_generate_options(
|
options, challenge = passkey.reg_generate_options(
|
||||||
user_id=user_id,
|
user_id=user_uuid,
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
credential_ids=credential_ids,
|
credential_ids=credential_ids,
|
||||||
origin=origin,
|
origin=origin,
|
||||||
)
|
)
|
||||||
await ws.send_json(options)
|
await ws.send_json(options)
|
||||||
response = await ws.receive_json()
|
response = await ws.receive_json()
|
||||||
return passkey.reg_verify(response, challenge, user_id, origin=origin)
|
return passkey.reg_verify(response, challenge, user_uuid, origin=origin)
|
||||||
|
|
||||||
|
|
||||||
@ws_app.websocket("/register_new")
|
@app.websocket("/register")
|
||||||
async def websocket_register_new(ws: WebSocket, user_name: str):
|
async def websocket_register_new(
|
||||||
|
request: Request, ws: WebSocket, user_name: str = Query(""), auth=Cookie(None)
|
||||||
|
):
|
||||||
"""Register a new user and with a new passkey credential."""
|
"""Register a new user and with a new passkey credential."""
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
origin = ws.headers.get("origin")
|
origin = ws.headers.get("origin")
|
||||||
try:
|
try:
|
||||||
user_id = uuid7.create()
|
user_uuid = uuid7.create()
|
||||||
|
|
||||||
# WebAuthn registration
|
# WebAuthn registration
|
||||||
credential = await register_chat(ws, user_id, user_name, origin=origin)
|
credential = await register_chat(ws, user_uuid, user_name, origin=origin)
|
||||||
|
|
||||||
# Store the user and credential in the database
|
# Store the user and credential in the database
|
||||||
await sql.create_user_and_credential(
|
await sql.create_user_and_credential(
|
||||||
User(user_id, user_name, created_at=datetime.now()),
|
User(user_uuid, user_name, created_at=datetime.now()),
|
||||||
credential,
|
credential,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a session token for the new user
|
# Create a session token for the new user
|
||||||
client_info = get_client_info_from_websocket(ws)
|
token = create_token()
|
||||||
session_token = await create_session_token(
|
await sql.create_session(
|
||||||
user_id, credential.credential_id, client_info
|
user_uuid=user_uuid,
|
||||||
|
key=session_key(token),
|
||||||
|
expires=datetime.now() + session.EXPIRES,
|
||||||
|
info=infodict(request, "authenticated"),
|
||||||
|
credential_uuid=credential.uuid,
|
||||||
)
|
)
|
||||||
|
|
||||||
await ws.send_json(
|
await ws.send_json(
|
||||||
{
|
{
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"user_id": str(user_id),
|
"user_uuid": str(user_uuid),
|
||||||
"session_token": session_token,
|
"session_token": token,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@ -90,28 +95,31 @@ async def websocket_register_new(ws: WebSocket, user_name: str):
|
|||||||
await ws.send_json({"error": "Internal Server Error"})
|
await ws.send_json({"error": "Internal Server Error"})
|
||||||
|
|
||||||
|
|
||||||
@ws_app.websocket("/add_credential")
|
@app.websocket("/add_credential")
|
||||||
async def websocket_register_add(ws: WebSocket):
|
async def websocket_register_add(ws: WebSocket, token: str | None = None):
|
||||||
"""Register a new credential for an existing user."""
|
"""Register a new credential for an existing user."""
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
origin = ws.headers.get("origin")
|
origin = ws.headers.get("origin")
|
||||||
try:
|
try:
|
||||||
# Authenticate user via cookie
|
if not token:
|
||||||
cookie_header = ws.headers.get("cookie", "")
|
await ws.send_json({"error": "Token is required"})
|
||||||
user_id = await get_user_from_cookie_string(cookie_header)
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
await ws.send_json({"error": "Authentication required"})
|
|
||||||
return
|
return
|
||||||
|
# If a token is provided, use it to look up the session
|
||||||
|
key = reset_key(token)
|
||||||
|
s = await sql.get_session(key)
|
||||||
|
if not s:
|
||||||
|
await ws.send_json({"error": "Invalid or expired token"})
|
||||||
|
return
|
||||||
|
user_uuid = s.user_uuid
|
||||||
|
|
||||||
# Get user information to get the user_name
|
# Get user information to get the user_name
|
||||||
user = await sql.get_user_by_id(user_id)
|
user = await sql.get_user_by_uuid(user_uuid)
|
||||||
user_name = user.user_name
|
user_name = user.user_name
|
||||||
challenge_ids = await sql.get_user_credentials(user_id)
|
challenge_ids = await sql.get_user_credentials(user_uuid)
|
||||||
|
|
||||||
# WebAuthn registration
|
# WebAuthn registration
|
||||||
credential = await register_chat(
|
credential = await register_chat(
|
||||||
ws, user_id, user_name, challenge_ids, origin=origin
|
ws, user_uuid, user_name, challenge_ids, origin=origin
|
||||||
)
|
)
|
||||||
# Store the new credential in the database
|
# Store the new credential in the database
|
||||||
await sql.create_credential_for_user(credential)
|
await sql.create_credential_for_user(credential)
|
||||||
@ -119,7 +127,7 @@ async def websocket_register_add(ws: WebSocket):
|
|||||||
await ws.send_json(
|
await ws.send_json(
|
||||||
{
|
{
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"user_id": str(user_id),
|
"user_uuid": str(user_uuid),
|
||||||
"credential_id": credential.credential_id.hex(),
|
"credential_id": credential.credential_id.hex(),
|
||||||
"message": "New credential added successfully",
|
"message": "New credential added successfully",
|
||||||
}
|
}
|
||||||
@ -133,103 +141,8 @@ async def websocket_register_add(ws: WebSocket):
|
|||||||
await ws.send_json({"error": "Internal Server Error"})
|
await ws.send_json({"error": "Internal Server Error"})
|
||||||
|
|
||||||
|
|
||||||
@ws_app.websocket("/add_device_credential")
|
@app.websocket("/authenticate")
|
||||||
async def websocket_add_device_credential(ws: WebSocket, token: str):
|
async def websocket_authenticate(request: Request, ws: WebSocket):
|
||||||
"""Add a new credential for an existing user via device addition token."""
|
|
||||||
await ws.accept()
|
|
||||||
origin = ws.headers.get("origin")
|
|
||||||
try:
|
|
||||||
reset_token = await sql.get_session(token)
|
|
||||||
if not reset_token:
|
|
||||||
await ws.send_json({"error": "Invalid or expired device addition token"})
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get user information
|
|
||||||
user = await sql.get_user_by_id(reset_token.user_id)
|
|
||||||
|
|
||||||
# WebAuthn registration
|
|
||||||
# Fetch challenge IDs for the user
|
|
||||||
challenge_ids = await sql.get_user_credentials(reset_token.user_id)
|
|
||||||
|
|
||||||
credential = await register_chat(
|
|
||||||
ws, reset_token.user_id, user.user_name, challenge_ids, origin=origin
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store the new credential in the database
|
|
||||||
await sql.create_credential_for_user(credential)
|
|
||||||
|
|
||||||
# Delete the device addition token (it's now used)
|
|
||||||
await sql.delete_reset_token(token)
|
|
||||||
|
|
||||||
await ws.send_json(
|
|
||||||
{
|
|
||||||
"status": "success",
|
|
||||||
"user_id": str(reset_token.user_id),
|
|
||||||
"credential_id": credential.credential_id.hex(),
|
|
||||||
"message": "New credential added successfully via device addition token",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
await ws.send_json({"error": str(e)})
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
logging.exception("Internal Server Error")
|
|
||||||
await ws.send_json({"error": "Internal Server Error"})
|
|
||||||
|
|
||||||
|
|
||||||
@ws_app.websocket("/add_device_credential_session")
|
|
||||||
async def websocket_add_device_credential_session(ws: WebSocket):
|
|
||||||
"""Add a new credential for an existing user via device addition session."""
|
|
||||||
await ws.accept()
|
|
||||||
origin = ws.headers.get("origin")
|
|
||||||
try:
|
|
||||||
# Get device addition user ID from session cookie
|
|
||||||
cookie_header = ws.headers.get("cookie", "")
|
|
||||||
from .session import get_device_addition_user_id_from_cookie
|
|
||||||
|
|
||||||
user_id = await get_device_addition_user_id_from_cookie(cookie_header)
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
await ws.send_json({"error": "No valid device addition session found"})
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get user information
|
|
||||||
user = await sql.get_user_by_id(user_id)
|
|
||||||
if not user:
|
|
||||||
await ws.send_json({"error": "User not found"})
|
|
||||||
return
|
|
||||||
|
|
||||||
# WebAuthn registration
|
|
||||||
# Fetch challenge IDs for the user
|
|
||||||
challenge_ids = await sql.get_user_credentials(user_id)
|
|
||||||
|
|
||||||
credential = await register_chat(
|
|
||||||
ws, user_id, user.user_name, challenge_ids, origin=origin
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store the new credential in the database
|
|
||||||
await sql.create_credential_for_user(credential)
|
|
||||||
|
|
||||||
await ws.send_json(
|
|
||||||
{
|
|
||||||
"status": "success",
|
|
||||||
"user_id": str(user_id),
|
|
||||||
"credential_id": credential.credential_id.hex(),
|
|
||||||
"message": "New credential added successfully via device addition session",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
await ws.send_json({"error": str(e)})
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
logging.exception("Internal Server Error")
|
|
||||||
await ws.send_json({"error": "Internal Server Error"})
|
|
||||||
|
|
||||||
|
|
||||||
@ws_app.websocket("/authenticate")
|
|
||||||
async def websocket_authenticate(ws: WebSocket):
|
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
origin = ws.headers.get("origin")
|
origin = ws.headers.get("origin")
|
||||||
try:
|
try:
|
||||||
@ -242,19 +155,21 @@ async def websocket_authenticate(ws: WebSocket):
|
|||||||
# Verify the credential matches the stored data
|
# Verify the credential matches the stored data
|
||||||
passkey.auth_verify(credential, challenge, stored_cred, origin=origin)
|
passkey.auth_verify(credential, challenge, stored_cred, origin=origin)
|
||||||
# Update both credential and user's last_seen timestamp
|
# Update both credential and user's last_seen timestamp
|
||||||
await sql.login_user(stored_cred.user_id, stored_cred)
|
await sql.login_user(stored_cred.user_uuid, stored_cred)
|
||||||
|
|
||||||
# Create a session token for the authenticated user
|
# Create a session token for the authenticated user
|
||||||
client_info = get_client_info_from_websocket(ws)
|
assert stored_cred.uuid is not None
|
||||||
session_token = await create_session_token(
|
token = await create_session(
|
||||||
stored_cred.user_id, stored_cred.credential_id, client_info
|
user_uuid=stored_cred.user_uuid,
|
||||||
|
info=infodict(request, "auth"),
|
||||||
|
credential_uuid=stored_cred.uuid,
|
||||||
)
|
)
|
||||||
|
|
||||||
await ws.send_json(
|
await ws.send_json(
|
||||||
{
|
{
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"user_id": str(stored_cred.user_id),
|
"user_uuid": str(stored_cred.user_uuid),
|
||||||
"session_token": session_token,
|
"session_token": token,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
except (ValueError, InvalidAuthenticationResponse) as e:
|
except (ValueError, InvalidAuthenticationResponse) as e:
|
||||||
|
@ -8,7 +8,6 @@ This module provides a unified interface for WebAuthn operations including:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@ -36,21 +35,7 @@ from webauthn.helpers.structs import (
|
|||||||
UserVerificationRequirement,
|
UserVerificationRequirement,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .db import Credential
|
||||||
@dataclass
|
|
||||||
class StoredCredential:
|
|
||||||
"""Credential data stored in the database."""
|
|
||||||
|
|
||||||
# Fields set only at registration time
|
|
||||||
credential_id: bytes
|
|
||||||
user_id: UUID
|
|
||||||
aaguid: UUID
|
|
||||||
public_key: bytes
|
|
||||||
# Mutable fields that may be updated during authentication
|
|
||||||
sign_count: int
|
|
||||||
created_at: datetime
|
|
||||||
last_used: datetime | None = None
|
|
||||||
last_verified: datetime | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class Passkey:
|
class Passkey:
|
||||||
@ -129,7 +114,7 @@ class Passkey:
|
|||||||
expected_challenge: bytes,
|
expected_challenge: bytes,
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
origin: str | None = None,
|
origin: str | None = None,
|
||||||
) -> StoredCredential:
|
) -> Credential:
|
||||||
"""
|
"""
|
||||||
Verify registration response.
|
Verify registration response.
|
||||||
|
|
||||||
@ -147,7 +132,7 @@ class Passkey:
|
|||||||
expected_origin=origin or self.origin,
|
expected_origin=origin or self.origin,
|
||||||
expected_rp_id=self.rp_id,
|
expected_rp_id=self.rp_id,
|
||||||
)
|
)
|
||||||
return StoredCredential(
|
return Credential(
|
||||||
credential_id=credential.raw_id,
|
credential_id=credential.raw_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
aaguid=UUID(registration.aaguid),
|
aaguid=UUID(registration.aaguid),
|
||||||
@ -195,7 +180,7 @@ class Passkey:
|
|||||||
self,
|
self,
|
||||||
credential: AuthenticationCredential,
|
credential: AuthenticationCredential,
|
||||||
expected_challenge: bytes,
|
expected_challenge: bytes,
|
||||||
stored_cred: StoredCredential,
|
stored_cred: Credential,
|
||||||
origin: str | None = None,
|
origin: str | None = None,
|
||||||
) -> VerifiedAuthentication:
|
) -> VerifiedAuthentication:
|
||||||
"""
|
"""
|
||||||
|
@ -2,8 +2,18 @@ import secrets
|
|||||||
|
|
||||||
from .wordlist import words
|
from .wordlist import words
|
||||||
|
|
||||||
|
N_WORDS = 5
|
||||||
|
|
||||||
def generate(n=4, sep="."):
|
wset = set(words)
|
||||||
|
|
||||||
|
|
||||||
|
def generate(n=N_WORDS, sep="."):
|
||||||
"""Generate a password of random words without repeating any word."""
|
"""Generate a password of random words without repeating any word."""
|
||||||
wl = list(words)
|
wl = words.copy()
|
||||||
return sep.join(wl.pop(secrets.randbelow(len(wl))) for i in range(n))
|
return sep.join(wl.pop(secrets.randbelow(len(wl))) for i in range(n))
|
||||||
|
|
||||||
|
|
||||||
|
def is_well_formed(passphrase: str, n=N_WORDS, sep=".") -> bool:
|
||||||
|
"""Check if the passphrase is well-formed according to the regex pattern."""
|
||||||
|
p = passphrase.split(sep)
|
||||||
|
return len(p) == n and all(w in wset for w in passphrase.split("."))
|
||||||
|
@ -1,88 +0,0 @@
|
|||||||
"""
|
|
||||||
Database session management for WebAuthn authentication.
|
|
||||||
|
|
||||||
This module provides session management using database tokens instead of JWT tokens.
|
|
||||||
Session tokens are stored in the database and validated on each request.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Optional
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from fastapi import Request
|
|
||||||
|
|
||||||
from ..db import sql
|
|
||||||
|
|
||||||
|
|
||||||
def get_client_info(request: Request) -> dict:
|
|
||||||
"""Extract client information from FastAPI request and return as dict."""
|
|
||||||
# Get client IP (handle X-Forwarded-For for proxies)
|
|
||||||
# Get user agent
|
|
||||||
return {
|
|
||||||
"client_ip": request.client.host if request.client else "",
|
|
||||||
"user_agent": request.headers.get("user-agent", "")[:500],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_client_info_from_websocket(ws) -> dict:
|
|
||||||
"""Extract client information from WebSocket connection and return as dict."""
|
|
||||||
# Get client IP from WebSocket
|
|
||||||
client_ip = None
|
|
||||||
if hasattr(ws, "client") and ws.client:
|
|
||||||
client_ip = ws.client.host
|
|
||||||
|
|
||||||
# Check for forwarded headers
|
|
||||||
if hasattr(ws, "headers"):
|
|
||||||
forwarded_for = ws.headers.get("x-forwarded-for")
|
|
||||||
if forwarded_for:
|
|
||||||
client_ip = forwarded_for.split(",")[0].strip()
|
|
||||||
|
|
||||||
# Get user agent from WebSocket headers
|
|
||||||
user_agent = None
|
|
||||||
if hasattr(ws, "headers"):
|
|
||||||
user_agent = ws.headers.get("user-agent")
|
|
||||||
# Truncate user agent if too long
|
|
||||||
if user_agent and len(user_agent) > 500: # Keep some margin
|
|
||||||
user_agent = user_agent[:500]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"client_ip": client_ip,
|
|
||||||
"user_agent": user_agent,
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
|
||||||
"connection_type": "websocket",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def create_session_token(
|
|
||||||
user_id: UUID, credential_id: bytes, info: dict | None = None
|
|
||||||
) -> str:
|
|
||||||
"""Create a session token for a user."""
|
|
||||||
return await sql.create_session_by_credential_id(user_id, credential_id, None, info)
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_session_token(token: str) -> Optional[dict]:
|
|
||||||
"""Validate a session token."""
|
|
||||||
session_data = await sql.get_session(token)
|
|
||||||
if not session_data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return {
|
|
||||||
"user_id": session_data["user_id"],
|
|
||||||
"credential_id": session_data["credential_id"],
|
|
||||||
"created_at": session_data["created_at"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def refresh_session_token(token: str) -> Optional[str]:
|
|
||||||
"""Refresh a session token."""
|
|
||||||
return await sql.refresh_session(token)
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_session_token(token: str) -> None:
|
|
||||||
"""Delete a session token."""
|
|
||||||
await sql.delete_session(token)
|
|
||||||
|
|
||||||
|
|
||||||
async def logout_session(token: str) -> None:
|
|
||||||
"""Log out a user by deleting their session token."""
|
|
||||||
await sql.delete_session(token)
|
|
17
passkey/util/tokens.py
Normal file
17
passkey/util/tokens.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
|
||||||
|
def create_token() -> str:
|
||||||
|
return secrets.token_urlsafe(12) # 16 characters Base64
|
||||||
|
|
||||||
|
|
||||||
|
def session_key(token: str) -> bytes:
|
||||||
|
if len(token) != 16:
|
||||||
|
raise ValueError("Session token must be exactly 16 characters long")
|
||||||
|
return b"sess" + base64.urlsafe_b64decode(token)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_key(passphrase: str) -> bytes:
|
||||||
|
return b"rset" + hashlib.sha512(passphrase.encode()).digest()[:12]
|
Loading…
x
Reference in New Issue
Block a user