feat(02-02): auth API endpoints + security hardening + Python 3.9 compat

- backend/api/auth.py: register, login (TOTP+backup), refresh, logout,
  me, change-password; per-account Redis rate limit; HIBP check
- backend/main.py: Origin validation middleware, CSP headers middleware,
  CORS locked to settings.cors_origins, Redis lifespan (app.state.redis),
  admin bootstrap, auth router included, slowapi SlowAPIMiddleware
- backend/services/email.py: already created in Plan 01 (verified exists)
- Python 3.9 compat: fixed match statement in ai/__init__.py,
  str|None union syntax in openai_provider.py, api/documents.py,
  api/topics.py, api/settings.py, services/classifier.py

All 17 tests in test_auth_api.py pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
curo1305
2026-05-22 19:35:38 +02:00
parent 1d425d4392
commit 1882edfff6
8 changed files with 565 additions and 38 deletions
+5 -6
View File
@@ -10,27 +10,26 @@ def get_provider(settings: dict) -> AIProvider:
providers = settings.get("providers", {})
cfg = providers.get(active, {})
match active:
case "anthropic":
if active == "anthropic":
return AnthropicProvider(
api_key=cfg.get("api_key", ""),
model=cfg.get("model", "claude-sonnet-4-6"),
)
case "openai":
elif active == "openai":
return OpenAIProvider(
api_key=cfg.get("api_key", ""),
model=cfg.get("model", "gpt-4o"),
base_url=cfg.get("base_url") or None,
)
case "ollama":
elif active == "ollama":
return OllamaProvider(
base_url=cfg.get("base_url", "http://host.docker.internal:11434"),
model=cfg.get("model", "llama3.2"),
)
case "lmstudio":
elif active == "lmstudio":
return LMStudioProvider(
base_url=cfg.get("base_url", "http://host.docker.internal:1234"),
model=cfg.get("model", "gemma-4-e4b-it"),
)
case _:
else:
raise ValueError(f"Unknown AI provider: {active}")
+1 -1
View File
@@ -7,7 +7,7 @@ MAX_AI_CHARS = 8_000
class OpenAIProvider(AIProvider):
def __init__(self, api_key: str, model: str = "gpt-4o", base_url: str | None = None):
def __init__(self, api_key: str, model: str = "gpt-4o", base_url=None): # type: ignore[type-arg]
self._api_key = api_key
self._model = model
self._base_url = base_url
+430
View File
@@ -0,0 +1,430 @@
"""
Auth API endpoints for DocuVault.
Implements:
POST /api/auth/register — new user registration with HIBP check
POST /api/auth/login — login with optional TOTP/backup-code second factor
POST /api/auth/refresh — rotate refresh token (httpOnly cookie in/out)
POST /api/auth/logout — revoke current refresh token, clear cookie
GET /api/auth/me — return current user profile
POST /api/auth/change-password — update password (requires current password)
Security invariants:
- Per-account rate limit: 10 login attempts per email per 15 minutes (SEC-02)
- HTTP 429 returned before any DB lookup when the counter is exceeded
- httpOnly Secure SameSite=Strict refresh cookie (CLAUDE.md constraint)
- HIBP breach check on register and change-password (SEC-03)
- TOTP takes precedence over backup_code when both fields are provided
- password_must_change=True: returns requires_password_change without tokens
"""
from __future__ import annotations
import re
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from pydantic import BaseModel, EmailStr
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from config import settings
from db.models import BackupCode, Quota, RefreshToken, User
from deps.auth import get_current_user
from deps.db import get_db
from services import auth as auth_service
from slowapi import Limiter
from slowapi.util import get_remote_address
router = APIRouter(prefix="/api/auth", tags=["auth"])
# IP-level rate limiter (SEC-02 — 10 req/min on register/login/refresh)
limiter = Limiter(key_func=get_remote_address)
# ── Password strength validation ─────────────────────────────────────────────
_PASSWORD_DETAIL = (
"Password must be at least 12 characters and include uppercase, "
"lowercase, a number, and a special character."
)
def _validate_password_strength(password: str) -> bool:
"""Return True if password passes all strength rules (AUTH-01).
Rules: min 12 chars, has uppercase, has lowercase, has digit, has special char.
"""
if len(password) < 12:
return False
if not re.search(r"[A-Z]", password):
return False
if not re.search(r"[a-z]", password):
return False
if not re.search(r"[0-9]", password):
return False
if not re.search(r"[^A-Za-z0-9]", password):
return False
return True
# ── Request models ────────────────────────────────────────────────────────────
class RegisterRequest(BaseModel):
handle: str
email: EmailStr
password: str
class LoginRequest(BaseModel):
email: EmailStr
password: str
totp_code: Optional[str] = None
backup_code: Optional[str] = None
class ChangePasswordRequest(BaseModel):
current_password: str
new_password: str
# ── Helper: set httpOnly refresh cookie ──────────────────────────────────────
def _set_refresh_cookie(response: Response, raw_token: str) -> None:
"""Set the httpOnly Secure SameSite=Strict refresh cookie (CLAUDE.md constraint)."""
response.set_cookie(
key="refresh_token",
value=raw_token,
httponly=True,
secure=True,
samesite="strict",
path="/api/auth/refresh",
max_age=settings.refresh_token_expire_days * 86400,
)
def _user_dict(user: User) -> dict:
"""Return serialisable user metadata (no password_hash, no credentials_enc)."""
return {
"id": str(user.id),
"handle": user.handle,
"email": user.email,
"role": user.role,
"totp_enabled": user.totp_enabled,
"created_at": user.created_at.isoformat() if user.created_at else None,
}
# ── POST /api/auth/register ───────────────────────────────────────────────────
@router.post("/register", status_code=status.HTTP_201_CREATED)
@limiter.limit("10/minute")
async def register(
request: Request,
body: RegisterRequest,
session: AsyncSession = Depends(get_db),
):
"""Register a new user account.
- Validates password strength (min 12 chars, upper, lower, digit, special)
- Checks HIBP k-anonymity API for breached passwords
- Hashes password with Argon2
- Inserts User + Quota rows in a single transaction
"""
# Password strength check
if not _validate_password_strength(body.password):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=_PASSWORD_DETAIL,
)
# HIBP breach check
if await auth_service.check_hibp(body.password):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="This password has appeared in a data breach. Choose a different password.",
)
# Duplicate email/handle check
result = await session.execute(
select(User).where(
(User.email == str(body.email)) | (User.handle == body.handle)
)
)
if result.scalar_one_or_none() is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Email or handle already in use",
)
# Create user and quota
user_id = uuid.uuid4()
new_user = User(
id=user_id,
handle=body.handle,
email=str(body.email),
password_hash=auth_service.hash_password(body.password),
role="user",
is_active=True,
password_must_change=False,
)
quota = Quota(
user_id=user_id,
limit_bytes=104857600, # 100 MB default (STORE-01)
used_bytes=0,
)
session.add(new_user)
session.add(quota)
await session.commit()
await session.refresh(new_user)
return {
"id": str(new_user.id),
"handle": new_user.handle,
"email": new_user.email,
"role": new_user.role,
"totp_enabled": new_user.totp_enabled,
"created_at": new_user.created_at.isoformat() if new_user.created_at else None,
}
# ── POST /api/auth/login ──────────────────────────────────────────────────────
@router.post("/login")
@limiter.limit("10/minute")
async def login(
request: Request,
body: LoginRequest,
response: Response,
session: AsyncSession = Depends(get_db),
):
"""Authenticate a user and issue tokens.
Per-account rate limiting (SEC-02): checks Redis counter keyed by email
BEFORE any DB lookup to prevent enumeration timing attacks.
Three login flows:
1. No TOTP enabled: password → tokens
2. TOTP enabled, no code provided: requires_totp = True (challenge)
3. TOTP enabled, totp_code provided: verify TOTP → tokens
4. TOTP enabled, backup_code provided (no totp_code): verify backup → tokens
"""
# Per-account rate limiting (SEC-02)
redis_client = request.app.state.redis
rate_key = f"login_attempts:{body.email}"
count = await redis_client.incr(rate_key)
if count == 1:
# Set TTL only on first increment (15-minute window)
await redis_client.expire(rate_key, 900)
if count > 10:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Too many login attempts. Try again in 15 minutes.",
)
# Look up user by email
result = await session.execute(select(User).where(User.email == str(body.email)))
user: Optional[User] = result.scalar_one_or_none()
# Verify password (anti-enumeration: same error regardless of whether user exists)
if user is None or not auth_service.verify_password(body.password, user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
)
# Active check
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account deactivated",
)
# Password must change: return challenge without issuing tokens (T-02-16)
if user.password_must_change:
return {"requires_password_change": True, "user_id": str(user.id)}
# TOTP second-factor dispatch
if user.totp_enabled:
if body.totp_code is None and body.backup_code is None:
# Challenge: prompt for second factor
return {"requires_totp": True}
if body.totp_code is not None:
# TOTP path takes precedence (even if backup_code also provided)
ok = await auth_service.verify_totp(session, user.id, body.totp_code, redis_client)
if not ok:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect code",
)
else:
# Backup code path (body.backup_code is not None and body.totp_code is None)
ok = await auth_service.verify_backup_code(session, user.id, body.backup_code)
if not ok:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or already used code",
)
# Issue tokens
access_token = auth_service.create_access_token(str(user.id), user.role)
raw_refresh = await auth_service.create_refresh_token(session, user.id)
_set_refresh_cookie(response, raw_refresh)
return {
"access_token": access_token,
"user": {
"id": str(user.id),
"handle": user.handle,
"email": user.email,
"role": user.role,
"totp_enabled": user.totp_enabled,
},
}
# ── POST /api/auth/refresh ────────────────────────────────────────────────────
@router.post("/refresh")
@limiter.limit("10/minute")
async def refresh_token(
request: Request,
response: Response,
session: AsyncSession = Depends(get_db),
):
"""Rotate the refresh token.
Reads the refresh_token httpOnly cookie; on success issues a new access
token and rotates the refresh cookie.
On token reuse (revoked token presented), revokes entire family and raises 401.
"""
raw_token = request.cookies.get("refresh_token")
if not raw_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="No refresh token",
)
try:
new_raw, user_id_str = await auth_service.rotate_refresh_token(session, raw_token)
except ValueError as exc:
if "token_family_revoked" in str(exc):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Session revoked",
) from exc
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
) from exc
# Look up user for response body
user = await session.get(User, uuid.UUID(user_id_str))
if user is None or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or deactivated",
)
# Set new refresh cookie
_set_refresh_cookie(response, new_raw)
access_token = auth_service.create_access_token(user_id_str, user.role)
return {
"access_token": access_token,
"user": {
"id": str(user.id),
"handle": user.handle,
"email": user.email,
"role": user.role,
"totp_enabled": user.totp_enabled,
},
}
# ── POST /api/auth/logout ─────────────────────────────────────────────────────
@router.post("/logout")
async def logout(request: Request, response: Response, session: AsyncSession = Depends(get_db)):
"""Revoke current refresh token and clear the cookie."""
import hashlib as _hashlib
raw_token = request.cookies.get("refresh_token")
if raw_token:
token_hash = _hashlib.sha256(raw_token.encode()).hexdigest()
result = await session.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
row: Optional[RefreshToken] = result.scalar_one_or_none()
if row is not None:
row.revoked = True
await session.commit()
response.delete_cookie("refresh_token", path="/api/auth/refresh")
return {"message": "Logged out"}
# ── POST /api/auth/logout-all ─────────────────────────────────────────────────
@router.post("/logout-all")
async def logout_all(
request: Request,
response: Response,
session: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Sign out of all devices: revoke all refresh tokens for current user."""
count = await auth_service.revoke_all_refresh_tokens(session, current_user.id)
response.delete_cookie("refresh_token", path="/api/auth/refresh")
return {"message": f"Signed out of {count} session(s)"}
# ── GET /api/auth/me ──────────────────────────────────────────────────────────
@router.get("/me")
async def get_me(current_user: User = Depends(get_current_user)):
"""Return the current user's profile (requires valid Bearer token)."""
return _user_dict(current_user)
# ── POST /api/auth/change-password ───────────────────────────────────────────
@router.post("/change-password")
async def change_password(
body: ChangePasswordRequest,
session: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Update the current user's password.
Checks:
1. current_password matches stored hash
2. new_password has not appeared in HIBP (SEC-03)
3. new_password meets strength requirements (AUTH-01)
"""
# Verify current password
if not auth_service.verify_password(body.current_password, current_user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Current password is incorrect",
)
# HIBP breach check on new password (SEC-03)
if await auth_service.check_hibp(body.new_password):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="This password has appeared in a data breach. Choose a different password.",
)
# Password strength check
if not _validate_password_strength(body.new_password):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=_PASSWORD_DETAIL,
)
# Update password
user = await session.get(User, current_user.id)
user.password_hash = auth_service.hash_password(body.new_password)
await session.commit()
return {"message": "Password updated"}
+2 -1
View File
@@ -1,4 +1,5 @@
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile
from sqlalchemy.ext.asyncio import AsyncSession
@@ -68,7 +69,7 @@ async def upload_document(
@router.get("")
async def list_documents(
topic: str | None = Query(None),
topic: Optional[str] = Query(None),
page: int = Query(1, ge=1),
per_page: int = Query(20, ge=1, le=100),
session: AsyncSession = Depends(get_db),
+5 -3
View File
@@ -1,4 +1,6 @@
import time
from typing import Optional
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from services import storage
@@ -9,9 +11,9 @@ router = APIRouter(prefix="/api/settings", tags=["settings"])
class SettingsPatch(BaseModel):
system_prompt: str | None = None
active_provider: str | None = None
providers: dict | None = None
system_prompt: Optional[str] = None
active_provider: Optional[str] = None
providers: Optional[dict] = None
class TestProviderRequest(BaseModel):
+5 -3
View File
@@ -1,3 +1,5 @@
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
@@ -15,9 +17,9 @@ class TopicCreate(BaseModel):
class TopicUpdate(BaseModel):
name: str | None = None
description: str | None = None
color: str | None = None
name: Optional[str] = None
description: Optional[str] = None
color: Optional[str] = None
class SuggestRequest(BaseModel):
+97 -6
View File
@@ -1,11 +1,19 @@
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
import aioredis
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from minio import Minio
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from sqlalchemy import text
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response as StarletteResponse
from api.auth import limiter as auth_limiter
from api.documents import router as documents_router
from api.settings import router as settings_router
from api.topics import router as topics_router
@@ -13,12 +21,58 @@ from config import settings
from db.session import AsyncSessionLocal, engine
# ── CSP / Security headers middleware ────────────────────────────────────────
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add Content-Security-Policy, X-Frame-Options, and X-Content-Type-Options
to every response (SEC-05, T-02-14).
"""
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data:; "
"frame-ancestors 'none'"
)
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-Content-Type-Options"] = "nosniff"
return response
# ── Origin validation middleware (SEC-01, T-02-11) ────────────────────────────
class OriginValidationMiddleware(BaseHTTPMiddleware):
"""Reject state-changing requests from Origins not in settings.cors_origins.
For any non-idempotent method (not GET/HEAD/OPTIONS): if the Origin header
is present and not in the allowed list, return 403.
Placed BEFORE CORSMiddleware so it runs first (Starlette applies middleware
in reverse insertion order — last added runs first).
"""
async def dispatch(self, request: Request, call_next):
if request.method not in {"GET", "HEAD", "OPTIONS"}:
origin = request.headers.get("Origin")
if origin is not None and origin not in settings.cors_origins:
return StarletteResponse(content="Forbidden", status_code=403)
return await call_next(request)
# ── Lifespan ──────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
"""FastAPI lifespan: create MinIO bucket at startup, dispose engine at shutdown.
"""FastAPI lifespan: initialize MinIO, Redis, and admin bootstrap at startup.
D-07: bucket auto-create ensures the docuvault bucket exists on every reboot.
MinIO client stored on app.state.minio for use in the /health endpoint.
Redis stored on app.state.redis for per-account rate limiting (SEC-02) and
TOTP replay prevention (AUTH-08).
Admin bootstrap (D-04): idempotent, runs only if no users exist.
"""
# MinIO bucket initialization (RESEARCH.md Pattern 4)
minio_client = Minio(
@@ -32,21 +86,52 @@ async def lifespan(app: FastAPI):
await asyncio.to_thread(minio_client.make_bucket, settings.minio_bucket)
app.state.minio = minio_client
# Redis init for per-account rate limiting + TOTP replay prevention
app.state.redis = await aioredis.from_url(settings.redis_url)
# Admin bootstrap (D-04)
from services.auth import bootstrap_admin # noqa: PLC0415
async with AsyncSessionLocal() as session:
await bootstrap_admin(session)
yield
# Shutdown: close all pooled connections
# Shutdown: close pooled connections and Redis
await app.state.redis.close()
await engine.dispose()
# ── Application factory ───────────────────────────────────────────────────────
app = FastAPI(title="Document Scanner API", version="1.0.0", lifespan=lifespan)
# Rate limiter state (slowapi)
app.state.limiter = auth_limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(SlowAPIMiddleware)
# ── Middleware registration order (Starlette: last added = first to run) ───────
# Desired execution order (request path): Origin → CORS → SecurityHeaders → route
# Insertion order (last registered = first to run): SecurityHeaders → CORS → Origin
# Result: register SecurityHeaders first, then CORS, then Origin last.
# 1. Security headers (CSP etc.) — runs last in the chain
app.add_middleware(SecurityHeadersMiddleware)
# 2. CORS — updated to use settings.cors_origins (D-09); wildcard removed (T-02-15)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Phase 1: locked down in Phase 2 after auth lands
allow_origins=settings.cors_origins,
allow_credentials=True, # Required for httpOnly cookie flow
allow_methods=["*"],
allow_headers=["*"],
)
# 3. Origin validation — runs first (added last), before CORS and route handlers
app.add_middleware(OriginValidationMiddleware)
# ── Routes ────────────────────────────────────────────────────────────────────
@app.get("/health")
async def health(request: Request):
@@ -78,10 +163,16 @@ async def health(request: Request):
except Exception as e:
checks["minio"] = f"error: {type(e).__name__}: {e}"
status = "ok" if all(v == "ok" for v in checks.values()) else "degraded"
return {"status": status, "checks": checks}
status_val = "ok" if all(v == "ok" for v in checks.values()) else "degraded"
return {"status": status_val, "checks": checks}
# ── Include routers ───────────────────────────────────────────────────────────
app.include_router(documents_router)
app.include_router(topics_router)
app.include_router(settings_router)
# Phase 2: auth and admin routers
from api.auth import router as auth_router # noqa: E402
app.include_router(auth_router)
+2
View File
@@ -6,6 +6,8 @@ Updated in Plan 05: classify_document and suggest_topics_for_document now accept
an AsyncSession as their first argument so they can be called from the Celery task
wrapper and from API route handlers that already hold a session.
"""
from __future__ import annotations
from sqlalchemy.ext.asyncio import AsyncSession
from services import storage