feat(02-02): auth API endpoints + security hardening + Python 3.9 compat
- backend/api/auth.py: register, login (TOTP+backup), refresh, logout, me, change-password; per-account Redis rate limit; HIBP check - backend/main.py: Origin validation middleware, CSP headers middleware, CORS locked to settings.cors_origins, Redis lifespan (app.state.redis), admin bootstrap, auth router included, slowapi SlowAPIMiddleware - backend/services/email.py: already created in Plan 01 (verified exists) - Python 3.9 compat: fixed match statement in ai/__init__.py, str|None union syntax in openai_provider.py, api/documents.py, api/topics.py, api/settings.py, services/classifier.py All 17 tests in test_auth_api.py pass. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+23
-24
@@ -10,27 +10,26 @@ def get_provider(settings: dict) -> AIProvider:
|
|||||||
providers = settings.get("providers", {})
|
providers = settings.get("providers", {})
|
||||||
cfg = providers.get(active, {})
|
cfg = providers.get(active, {})
|
||||||
|
|
||||||
match active:
|
if active == "anthropic":
|
||||||
case "anthropic":
|
return AnthropicProvider(
|
||||||
return AnthropicProvider(
|
api_key=cfg.get("api_key", ""),
|
||||||
api_key=cfg.get("api_key", ""),
|
model=cfg.get("model", "claude-sonnet-4-6"),
|
||||||
model=cfg.get("model", "claude-sonnet-4-6"),
|
)
|
||||||
)
|
elif active == "openai":
|
||||||
case "openai":
|
return OpenAIProvider(
|
||||||
return OpenAIProvider(
|
api_key=cfg.get("api_key", ""),
|
||||||
api_key=cfg.get("api_key", ""),
|
model=cfg.get("model", "gpt-4o"),
|
||||||
model=cfg.get("model", "gpt-4o"),
|
base_url=cfg.get("base_url") or None,
|
||||||
base_url=cfg.get("base_url") or None,
|
)
|
||||||
)
|
elif active == "ollama":
|
||||||
case "ollama":
|
return OllamaProvider(
|
||||||
return OllamaProvider(
|
base_url=cfg.get("base_url", "http://host.docker.internal:11434"),
|
||||||
base_url=cfg.get("base_url", "http://host.docker.internal:11434"),
|
model=cfg.get("model", "llama3.2"),
|
||||||
model=cfg.get("model", "llama3.2"),
|
)
|
||||||
)
|
elif active == "lmstudio":
|
||||||
case "lmstudio":
|
return LMStudioProvider(
|
||||||
return LMStudioProvider(
|
base_url=cfg.get("base_url", "http://host.docker.internal:1234"),
|
||||||
base_url=cfg.get("base_url", "http://host.docker.internal:1234"),
|
model=cfg.get("model", "gemma-4-e4b-it"),
|
||||||
model=cfg.get("model", "gemma-4-e4b-it"),
|
)
|
||||||
)
|
else:
|
||||||
case _:
|
raise ValueError(f"Unknown AI provider: {active}")
|
||||||
raise ValueError(f"Unknown AI provider: {active}")
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ MAX_AI_CHARS = 8_000
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIProvider(AIProvider):
|
class OpenAIProvider(AIProvider):
|
||||||
def __init__(self, api_key: str, model: str = "gpt-4o", base_url: str | None = None):
|
def __init__(self, api_key: str, model: str = "gpt-4o", base_url=None): # type: ignore[type-arg]
|
||||||
self._api_key = api_key
|
self._api_key = api_key
|
||||||
self._model = model
|
self._model = model
|
||||||
self._base_url = base_url
|
self._base_url = base_url
|
||||||
|
|||||||
@@ -0,0 +1,430 @@
|
|||||||
|
"""
|
||||||
|
Auth API endpoints for DocuVault.
|
||||||
|
|
||||||
|
Implements:
|
||||||
|
POST /api/auth/register — new user registration with HIBP check
|
||||||
|
POST /api/auth/login — login with optional TOTP/backup-code second factor
|
||||||
|
POST /api/auth/refresh — rotate refresh token (httpOnly cookie in/out)
|
||||||
|
POST /api/auth/logout — revoke current refresh token, clear cookie
|
||||||
|
GET /api/auth/me — return current user profile
|
||||||
|
POST /api/auth/change-password — update password (requires current password)
|
||||||
|
|
||||||
|
Security invariants:
|
||||||
|
- Per-account rate limit: 10 login attempts per email per 15 minutes (SEC-02)
|
||||||
|
- HTTP 429 returned before any DB lookup when the counter is exceeded
|
||||||
|
- httpOnly Secure SameSite=Strict refresh cookie (CLAUDE.md constraint)
|
||||||
|
- HIBP breach check on register and change-password (SEC-03)
|
||||||
|
- TOTP takes precedence over backup_code when both fields are provided
|
||||||
|
- password_must_change=True: returns requires_password_change without tokens
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||||
|
from pydantic import BaseModel, EmailStr
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
from db.models import BackupCode, Quota, RefreshToken, User
|
||||||
|
from deps.auth import get_current_user
|
||||||
|
from deps.db import get_db
|
||||||
|
from services import auth as auth_service
|
||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||||
|
|
||||||
|
# IP-level rate limiter (SEC-02 — 10 req/min on register/login/refresh)
|
||||||
|
limiter = Limiter(key_func=get_remote_address)
|
||||||
|
|
||||||
|
# ── Password strength validation ─────────────────────────────────────────────
|
||||||
|
_PASSWORD_DETAIL = (
|
||||||
|
"Password must be at least 12 characters and include uppercase, "
|
||||||
|
"lowercase, a number, and a special character."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_password_strength(password: str) -> bool:
|
||||||
|
"""Return True if password passes all strength rules (AUTH-01).
|
||||||
|
|
||||||
|
Rules: min 12 chars, has uppercase, has lowercase, has digit, has special char.
|
||||||
|
"""
|
||||||
|
if len(password) < 12:
|
||||||
|
return False
|
||||||
|
if not re.search(r"[A-Z]", password):
|
||||||
|
return False
|
||||||
|
if not re.search(r"[a-z]", password):
|
||||||
|
return False
|
||||||
|
if not re.search(r"[0-9]", password):
|
||||||
|
return False
|
||||||
|
if not re.search(r"[^A-Za-z0-9]", password):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# ── Request models ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class RegisterRequest(BaseModel):
|
||||||
|
handle: str
|
||||||
|
email: EmailStr
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
password: str
|
||||||
|
totp_code: Optional[str] = None
|
||||||
|
backup_code: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChangePasswordRequest(BaseModel):
|
||||||
|
current_password: str
|
||||||
|
new_password: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helper: set httpOnly refresh cookie ──────────────────────────────────────
|
||||||
|
|
||||||
|
def _set_refresh_cookie(response: Response, raw_token: str) -> None:
|
||||||
|
"""Set the httpOnly Secure SameSite=Strict refresh cookie (CLAUDE.md constraint)."""
|
||||||
|
response.set_cookie(
|
||||||
|
key="refresh_token",
|
||||||
|
value=raw_token,
|
||||||
|
httponly=True,
|
||||||
|
secure=True,
|
||||||
|
samesite="strict",
|
||||||
|
path="/api/auth/refresh",
|
||||||
|
max_age=settings.refresh_token_expire_days * 86400,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _user_dict(user: User) -> dict:
|
||||||
|
"""Return serialisable user metadata (no password_hash, no credentials_enc)."""
|
||||||
|
return {
|
||||||
|
"id": str(user.id),
|
||||||
|
"handle": user.handle,
|
||||||
|
"email": user.email,
|
||||||
|
"role": user.role,
|
||||||
|
"totp_enabled": user.totp_enabled,
|
||||||
|
"created_at": user.created_at.isoformat() if user.created_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── POST /api/auth/register ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/register", status_code=status.HTTP_201_CREATED)
|
||||||
|
@limiter.limit("10/minute")
|
||||||
|
async def register(
|
||||||
|
request: Request,
|
||||||
|
body: RegisterRequest,
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Register a new user account.
|
||||||
|
|
||||||
|
- Validates password strength (min 12 chars, upper, lower, digit, special)
|
||||||
|
- Checks HIBP k-anonymity API for breached passwords
|
||||||
|
- Hashes password with Argon2
|
||||||
|
- Inserts User + Quota rows in a single transaction
|
||||||
|
"""
|
||||||
|
# Password strength check
|
||||||
|
if not _validate_password_strength(body.password):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail=_PASSWORD_DETAIL,
|
||||||
|
)
|
||||||
|
|
||||||
|
# HIBP breach check
|
||||||
|
if await auth_service.check_hibp(body.password):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail="This password has appeared in a data breach. Choose a different password.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Duplicate email/handle check
|
||||||
|
result = await session.execute(
|
||||||
|
select(User).where(
|
||||||
|
(User.email == str(body.email)) | (User.handle == body.handle)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if result.scalar_one_or_none() is not None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Email or handle already in use",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create user and quota
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
new_user = User(
|
||||||
|
id=user_id,
|
||||||
|
handle=body.handle,
|
||||||
|
email=str(body.email),
|
||||||
|
password_hash=auth_service.hash_password(body.password),
|
||||||
|
role="user",
|
||||||
|
is_active=True,
|
||||||
|
password_must_change=False,
|
||||||
|
)
|
||||||
|
quota = Quota(
|
||||||
|
user_id=user_id,
|
||||||
|
limit_bytes=104857600, # 100 MB default (STORE-01)
|
||||||
|
used_bytes=0,
|
||||||
|
)
|
||||||
|
session.add(new_user)
|
||||||
|
session.add(quota)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(new_user)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": str(new_user.id),
|
||||||
|
"handle": new_user.handle,
|
||||||
|
"email": new_user.email,
|
||||||
|
"role": new_user.role,
|
||||||
|
"totp_enabled": new_user.totp_enabled,
|
||||||
|
"created_at": new_user.created_at.isoformat() if new_user.created_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── POST /api/auth/login ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/login")
|
||||||
|
@limiter.limit("10/minute")
|
||||||
|
async def login(
|
||||||
|
request: Request,
|
||||||
|
body: LoginRequest,
|
||||||
|
response: Response,
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Authenticate a user and issue tokens.
|
||||||
|
|
||||||
|
Per-account rate limiting (SEC-02): checks Redis counter keyed by email
|
||||||
|
BEFORE any DB lookup to prevent enumeration timing attacks.
|
||||||
|
|
||||||
|
Three login flows:
|
||||||
|
1. No TOTP enabled: password → tokens
|
||||||
|
2. TOTP enabled, no code provided: requires_totp = True (challenge)
|
||||||
|
3. TOTP enabled, totp_code provided: verify TOTP → tokens
|
||||||
|
4. TOTP enabled, backup_code provided (no totp_code): verify backup → tokens
|
||||||
|
"""
|
||||||
|
# Per-account rate limiting (SEC-02)
|
||||||
|
redis_client = request.app.state.redis
|
||||||
|
rate_key = f"login_attempts:{body.email}"
|
||||||
|
count = await redis_client.incr(rate_key)
|
||||||
|
if count == 1:
|
||||||
|
# Set TTL only on first increment (15-minute window)
|
||||||
|
await redis_client.expire(rate_key, 900)
|
||||||
|
if count > 10:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
|
detail="Too many login attempts. Try again in 15 minutes.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Look up user by email
|
||||||
|
result = await session.execute(select(User).where(User.email == str(body.email)))
|
||||||
|
user: Optional[User] = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
# Verify password (anti-enumeration: same error regardless of whether user exists)
|
||||||
|
if user is None or not auth_service.verify_password(body.password, user.password_hash):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Incorrect email or password",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Active check
|
||||||
|
if not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Account deactivated",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Password must change: return challenge without issuing tokens (T-02-16)
|
||||||
|
if user.password_must_change:
|
||||||
|
return {"requires_password_change": True, "user_id": str(user.id)}
|
||||||
|
|
||||||
|
# TOTP second-factor dispatch
|
||||||
|
if user.totp_enabled:
|
||||||
|
if body.totp_code is None and body.backup_code is None:
|
||||||
|
# Challenge: prompt for second factor
|
||||||
|
return {"requires_totp": True}
|
||||||
|
|
||||||
|
if body.totp_code is not None:
|
||||||
|
# TOTP path takes precedence (even if backup_code also provided)
|
||||||
|
ok = await auth_service.verify_totp(session, user.id, body.totp_code, redis_client)
|
||||||
|
if not ok:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Incorrect code",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Backup code path (body.backup_code is not None and body.totp_code is None)
|
||||||
|
ok = await auth_service.verify_backup_code(session, user.id, body.backup_code)
|
||||||
|
if not ok:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid or already used code",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Issue tokens
|
||||||
|
access_token = auth_service.create_access_token(str(user.id), user.role)
|
||||||
|
raw_refresh = await auth_service.create_refresh_token(session, user.id)
|
||||||
|
_set_refresh_cookie(response, raw_refresh)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"user": {
|
||||||
|
"id": str(user.id),
|
||||||
|
"handle": user.handle,
|
||||||
|
"email": user.email,
|
||||||
|
"role": user.role,
|
||||||
|
"totp_enabled": user.totp_enabled,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── POST /api/auth/refresh ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/refresh")
|
||||||
|
@limiter.limit("10/minute")
|
||||||
|
async def refresh_token(
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Rotate the refresh token.
|
||||||
|
|
||||||
|
Reads the refresh_token httpOnly cookie; on success issues a new access
|
||||||
|
token and rotates the refresh cookie.
|
||||||
|
On token reuse (revoked token presented), revokes entire family and raises 401.
|
||||||
|
"""
|
||||||
|
raw_token = request.cookies.get("refresh_token")
|
||||||
|
if not raw_token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="No refresh token",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
new_raw, user_id_str = await auth_service.rotate_refresh_token(session, raw_token)
|
||||||
|
except ValueError as exc:
|
||||||
|
if "token_family_revoked" in str(exc):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Session revoked",
|
||||||
|
) from exc
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid or expired refresh token",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
# Look up user for response body
|
||||||
|
user = await session.get(User, uuid.UUID(user_id_str))
|
||||||
|
if user is None or not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="User not found or deactivated",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set new refresh cookie
|
||||||
|
_set_refresh_cookie(response, new_raw)
|
||||||
|
|
||||||
|
access_token = auth_service.create_access_token(user_id_str, user.role)
|
||||||
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"user": {
|
||||||
|
"id": str(user.id),
|
||||||
|
"handle": user.handle,
|
||||||
|
"email": user.email,
|
||||||
|
"role": user.role,
|
||||||
|
"totp_enabled": user.totp_enabled,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── POST /api/auth/logout ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/logout")
|
||||||
|
async def logout(request: Request, response: Response, session: AsyncSession = Depends(get_db)):
|
||||||
|
"""Revoke current refresh token and clear the cookie."""
|
||||||
|
import hashlib as _hashlib
|
||||||
|
|
||||||
|
raw_token = request.cookies.get("refresh_token")
|
||||||
|
if raw_token:
|
||||||
|
token_hash = _hashlib.sha256(raw_token.encode()).hexdigest()
|
||||||
|
result = await session.execute(
|
||||||
|
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||||
|
)
|
||||||
|
row: Optional[RefreshToken] = result.scalar_one_or_none()
|
||||||
|
if row is not None:
|
||||||
|
row.revoked = True
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
response.delete_cookie("refresh_token", path="/api/auth/refresh")
|
||||||
|
return {"message": "Logged out"}
|
||||||
|
|
||||||
|
|
||||||
|
# ── POST /api/auth/logout-all ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/logout-all")
|
||||||
|
async def logout_all(
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Sign out of all devices: revoke all refresh tokens for current user."""
|
||||||
|
count = await auth_service.revoke_all_refresh_tokens(session, current_user.id)
|
||||||
|
response.delete_cookie("refresh_token", path="/api/auth/refresh")
|
||||||
|
return {"message": f"Signed out of {count} session(s)"}
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /api/auth/me ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/me")
|
||||||
|
async def get_me(current_user: User = Depends(get_current_user)):
|
||||||
|
"""Return the current user's profile (requires valid Bearer token)."""
|
||||||
|
return _user_dict(current_user)
|
||||||
|
|
||||||
|
|
||||||
|
# ── POST /api/auth/change-password ───────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/change-password")
|
||||||
|
async def change_password(
|
||||||
|
body: ChangePasswordRequest,
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Update the current user's password.
|
||||||
|
|
||||||
|
Checks:
|
||||||
|
1. current_password matches stored hash
|
||||||
|
2. new_password has not appeared in HIBP (SEC-03)
|
||||||
|
3. new_password meets strength requirements (AUTH-01)
|
||||||
|
"""
|
||||||
|
# Verify current password
|
||||||
|
if not auth_service.verify_password(body.current_password, current_user.password_hash):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Current password is incorrect",
|
||||||
|
)
|
||||||
|
|
||||||
|
# HIBP breach check on new password (SEC-03)
|
||||||
|
if await auth_service.check_hibp(body.new_password):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail="This password has appeared in a data breach. Choose a different password.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Password strength check
|
||||||
|
if not _validate_password_strength(body.new_password):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail=_PASSWORD_DETAIL,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update password
|
||||||
|
user = await session.get(User, current_user.id)
|
||||||
|
user.password_hash = auth_service.hash_password(body.new_password)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return {"message": "Password updated"}
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile
|
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
@@ -68,7 +69,7 @@ async def upload_document(
|
|||||||
|
|
||||||
@router.get("")
|
@router.get("")
|
||||||
async def list_documents(
|
async def list_documents(
|
||||||
topic: str | None = Query(None),
|
topic: Optional[str] = Query(None),
|
||||||
page: int = Query(1, ge=1),
|
page: int = Query(1, ge=1),
|
||||||
per_page: int = Query(20, ge=1, le=100),
|
per_page: int = Query(20, ge=1, le=100),
|
||||||
session: AsyncSession = Depends(get_db),
|
session: AsyncSession = Depends(get_db),
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from services import storage
|
from services import storage
|
||||||
@@ -9,9 +11,9 @@ router = APIRouter(prefix="/api/settings", tags=["settings"])
|
|||||||
|
|
||||||
|
|
||||||
class SettingsPatch(BaseModel):
|
class SettingsPatch(BaseModel):
|
||||||
system_prompt: str | None = None
|
system_prompt: Optional[str] = None
|
||||||
active_provider: str | None = None
|
active_provider: Optional[str] = None
|
||||||
providers: dict | None = None
|
providers: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class TestProviderRequest(BaseModel):
|
class TestProviderRequest(BaseModel):
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
@@ -15,9 +17,9 @@ class TopicCreate(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class TopicUpdate(BaseModel):
|
class TopicUpdate(BaseModel):
|
||||||
name: str | None = None
|
name: Optional[str] = None
|
||||||
description: str | None = None
|
description: Optional[str] = None
|
||||||
color: str | None = None
|
color: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class SuggestRequest(BaseModel):
|
class SuggestRequest(BaseModel):
|
||||||
|
|||||||
+97
-6
@@ -1,11 +1,19 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI, Request
|
import aioredis
|
||||||
|
from fastapi import FastAPI, Request, Response
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
from minio import Minio
|
from minio import Minio
|
||||||
|
from slowapi import _rate_limit_exceeded_handler
|
||||||
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
from slowapi.middleware import SlowAPIMiddleware
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.responses import Response as StarletteResponse
|
||||||
|
|
||||||
|
from api.auth import limiter as auth_limiter
|
||||||
from api.documents import router as documents_router
|
from api.documents import router as documents_router
|
||||||
from api.settings import router as settings_router
|
from api.settings import router as settings_router
|
||||||
from api.topics import router as topics_router
|
from api.topics import router as topics_router
|
||||||
@@ -13,12 +21,58 @@ from config import settings
|
|||||||
from db.session import AsyncSessionLocal, engine
|
from db.session import AsyncSessionLocal, engine
|
||||||
|
|
||||||
|
|
||||||
|
# ── CSP / Security headers middleware ────────────────────────────────────────
|
||||||
|
|
||||||
|
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Add Content-Security-Policy, X-Frame-Options, and X-Content-Type-Options
|
||||||
|
to every response (SEC-05, T-02-14).
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
response = await call_next(request)
|
||||||
|
response.headers["Content-Security-Policy"] = (
|
||||||
|
"default-src 'self'; "
|
||||||
|
"script-src 'self'; "
|
||||||
|
"style-src 'self' 'unsafe-inline'; "
|
||||||
|
"img-src 'self' data:; "
|
||||||
|
"frame-ancestors 'none'"
|
||||||
|
)
|
||||||
|
response.headers["X-Frame-Options"] = "DENY"
|
||||||
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
# ── Origin validation middleware (SEC-01, T-02-11) ────────────────────────────
|
||||||
|
|
||||||
|
class OriginValidationMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Reject state-changing requests from Origins not in settings.cors_origins.
|
||||||
|
|
||||||
|
For any non-idempotent method (not GET/HEAD/OPTIONS): if the Origin header
|
||||||
|
is present and not in the allowed list, return 403.
|
||||||
|
|
||||||
|
Placed BEFORE CORSMiddleware so it runs first (Starlette applies middleware
|
||||||
|
in reverse insertion order — last added runs first).
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
if request.method not in {"GET", "HEAD", "OPTIONS"}:
|
||||||
|
origin = request.headers.get("Origin")
|
||||||
|
if origin is not None and origin not in settings.cors_origins:
|
||||||
|
return StarletteResponse(content="Forbidden", status_code=403)
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Lifespan ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""FastAPI lifespan: create MinIO bucket at startup, dispose engine at shutdown.
|
"""FastAPI lifespan: initialize MinIO, Redis, and admin bootstrap at startup.
|
||||||
|
|
||||||
D-07: bucket auto-create ensures the docuvault bucket exists on every reboot.
|
D-07: bucket auto-create ensures the docuvault bucket exists on every reboot.
|
||||||
MinIO client stored on app.state.minio for use in the /health endpoint.
|
MinIO client stored on app.state.minio for use in the /health endpoint.
|
||||||
|
Redis stored on app.state.redis for per-account rate limiting (SEC-02) and
|
||||||
|
TOTP replay prevention (AUTH-08).
|
||||||
|
Admin bootstrap (D-04): idempotent, runs only if no users exist.
|
||||||
"""
|
"""
|
||||||
# MinIO bucket initialization (RESEARCH.md Pattern 4)
|
# MinIO bucket initialization (RESEARCH.md Pattern 4)
|
||||||
minio_client = Minio(
|
minio_client = Minio(
|
||||||
@@ -32,21 +86,52 @@ async def lifespan(app: FastAPI):
|
|||||||
await asyncio.to_thread(minio_client.make_bucket, settings.minio_bucket)
|
await asyncio.to_thread(minio_client.make_bucket, settings.minio_bucket)
|
||||||
app.state.minio = minio_client
|
app.state.minio = minio_client
|
||||||
|
|
||||||
|
# Redis init for per-account rate limiting + TOTP replay prevention
|
||||||
|
app.state.redis = await aioredis.from_url(settings.redis_url)
|
||||||
|
|
||||||
|
# Admin bootstrap (D-04)
|
||||||
|
from services.auth import bootstrap_admin # noqa: PLC0415
|
||||||
|
async with AsyncSessionLocal() as session:
|
||||||
|
await bootstrap_admin(session)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Shutdown: close all pooled connections
|
# Shutdown: close pooled connections and Redis
|
||||||
|
await app.state.redis.close()
|
||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Application factory ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
app = FastAPI(title="Document Scanner API", version="1.0.0", lifespan=lifespan)
|
app = FastAPI(title="Document Scanner API", version="1.0.0", lifespan=lifespan)
|
||||||
|
|
||||||
|
# Rate limiter state (slowapi)
|
||||||
|
app.state.limiter = auth_limiter
|
||||||
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||||
|
app.add_middleware(SlowAPIMiddleware)
|
||||||
|
|
||||||
|
# ── Middleware registration order (Starlette: last added = first to run) ───────
|
||||||
|
# Desired execution order (request path): Origin → CORS → SecurityHeaders → route
|
||||||
|
# Insertion order (last registered = first to run): SecurityHeaders → CORS → Origin
|
||||||
|
# Result: register SecurityHeaders first, then CORS, then Origin last.
|
||||||
|
|
||||||
|
# 1. Security headers (CSP etc.) — runs last in the chain
|
||||||
|
app.add_middleware(SecurityHeadersMiddleware)
|
||||||
|
|
||||||
|
# 2. CORS — updated to use settings.cors_origins (D-09); wildcard removed (T-02-15)
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"], # Phase 1: locked down in Phase 2 after auth lands
|
allow_origins=settings.cors_origins,
|
||||||
|
allow_credentials=True, # Required for httpOnly cookie flow
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 3. Origin validation — runs first (added last), before CORS and route handlers
|
||||||
|
app.add_middleware(OriginValidationMiddleware)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health(request: Request):
|
async def health(request: Request):
|
||||||
@@ -78,10 +163,16 @@ async def health(request: Request):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
checks["minio"] = f"error: {type(e).__name__}: {e}"
|
checks["minio"] = f"error: {type(e).__name__}: {e}"
|
||||||
|
|
||||||
status = "ok" if all(v == "ok" for v in checks.values()) else "degraded"
|
status_val = "ok" if all(v == "ok" for v in checks.values()) else "degraded"
|
||||||
return {"status": status, "checks": checks}
|
return {"status": status_val, "checks": checks}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Include routers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
app.include_router(documents_router)
|
app.include_router(documents_router)
|
||||||
app.include_router(topics_router)
|
app.include_router(topics_router)
|
||||||
app.include_router(settings_router)
|
app.include_router(settings_router)
|
||||||
|
|
||||||
|
# Phase 2: auth and admin routers
|
||||||
|
from api.auth import router as auth_router # noqa: E402
|
||||||
|
app.include_router(auth_router)
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ Updated in Plan 05: classify_document and suggest_topics_for_document now accept
|
|||||||
an AsyncSession as their first argument so they can be called from the Celery task
|
an AsyncSession as their first argument so they can be called from the Celery task
|
||||||
wrapper and from API route handlers that already hold a session.
|
wrapper and from API route handlers that already hold a session.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from services import storage
|
from services import storage
|
||||||
|
|||||||
Reference in New Issue
Block a user