import asyncio from contextlib import asynccontextmanager import aioredis from fastapi import FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from minio import Minio from slowapi import _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from slowapi.middleware import SlowAPIMiddleware from sqlalchemy import text from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import Response as StarletteResponse from api.auth import limiter as auth_limiter from api.documents import router as documents_router from api.topics import router as topics_router from config import settings from db.session import AsyncSessionLocal, engine # ── CSP / Security headers middleware ──────────────────────────────────────── class SecurityHeadersMiddleware(BaseHTTPMiddleware): """Add Content-Security-Policy, X-Frame-Options, and X-Content-Type-Options to every response (SEC-05, T-02-14). """ async def dispatch(self, request: Request, call_next): response = await call_next(request) response.headers["Content-Security-Policy"] = ( "default-src 'self'; " "script-src 'self'; " "style-src 'self' 'unsafe-inline'; " "img-src 'self' data:; " "frame-ancestors 'none'" ) response.headers["X-Frame-Options"] = "DENY" response.headers["X-Content-Type-Options"] = "nosniff" return response # ── Origin validation middleware (SEC-01, T-02-11) ──────────────────────────── class OriginValidationMiddleware(BaseHTTPMiddleware): """Reject state-changing requests from Origins not in settings.cors_origins. For any non-idempotent method (not GET/HEAD/OPTIONS): if the Origin header is present and not in the allowed list, return 403. Placed BEFORE CORSMiddleware so it runs first (Starlette applies middleware in reverse insertion order — last added runs first). """ async def dispatch(self, request: Request, call_next): if request.method not in {"GET", "HEAD", "OPTIONS"}: origin = request.headers.get("Origin") if origin is not None and origin not in settings.cors_origins: return StarletteResponse(content="Forbidden", status_code=403) return await call_next(request) # ── Lifespan ────────────────────────────────────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): """FastAPI lifespan: initialize MinIO, Redis, and admin bootstrap at startup. D-07: bucket auto-create ensures the docuvault bucket exists on every reboot. MinIO client stored on app.state.minio for use in the /health endpoint. Redis stored on app.state.redis for per-account rate limiting (SEC-02) and TOTP replay prevention (AUTH-08). Admin bootstrap (D-04): idempotent, runs only if no users exist. """ # MinIO bucket initialization (RESEARCH.md Pattern 4) minio_client = Minio( settings.minio_endpoint, access_key=settings.minio_access_key, secret_key=settings.minio_secret_key, secure=False, ) exists = await asyncio.to_thread(minio_client.bucket_exists, settings.minio_bucket) if not exists: await asyncio.to_thread(minio_client.make_bucket, settings.minio_bucket) app.state.minio = minio_client # Redis init for per-account rate limiting + TOTP replay prevention app.state.redis = await aioredis.from_url(settings.redis_url) # Admin bootstrap (D-04) from services.auth import bootstrap_admin # noqa: PLC0415 async with AsyncSessionLocal() as session: await bootstrap_admin(session) yield # Shutdown: close pooled connections and Redis await app.state.redis.close() await engine.dispose() # ── Application factory ─────────────────────────────────────────────────────── app = FastAPI(title="Document Scanner API", version="1.0.0", lifespan=lifespan) # Rate limiter state (slowapi) app.state.limiter = auth_limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_middleware(SlowAPIMiddleware) # ── Middleware registration order (Starlette: last added = first to run) ─────── # Desired execution order (request path): Origin → CORS → SecurityHeaders → route # Insertion order (last registered = first to run): SecurityHeaders → CORS → Origin # Result: register SecurityHeaders first, then CORS, then Origin last. # 1. Security headers (CSP etc.) — runs last in the chain app.add_middleware(SecurityHeadersMiddleware) # 2. CORS — updated to use settings.cors_origins (D-09); wildcard removed (T-02-15) app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins, allow_credentials=True, # Required for httpOnly cookie flow allow_methods=["*"], allow_headers=["*"], ) # 3. Origin validation — runs first (added last), before CORS and route handlers app.add_middleware(OriginValidationMiddleware) # ── Routes ──────────────────────────────────────────────────────────────────── @app.get("/health") async def health(request: Request): """Extended health probe: reports PostgreSQL and MinIO connectivity (D-07). Always returns HTTP 200 — 'degraded' status signals a partial outage without causing load-balancer retries. Note (T-01-05-03): error strings expose Python exception class names — acceptable for an internal/dev endpoint in Phase 1. Phase 2 will trim to 'error' or 'unhealthy' once the endpoint is internet-facing. """ checks: dict = {} # PostgreSQL probe try: async with AsyncSessionLocal() as session: await session.execute(text("SELECT 1")) checks["postgres"] = "ok" except Exception as e: checks["postgres"] = f"error: {type(e).__name__}: {e}" # MinIO probe try: ok = await asyncio.to_thread( request.app.state.minio.bucket_exists, settings.minio_bucket ) checks["minio"] = "ok" if ok else "error: bucket missing" except Exception as e: checks["minio"] = f"error: {type(e).__name__}: {e}" status_val = "ok" if all(v == "ok" for v in checks.values()) else "degraded" return {"status": status_val, "checks": checks} # ── Include routers ─────────────────────────────────────────────────────────── app.include_router(documents_router) app.include_router(topics_router) # Phase 2: auth and admin routers from api.auth import router as auth_router # noqa: E402 from api.admin import router as admin_router # noqa: E402 app.include_router(auth_router) app.include_router(admin_router)