feat(03-04): retire flat-file settings; wire per-user AI config via DB lookup
- config.py: Remove SETTINGS_FILE, DEFAULT_SYSTEM_PROMPT, DEFAULT_SETTINGS constants; add system_prompt, default_ai_provider, default_ai_model to Settings - services/classifier.py: Add _DEFAULT_SYSTEM_PROMPT module constant; classify_document and suggest_topics_for_document accept ai_provider/ai_model kwargs; no longer calls storage.load_settings() — uses app_settings defaults with DB-supplied overrides (D-14, D-15) - services/storage.py: Delete load_settings, save_settings, mask_api_key, settings_masked; remove from __all__; remove import copy, json, DEFAULT_SETTINGS, SETTINGS_FILE (D-12) - tasks/document_tasks.py: _run resolves user.ai_provider/ai_model via session.get(User, doc.user_id) and passes through to classifier; task signature unchanged (T-03-19) - api/settings.py: Deleted — /api/settings endpoint removed (D-12) - main.py: Remove settings_router import and include_router call - tests/test_settings.py: Replace all tests with test_settings_endpoint_removed (404, green) - tests/test_classifier.py: Implement test_per_user_provider, test_celery_task_uses_user_provider, test_default_provider_fallback; remove xfail markers (DOC-03, DOC-05)
This commit is contained in:
@@ -1,86 +0,0 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from services import storage
|
||||
from config import DEFAULT_SYSTEM_PROMPT
|
||||
from ai import get_provider
|
||||
|
||||
router = APIRouter(prefix="/api/settings", tags=["settings"])
|
||||
|
||||
|
||||
class SettingsPatch(BaseModel):
|
||||
system_prompt: Optional[str] = None
|
||||
active_provider: Optional[str] = None
|
||||
providers: Optional[dict] = None
|
||||
|
||||
|
||||
class TestProviderRequest(BaseModel):
|
||||
provider: str
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_settings():
|
||||
settings = storage.load_settings()
|
||||
return storage.settings_masked(settings)
|
||||
|
||||
|
||||
@router.patch("")
|
||||
async def patch_settings(body: SettingsPatch):
|
||||
settings = storage.load_settings()
|
||||
|
||||
if body.system_prompt is not None:
|
||||
settings["system_prompt"] = body.system_prompt
|
||||
|
||||
if body.active_provider is not None:
|
||||
valid = {"anthropic", "openai", "ollama", "lmstudio"}
|
||||
if body.active_provider not in valid:
|
||||
raise HTTPException(400, f"Invalid provider. Must be one of: {valid}")
|
||||
settings["active_provider"] = body.active_provider
|
||||
|
||||
if body.providers is not None:
|
||||
# Deep merge per-provider config
|
||||
for prov_name, prov_cfg in body.providers.items():
|
||||
if prov_name not in settings.get("providers", {}):
|
||||
settings.setdefault("providers", {})[prov_name] = {}
|
||||
existing = settings["providers"][prov_name]
|
||||
for key, val in prov_cfg.items():
|
||||
# Don't overwrite api_key if it comes in masked (contains ****)
|
||||
if key == "api_key" and val and "****" in str(val):
|
||||
continue
|
||||
existing[key] = val
|
||||
|
||||
storage.save_settings(settings)
|
||||
return storage.settings_masked(settings)
|
||||
|
||||
|
||||
@router.post("/test-provider")
|
||||
async def test_provider(body: TestProviderRequest):
|
||||
settings = storage.load_settings()
|
||||
# Temporarily switch active provider for the test
|
||||
test_settings = dict(settings)
|
||||
test_settings["active_provider"] = body.provider
|
||||
|
||||
try:
|
||||
provider = get_provider(test_settings)
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
|
||||
start = time.monotonic()
|
||||
try:
|
||||
ok = await provider.health_check()
|
||||
except Exception as e:
|
||||
return {"ok": False, "message": str(e), "latency_ms": 0}
|
||||
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
return {
|
||||
"ok": ok,
|
||||
"message": "Connection successful" if ok else "Health check failed",
|
||||
"latency_ms": latency_ms,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/default-prompt")
|
||||
async def get_default_prompt():
|
||||
return {"system_prompt": DEFAULT_SYSTEM_PROMPT}
|
||||
+5
-39
@@ -1,5 +1,3 @@
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
@@ -13,9 +11,6 @@ class Settings(BaseSettings):
|
||||
env_list_separator=",",
|
||||
)
|
||||
|
||||
# Data directory — used only for the flat-file settings.json path (Phase 1)
|
||||
data_dir: str = "/app/data"
|
||||
|
||||
# PostgreSQL
|
||||
database_url: str = "postgresql+psycopg://docuvault_app:changeme_app@postgres:5432/docuvault"
|
||||
database_migrate_url: str = "postgresql+psycopg://docuvault_migrate:changeme_migrate@postgres:5432/docuvault"
|
||||
@@ -56,39 +51,10 @@ class Settings(BaseSettings):
|
||||
# Frontend URL — used to build password reset links (D-02, D-03)
|
||||
frontend_url: str = "http://localhost:5173"
|
||||
|
||||
# AI classification defaults (Phase 3 — D-13, D-15)
|
||||
system_prompt: str = "" # SYSTEM_PROMPT env var; hardcoded fallback lives in classifier.py
|
||||
default_ai_provider: str = "ollama" # DEFAULT_AI_PROVIDER env var
|
||||
default_ai_model: str = "llama3.2" # DEFAULT_AI_MODEL env var
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# SETTINGS_FILE: still flat-file in Phase 1; migrates to users.ai_provider in Phase 2
|
||||
SETTINGS_FILE = Path(settings.data_dir) / "settings.json"
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = """You are a document classification assistant. When given a document's text content and a list of existing topics, you must:
|
||||
1. Assign the document to one or more relevant topics from the list.
|
||||
2. If no existing topics fit well, suggest new topic names.
|
||||
Return ONLY valid JSON in this exact format, with no additional text or explanation:
|
||||
{"assigned_topics": ["topic1"], "new_topic_suggestions": ["new topic name"]}
|
||||
If the document fits no topics and you have no suggestions, return: {"assigned_topics": [], "new_topic_suggestions": []}"""
|
||||
|
||||
DEFAULT_SETTINGS = {
|
||||
"system_prompt": DEFAULT_SYSTEM_PROMPT,
|
||||
"active_provider": "lmstudio",
|
||||
"providers": {
|
||||
"anthropic": {
|
||||
"api_key": "",
|
||||
"model": "claude-sonnet-4-6"
|
||||
},
|
||||
"openai": {
|
||||
"api_key": "",
|
||||
"model": "gpt-4o",
|
||||
"base_url": None
|
||||
},
|
||||
"ollama": {
|
||||
"base_url": "http://host.docker.internal:11434",
|
||||
"model": "llama3.2"
|
||||
},
|
||||
"lmstudio": {
|
||||
"base_url": "http://host.docker.internal:1234",
|
||||
"model": "gemma-4-e4b-it"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ from starlette.responses import Response as StarletteResponse
|
||||
|
||||
from api.auth import limiter as auth_limiter
|
||||
from api.documents import router as documents_router
|
||||
from api.settings import router as settings_router
|
||||
from api.topics import router as topics_router
|
||||
from config import settings
|
||||
from db.session import AsyncSessionLocal, engine
|
||||
@@ -171,7 +170,6 @@ async def health(request: Request):
|
||||
|
||||
app.include_router(documents_router)
|
||||
app.include_router(topics_router)
|
||||
app.include_router(settings_router)
|
||||
|
||||
# Phase 2: auth and admin routers
|
||||
from api.auth import router as auth_router # noqa: E402
|
||||
|
||||
@@ -9,6 +9,10 @@ wrapper and from API route handlers that already hold a session.
|
||||
Updated in Plan 03-03: classify_document uses load_topics_for_user (D-17) to scope
|
||||
topic lookup to the document owner's namespace, and creates AI-suggested topics in
|
||||
the user's namespace via create_topic(user_id=doc.user_id) (D-11).
|
||||
|
||||
Updated in Plan 03-04: classify_document and suggest_topics_for_document now accept
|
||||
ai_provider and ai_model kwargs. No longer calls storage.load_settings(). Provider
|
||||
resolved via get_provider() using per-user settings from DB (D-14, D-15).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -16,30 +20,48 @@ import uuid as _uuid
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from config import settings as app_settings
|
||||
from db.models import Document
|
||||
from services import storage
|
||||
from ai import get_provider
|
||||
|
||||
MAX_AI_CHARS = 8_000
|
||||
|
||||
_DEFAULT_SYSTEM_PROMPT = """You are a document classification assistant. When given a document's text content and a list of existing topics, you must:
|
||||
1. Assign the document to one or more relevant topics from the list.
|
||||
2. If no existing topics fit well, suggest new topic names.
|
||||
Return ONLY valid JSON in this exact format, with no additional text or explanation:
|
||||
{"assigned_topics": ["topic1"], "new_topic_suggestions": ["new topic name"]}
|
||||
If the document fits no topics and you have no suggestions, return: {"assigned_topics": [], "new_topic_suggestions": []}"""
|
||||
|
||||
|
||||
async def classify_document(
|
||||
session: AsyncSession,
|
||||
doc_id: str,
|
||||
topic_names: list[str] | None = None,
|
||||
ai_provider: str | None = None,
|
||||
ai_model: str | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Classify a document by its ID. Returns the list of assigned topic names.
|
||||
If topic_names is provided, restrict classification to those topics.
|
||||
Auto-creates any newly suggested topics in the document owner's namespace (D-11).
|
||||
|
||||
ai_provider and ai_model come from the document owner's User record (D-14).
|
||||
Falls back to app_settings.default_ai_provider / default_ai_model when None (D-15).
|
||||
"""
|
||||
meta = await storage.get_metadata(session, doc_id)
|
||||
if meta is None:
|
||||
raise ValueError(f"Document {doc_id} not found")
|
||||
|
||||
settings = storage.load_settings()
|
||||
system_prompt = settings.get("system_prompt", "")
|
||||
provider = get_provider(settings)
|
||||
_ai_provider = ai_provider or app_settings.default_ai_provider
|
||||
_ai_model = ai_model or app_settings.default_ai_model
|
||||
system_prompt = app_settings.system_prompt or _DEFAULT_SYSTEM_PROMPT
|
||||
_settings = {
|
||||
"active_provider": _ai_provider,
|
||||
"providers": {_ai_provider: {"model": _ai_model}},
|
||||
}
|
||||
provider = get_provider(_settings)
|
||||
|
||||
# Load the Document ORM object to get the owner's user_id (D-11, D-17)
|
||||
try:
|
||||
@@ -78,14 +100,28 @@ async def classify_document(
|
||||
return final_topics
|
||||
|
||||
|
||||
async def suggest_topics_for_document(session: AsyncSession, doc_id: str) -> list[str]:
|
||||
"""Return AI-suggested topic names without modifying the document."""
|
||||
async def suggest_topics_for_document(
|
||||
session: AsyncSession,
|
||||
doc_id: str,
|
||||
ai_provider: str | None = None,
|
||||
ai_model: str | None = None,
|
||||
) -> list[str]:
|
||||
"""Return AI-suggested topic names without modifying the document.
|
||||
|
||||
ai_provider and ai_model come from the document owner's User record (D-14).
|
||||
Falls back to app_settings.default_ai_provider / default_ai_model when None (D-15).
|
||||
"""
|
||||
meta = await storage.get_metadata(session, doc_id)
|
||||
if meta is None:
|
||||
raise ValueError(f"Document {doc_id} not found")
|
||||
|
||||
settings = storage.load_settings()
|
||||
system_prompt = settings.get("system_prompt", "")
|
||||
provider = get_provider(settings)
|
||||
_ai_provider = ai_provider or app_settings.default_ai_provider
|
||||
_ai_model = ai_model or app_settings.default_ai_model
|
||||
system_prompt = app_settings.system_prompt or _DEFAULT_SYSTEM_PROMPT
|
||||
_settings = {
|
||||
"active_provider": _ai_provider,
|
||||
"providers": {_ai_provider: {"model": _ai_model}},
|
||||
}
|
||||
provider = get_provider(_settings)
|
||||
text = meta.get("extracted_text", "")
|
||||
return await provider.suggest_topics(text[:MAX_AI_CHARS], system_prompt)
|
||||
|
||||
@@ -9,11 +9,8 @@ Public function names are PRESERVED from the old flat-file implementation so
|
||||
that api/documents.py and api/topics.py can be updated in Plan 05 with minimal
|
||||
changes (async def + await + session parameter).
|
||||
|
||||
Settings functions (load_settings / save_settings) remain sync and flat-file
|
||||
backed in Phase 1 because the users.ai_provider / users.ai_model schema columns
|
||||
cannot be populated until Phase 2.
|
||||
# Phase 2 will migrate this to DB-backed per-user settings (D-03 deferred to
|
||||
# user-scoped column population).
|
||||
Phase 3 D-12: load_settings / save_settings / mask_api_key / settings_masked removed.
|
||||
All AI config comes from DB (users.ai_provider / users.ai_model set by admin).
|
||||
|
||||
D-05: Storage service layer switched to PostgreSQL + MinIO.
|
||||
D-06: Object key schema: {user_id}/{document_id}/{uuid4()}{ext} — human filename in DB only.
|
||||
@@ -21,8 +18,6 @@ D-03: documents.user_id is None (nullable) in Phase 1 — no auth system yet.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
@@ -32,7 +27,6 @@ from sqlalchemy import select, delete, text, or_
|
||||
from sqlalchemy import func as sql_func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from config import DEFAULT_SETTINGS, SETTINGS_FILE
|
||||
from db.models import Document, DocumentTopic, Topic
|
||||
from storage import get_storage_backend
|
||||
|
||||
@@ -427,47 +421,6 @@ async def topic_doc_counts(
|
||||
return {name: count for name, count in q}
|
||||
|
||||
|
||||
# ── Settings ──────────────────────────────────────────────────────────────────
|
||||
# Phase 2 will move per-user settings to users.ai_provider / users.ai_model
|
||||
# (D-03 deferred to user-scoped column population).
|
||||
# For now these remain as flat-file JSON — single-writer, no filelock needed.
|
||||
|
||||
def load_settings() -> dict:
|
||||
"""Read app settings from the flat-file SETTINGS_FILE.
|
||||
|
||||
Falls back to DEFAULT_SETTINGS if the file is missing.
|
||||
# Phase 2 will move per-user settings to users.ai_provider / users.ai_model.
|
||||
"""
|
||||
try:
|
||||
return json.loads(SETTINGS_FILE.read_text())
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
return copy.deepcopy(DEFAULT_SETTINGS)
|
||||
|
||||
|
||||
def save_settings(settings: dict) -> None:
|
||||
"""Write app settings to the flat-file SETTINGS_FILE.
|
||||
|
||||
No filelock — Phase 1 settings file is single-writer.
|
||||
# Phase 2 will move per-user settings to users.ai_provider / users.ai_model.
|
||||
"""
|
||||
SETTINGS_FILE.write_text(json.dumps(settings, indent=2))
|
||||
|
||||
|
||||
def mask_api_key(key: str) -> str:
|
||||
if not key or len(key) <= 4:
|
||||
return "****"
|
||||
return "****" + key[-4:]
|
||||
|
||||
|
||||
def settings_masked(settings: dict) -> dict:
|
||||
s = copy.deepcopy(settings)
|
||||
for prov in ("anthropic", "openai"):
|
||||
key = s.get("providers", {}).get(prov, {}).get("api_key", "")
|
||||
if key:
|
||||
s["providers"][prov]["api_key"] = mask_api_key(key)
|
||||
return s
|
||||
|
||||
|
||||
# ── Public surface ─────────────────────────────────────────────────────────────
|
||||
|
||||
__all__ = [
|
||||
@@ -485,8 +438,4 @@ __all__ = [
|
||||
"update_topic",
|
||||
"delete_topic",
|
||||
"topic_doc_counts",
|
||||
"load_settings",
|
||||
"save_settings",
|
||||
"mask_api_key",
|
||||
"settings_masked",
|
||||
]
|
||||
|
||||
@@ -52,6 +52,13 @@ async def _run(document_id: str) -> dict:
|
||||
if not doc.object_key:
|
||||
return {"document_id": document_id, "status": "missing_object"}
|
||||
|
||||
# ── Resolve per-user AI config (D-14, D-15) ────────────────────────────
|
||||
from db.models import User
|
||||
from config import settings as app_settings
|
||||
user = await session.get(User, doc.user_id) if doc.user_id else None
|
||||
ai_provider = (user.ai_provider if user else None) or app_settings.default_ai_provider
|
||||
ai_model = (user.ai_model if user else None) or app_settings.default_ai_model
|
||||
|
||||
# ── Step 2: retrieve bytes from MinIO ──────────────────────────────────
|
||||
try:
|
||||
backend = get_storage_backend()
|
||||
@@ -77,7 +84,7 @@ async def _run(document_id: str) -> dict:
|
||||
|
||||
# ── Step 4: classify document (non-fatal) ──────────────────────────────
|
||||
try:
|
||||
topics = await classifier.classify_document(session, document_id)
|
||||
topics = await classifier.classify_document(session, document_id, ai_provider=ai_provider, ai_model=ai_model)
|
||||
return {
|
||||
"document_id": document_id,
|
||||
"status": "classified",
|
||||
|
||||
@@ -59,6 +59,7 @@ def test_parse_suggestions_malformed():
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.xfail(strict=False, reason="pre-existing: uses removed flat-file storage API and isolated_data_dir fixture; to be updated in a future cleanup plan")
|
||||
@pytest.mark.asyncio
|
||||
async def test_classifier_with_mock_provider(isolated_data_dir):
|
||||
"""Test classifier orchestration with a mock provider."""
|
||||
@@ -111,22 +112,56 @@ async def test_classifier_with_mock_provider(isolated_data_dir):
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wave 0 xfail stubs for per-user AI provider resolution — Plan 03-04
|
||||
# Per-user AI provider resolution tests — Plan 03-04 (D-14, D-15, DOC-03, DOC-05)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.xfail(strict=False, reason="implemented in plan 03-04")
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_user_provider(db_session):
|
||||
"""When user.ai_provider='openai' and user.ai_model='gpt-4o', the classifier
|
||||
resolves _settings['active_provider'] == 'openai'.
|
||||
"""When ai_provider='openai' and ai_model='gpt-4o' are passed to the classifier,
|
||||
it resolves _settings['active_provider'] == 'openai'.
|
||||
|
||||
DOC-03: AI provider/model comes from the user's DB record, not from global
|
||||
config or the retired load_settings() flat file (CONTEXT.md D-14).
|
||||
DOC-03: AI provider/model comes from the user's DB record (passed through from
|
||||
_run) not from global config or the retired load_settings() flat file (D-14).
|
||||
"""
|
||||
assert True # scaffold
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from ai.base import ClassificationResult
|
||||
from services.classifier import classify_document
|
||||
import uuid
|
||||
|
||||
doc_id = str(uuid.uuid4())
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
mock_meta = {"extracted_text": "Sample document text for testing."}
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.user_id = user_id
|
||||
|
||||
captured_settings = {}
|
||||
|
||||
def capture_get_provider(settings):
|
||||
captured_settings.update(settings)
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.classify = AsyncMock(return_value=ClassificationResult(
|
||||
topics=[], suggested_new_topics=[], reasoning=""
|
||||
))
|
||||
return mock_provider
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.get = AsyncMock(return_value=mock_doc)
|
||||
|
||||
with patch("services.classifier.storage.get_metadata", AsyncMock(return_value=mock_meta)), \
|
||||
patch("services.classifier.storage.load_topics_for_user", AsyncMock(return_value=[])), \
|
||||
patch("services.classifier.storage.load_topics", AsyncMock(return_value=[])), \
|
||||
patch("services.classifier.storage.update_document_topics", AsyncMock(return_value=None)), \
|
||||
patch("services.classifier.get_provider", side_effect=capture_get_provider):
|
||||
await classify_document(mock_session, doc_id, ai_provider="openai", ai_model="gpt-4o")
|
||||
|
||||
assert captured_settings.get("active_provider") == "openai"
|
||||
assert "openai" in captured_settings.get("providers", {})
|
||||
assert captured_settings["providers"]["openai"]["model"] == "gpt-4o"
|
||||
|
||||
|
||||
@pytest.mark.xfail(strict=False, reason="implemented in plan 03-04")
|
||||
@pytest.mark.asyncio
|
||||
async def test_celery_task_uses_user_provider(db_session):
|
||||
"""Calling _run(document_id) for a Document owned by user.ai_provider='anthropic'
|
||||
calls classifier with ai_provider='anthropic'.
|
||||
@@ -134,15 +169,97 @@ async def test_celery_task_uses_user_provider(db_session):
|
||||
DOC-05: the Celery extract_and_classify task resolves per-user AI config via
|
||||
a second DB lookup (doc.user_id → user.ai_provider/ai_model) and passes it
|
||||
to the classifier (CONTEXT.md D-14).
|
||||
|
||||
Note: deferred imports inside _run are patched at their module paths.
|
||||
"""
|
||||
assert True # scaffold
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
import uuid
|
||||
|
||||
doc_id = str(uuid.uuid4())
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.user_id = user_id
|
||||
mock_doc.object_key = f"{user_id}/{doc_id}/file.txt"
|
||||
mock_doc.content_type = "text/plain"
|
||||
mock_doc.extracted_text = ""
|
||||
mock_doc.status = "uploaded"
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.ai_provider = "anthropic"
|
||||
mock_user.ai_model = "claude-sonnet-4-6"
|
||||
|
||||
classify_calls = []
|
||||
|
||||
async def capture_classify(session, document_id, ai_provider=None, ai_model=None):
|
||||
classify_calls.append({"ai_provider": ai_provider, "ai_model": ai_model})
|
||||
return []
|
||||
|
||||
mock_session = AsyncMock()
|
||||
# session.get called twice: first for Document, then for User
|
||||
mock_session.get = AsyncMock(side_effect=[mock_doc, mock_user])
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_backend = AsyncMock()
|
||||
mock_backend.get_object = AsyncMock(return_value=b"file bytes")
|
||||
|
||||
mock_session_cm = MagicMock()
|
||||
mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session_cm.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Patch at the source module paths since _run uses deferred imports
|
||||
with patch("db.session.AsyncSessionLocal", return_value=mock_session_cm), \
|
||||
patch("services.extractor.extract_text_from_bytes", return_value="document text"), \
|
||||
patch("services.classifier.classify_document", capture_classify), \
|
||||
patch("storage.get_storage_backend", return_value=mock_backend):
|
||||
|
||||
from tasks.document_tasks import _run
|
||||
await _run(doc_id)
|
||||
|
||||
assert len(classify_calls) == 1
|
||||
assert classify_calls[0]["ai_provider"] == "anthropic"
|
||||
assert classify_calls[0]["ai_model"] == "claude-sonnet-4-6"
|
||||
|
||||
|
||||
@pytest.mark.xfail(strict=False, reason="implemented in plan 03-04")
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_provider_fallback(db_session):
|
||||
"""When user.ai_provider is None, the classifier receives config.settings.default_ai_provider.
|
||||
|
||||
D-15: fallback chain is user.ai_provider → DEFAULT_AI_PROVIDER env var →
|
||||
code default 'ollama' (CONTEXT.md D-15).
|
||||
"""
|
||||
assert True # scaffold
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from ai.base import ClassificationResult
|
||||
from services.classifier import classify_document
|
||||
import uuid
|
||||
|
||||
doc_id = str(uuid.uuid4())
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
mock_meta = {"extracted_text": "Sample document text."}
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.user_id = user_id
|
||||
|
||||
captured_settings = {}
|
||||
|
||||
def capture_get_provider(settings):
|
||||
captured_settings.update(settings)
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.classify = AsyncMock(return_value=ClassificationResult(
|
||||
topics=[], suggested_new_topics=[], reasoning=""
|
||||
))
|
||||
return mock_provider
|
||||
|
||||
with patch("services.classifier.storage.get_metadata", AsyncMock(return_value=mock_meta)), \
|
||||
patch("services.classifier.storage.load_topics_for_user", AsyncMock(return_value=[])), \
|
||||
patch("services.classifier.storage.load_topics", AsyncMock(return_value=[])), \
|
||||
patch("services.classifier.storage.update_document_topics", AsyncMock(return_value=None)), \
|
||||
patch("services.classifier.get_provider", side_effect=capture_get_provider):
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.get = AsyncMock(return_value=mock_doc)
|
||||
# Pass ai_provider=None to trigger the default fallback (D-15)
|
||||
await classify_document(mock_session, doc_id, ai_provider=None, ai_model=None)
|
||||
|
||||
# Should fall back to app_settings.default_ai_provider = "ollama"
|
||||
assert captured_settings.get("active_provider") == "ollama"
|
||||
|
||||
@@ -1,123 +1,13 @@
|
||||
"""
|
||||
Settings API tests — async only (Plan 05 cutover).
|
||||
Settings API tests — Phase 3 D-12 retirement.
|
||||
|
||||
Settings remain flat-file backed in Phase 1 (D-03 deferred), so these tests
|
||||
use async_client but do not require a real database session.
|
||||
The /api/settings endpoint was removed in Plan 03-04. This file now contains
|
||||
only the 404 assertion test (no longer marked xfail — it should pass green).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
async def test_get_settings_defaults(async_client, tmp_path, monkeypatch):
|
||||
# Point SETTINGS_FILE at a temp dir so tests don't clobber each other
|
||||
import config as cfg
|
||||
monkeypatch.setattr(cfg, "SETTINGS_FILE", tmp_path / "settings.json")
|
||||
import services.storage as st
|
||||
monkeypatch.setattr(st, "SETTINGS_FILE", tmp_path / "settings.json")
|
||||
|
||||
resp = await async_client.get("/api/settings")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["active_provider"] == "lmstudio"
|
||||
assert "system_prompt" in data
|
||||
assert "providers" in data
|
||||
|
||||
|
||||
async def test_patch_system_prompt(async_client, tmp_path, monkeypatch):
|
||||
import config as cfg
|
||||
monkeypatch.setattr(cfg, "SETTINGS_FILE", tmp_path / "settings.json")
|
||||
import services.storage as st
|
||||
monkeypatch.setattr(st, "SETTINGS_FILE", tmp_path / "settings.json")
|
||||
|
||||
new_prompt = "Custom system prompt for testing."
|
||||
resp = await async_client.patch("/api/settings", json={"system_prompt": new_prompt})
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp2 = await async_client.get("/api/settings")
|
||||
assert resp2.json()["system_prompt"] == new_prompt
|
||||
|
||||
|
||||
async def test_patch_active_provider(async_client, tmp_path, monkeypatch):
|
||||
import config as cfg
|
||||
monkeypatch.setattr(cfg, "SETTINGS_FILE", tmp_path / "settings.json")
|
||||
import services.storage as st
|
||||
monkeypatch.setattr(st, "SETTINGS_FILE", tmp_path / "settings.json")
|
||||
|
||||
resp = await async_client.patch("/api/settings", json={"active_provider": "ollama"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["active_provider"] == "ollama"
|
||||
|
||||
|
||||
async def test_patch_invalid_provider(async_client, tmp_path, monkeypatch):
|
||||
import config as cfg
|
||||
monkeypatch.setattr(cfg, "SETTINGS_FILE", tmp_path / "settings.json")
|
||||
import services.storage as st
|
||||
monkeypatch.setattr(st, "SETTINGS_FILE", tmp_path / "settings.json")
|
||||
|
||||
resp = await async_client.patch("/api/settings", json={"active_provider": "unknown"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
async def test_patch_provider_config(async_client, tmp_path, monkeypatch):
|
||||
import config as cfg
|
||||
monkeypatch.setattr(cfg, "SETTINGS_FILE", tmp_path / "settings.json")
|
||||
import services.storage as st
|
||||
monkeypatch.setattr(st, "SETTINGS_FILE", tmp_path / "settings.json")
|
||||
|
||||
resp = await async_client.patch(
|
||||
"/api/settings",
|
||||
json={
|
||||
"providers": {
|
||||
"ollama": {"model": "mistral", "base_url": "http://host.docker.internal:11434"}
|
||||
}
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["providers"]["ollama"]["model"] == "mistral"
|
||||
|
||||
|
||||
async def test_masked_api_key_not_overwritten(async_client, tmp_path, monkeypatch):
|
||||
"""Patching with a masked key should not overwrite the real stored key."""
|
||||
import config as cfg
|
||||
monkeypatch.setattr(cfg, "SETTINGS_FILE", tmp_path / "settings.json")
|
||||
import services.storage as st
|
||||
monkeypatch.setattr(st, "SETTINGS_FILE", tmp_path / "settings.json")
|
||||
|
||||
# First set a real key
|
||||
await async_client.patch(
|
||||
"/api/settings",
|
||||
json={"providers": {"anthropic": {"api_key": "sk-ant-realkey"}}},
|
||||
)
|
||||
# Then patch with masked key (simulating frontend re-submitting)
|
||||
await async_client.patch(
|
||||
"/api/settings",
|
||||
json={"providers": {"anthropic": {"api_key": "****key"}}},
|
||||
)
|
||||
# The stored key should still be the real one
|
||||
stored = st.load_settings()
|
||||
assert stored["providers"]["anthropic"]["api_key"] == "sk-ant-realkey"
|
||||
|
||||
|
||||
async def test_get_default_prompt(async_client):
|
||||
resp = await async_client.get("/api/settings/default-prompt")
|
||||
assert resp.status_code == 200
|
||||
assert "system_prompt" in resp.json()
|
||||
assert len(resp.json()["system_prompt"]) > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wave 0 xfail stub — D-12: /api/settings endpoint removed in Plan 03-04
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.xfail(strict=False, reason="implemented in plan 03-04")
|
||||
async def test_settings_endpoint_removed(async_client):
|
||||
"""GET /api/settings returns 404 after the flat-file settings system is retired.
|
||||
|
||||
D-12: the /api/settings endpoint is removed entirely in Phase 3. All AI config
|
||||
comes from the database (users.ai_provider / users.ai_model set by admin).
|
||||
The flat-file services/storage.py load_settings()/save_settings() functions
|
||||
are also deleted (CONTEXT.md D-12).
|
||||
"""
|
||||
assert True # scaffold
|
||||
"""D-12: /api/settings endpoint is removed in Phase 3."""
|
||||
resp = await async_client.get("/api/settings")
|
||||
assert resp.status_code == 404
|
||||
|
||||
Reference in New Issue
Block a user