a548266461
- Add backend/ai/utils.py — parse_classification, parse_suggestions, strip_code_fences shared by all AI providers; removes duplicated private functions from anthropic_provider.py and openai_provider.py - Add backend/deps/utils.py — get_client_ip, parse_uuid request-parsing helpers; removes local _ip() variants from admin.py, auth.py, shares.py, folders.py - Add backend/storage/exceptions.py — canonical CloudConnectionError definition; all routers and backends import from here instead of redefining - Move validate_password_strength to backend/services/auth.py; removes duplicated _validate_password_strength from admin.py and auth.py Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
267 lines
9.8 KiB
Python
267 lines
9.8 KiB
Python
"""
|
|
Unit tests for AI provider JSON parsing robustness and classifier orchestration.
|
|
Uses a mock provider — no real AI calls made.
|
|
"""
|
|
import json
|
|
import pytest
|
|
from ai.utils import parse_classification as _parse_classification, parse_suggestions as _parse_suggestions, strip_code_fences as _strip_code_fences
|
|
from ai.base import ClassificationResult
|
|
|
|
|
|
def test_parse_clean_json():
|
|
raw = '{"assigned_topics": ["finance", "invoices"], "new_topic_suggestions": []}'
|
|
result = _parse_classification(raw)
|
|
assert result.topics == ["finance", "invoices"]
|
|
assert result.suggested_new_topics == []
|
|
|
|
|
|
def test_parse_with_code_fence():
|
|
raw = '```json\n{"assigned_topics": ["legal"], "new_topic_suggestions": ["contracts"]}\n```'
|
|
result = _parse_classification(raw)
|
|
assert result.topics == ["legal"]
|
|
assert result.suggested_new_topics == ["contracts"]
|
|
|
|
|
|
def test_parse_with_preamble():
|
|
raw = 'Here is the classification:\n{"assigned_topics": ["hr"], "new_topic_suggestions": []}\nDone.'
|
|
result = _parse_classification(raw)
|
|
assert result.topics == ["hr"]
|
|
|
|
|
|
def test_parse_malformed_returns_empty():
|
|
raw = "I cannot classify this document."
|
|
result = _parse_classification(raw)
|
|
assert result.topics == []
|
|
assert result.suggested_new_topics == []
|
|
|
|
|
|
def test_strip_code_fences():
|
|
raw = "```json\n{}\n```"
|
|
assert _strip_code_fences(raw) == "{}"
|
|
|
|
|
|
def test_parse_suggestions_clean():
|
|
raw = '{"suggested_topics": ["Human Resources", "Onboarding"]}'
|
|
result = _parse_suggestions(raw)
|
|
assert "Human Resources" in result
|
|
assert "Onboarding" in result
|
|
|
|
|
|
def test_parse_suggestions_with_fence():
|
|
raw = "```\n{\"suggested_topics\": [\"Finance\"]}\n```"
|
|
result = _parse_suggestions(raw)
|
|
assert result == ["Finance"]
|
|
|
|
|
|
def test_parse_suggestions_malformed():
|
|
raw = "No suggestions available."
|
|
result = _parse_suggestions(raw)
|
|
assert result == []
|
|
|
|
|
|
@pytest.mark.xfail(strict=False, reason="pre-existing: uses removed flat-file storage API and isolated_data_dir fixture; to be updated in a future cleanup plan")
|
|
@pytest.mark.asyncio
|
|
async def test_classifier_with_mock_provider(isolated_data_dir):
|
|
"""Test classifier orchestration with a mock provider."""
|
|
from unittest.mock import AsyncMock, patch
|
|
from ai.base import ClassificationResult
|
|
import services.storage as st
|
|
|
|
# Create a document
|
|
doc_id = "test-doc-1"
|
|
st.save_metadata({
|
|
"id": doc_id,
|
|
"original_name": "test.txt",
|
|
"filename": "test-doc-1.txt",
|
|
"mime_type": "text/plain",
|
|
"size_bytes": 50,
|
|
"extracted_text": "Invoice for services rendered in March 2026.",
|
|
"topics": [],
|
|
"created_at": "2026-01-01T00:00:00Z",
|
|
"classified_at": None,
|
|
})
|
|
|
|
# Create some topics
|
|
st.create_topic("Finance")
|
|
st.create_topic("Legal")
|
|
|
|
mock_result = ClassificationResult(
|
|
topics=["Finance"],
|
|
suggested_new_topics=["Invoices"],
|
|
reasoning="Document is about financial invoicing.",
|
|
)
|
|
|
|
with patch("services.classifier.get_provider") as mock_get_provider:
|
|
mock_provider = AsyncMock()
|
|
mock_provider.classify = AsyncMock(return_value=mock_result)
|
|
mock_get_provider.return_value = mock_provider
|
|
|
|
from services.classifier import classify_document
|
|
topics = await classify_document(doc_id)
|
|
|
|
assert "Finance" in topics
|
|
assert "Invoices" in topics
|
|
|
|
# Verify new topic was auto-created
|
|
all_topics = st.load_topics()
|
|
assert any(t["name"] == "Invoices" for t in all_topics)
|
|
|
|
# Verify document was updated
|
|
meta = st.get_metadata(doc_id)
|
|
assert "Finance" in meta["topics"]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Per-user AI provider resolution tests — Plan 03-04 (D-14, D-15, DOC-03, DOC-05)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_per_user_provider(db_session):
|
|
"""When ai_provider='openai' and ai_model='gpt-4o' are passed to the classifier,
|
|
it resolves _settings['active_provider'] == 'openai'.
|
|
|
|
DOC-03: AI provider/model comes from the user's DB record (passed through from
|
|
_run) not from global config or the retired load_settings() flat file (D-14).
|
|
"""
|
|
from unittest.mock import AsyncMock, patch, MagicMock
|
|
from ai.base import ClassificationResult
|
|
from services.classifier import classify_document
|
|
import uuid
|
|
|
|
doc_id = str(uuid.uuid4())
|
|
user_id = uuid.uuid4()
|
|
|
|
mock_meta = {"extracted_text": "Sample document text for testing."}
|
|
mock_doc = MagicMock()
|
|
mock_doc.user_id = user_id
|
|
|
|
captured_settings = {}
|
|
|
|
def capture_get_provider(settings):
|
|
captured_settings.update(settings)
|
|
mock_provider = MagicMock()
|
|
mock_provider.classify = AsyncMock(return_value=ClassificationResult(
|
|
topics=[], suggested_new_topics=[], reasoning=""
|
|
))
|
|
return mock_provider
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session.get = AsyncMock(return_value=mock_doc)
|
|
|
|
with patch("services.classifier.storage.get_metadata", AsyncMock(return_value=mock_meta)), \
|
|
patch("services.classifier.storage.load_topics_for_user", AsyncMock(return_value=[])), \
|
|
patch("services.classifier.storage.load_topics", AsyncMock(return_value=[])), \
|
|
patch("services.classifier.storage.update_document_topics", AsyncMock(return_value=None)), \
|
|
patch("services.classifier.get_provider", side_effect=capture_get_provider):
|
|
await classify_document(mock_session, doc_id, ai_provider="openai", ai_model="gpt-4o")
|
|
|
|
assert captured_settings.get("active_provider") == "openai"
|
|
assert "openai" in captured_settings.get("providers", {})
|
|
assert captured_settings["providers"]["openai"]["model"] == "gpt-4o"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_celery_task_uses_user_provider(db_session):
|
|
"""Calling _run(document_id) for a Document owned by user.ai_provider='anthropic'
|
|
calls classifier with ai_provider='anthropic'.
|
|
|
|
DOC-05: the Celery extract_and_classify task resolves per-user AI config via
|
|
a second DB lookup (doc.user_id → user.ai_provider/ai_model) and passes it
|
|
to the classifier (CONTEXT.md D-14).
|
|
|
|
Note: deferred imports inside _run are patched at their module paths.
|
|
"""
|
|
from unittest.mock import AsyncMock, patch, MagicMock
|
|
import uuid
|
|
|
|
doc_id = str(uuid.uuid4())
|
|
user_id = uuid.uuid4()
|
|
|
|
mock_doc = MagicMock()
|
|
mock_doc.user_id = user_id
|
|
mock_doc.object_key = f"{user_id}/{doc_id}/file.txt"
|
|
mock_doc.content_type = "text/plain"
|
|
mock_doc.extracted_text = ""
|
|
mock_doc.status = "uploaded"
|
|
mock_doc.storage_backend = "minio"
|
|
|
|
mock_user = MagicMock()
|
|
mock_user.ai_provider = "anthropic"
|
|
mock_user.ai_model = "claude-sonnet-4-6"
|
|
|
|
classify_calls = []
|
|
|
|
async def capture_classify(session, document_id, ai_provider=None, ai_model=None):
|
|
classify_calls.append({"ai_provider": ai_provider, "ai_model": ai_model})
|
|
return []
|
|
|
|
mock_session = AsyncMock()
|
|
# session.get called twice: first for Document, then for User
|
|
mock_session.get = AsyncMock(side_effect=[mock_doc, mock_user])
|
|
mock_session.commit = AsyncMock()
|
|
|
|
mock_backend = AsyncMock()
|
|
mock_backend.get_object = AsyncMock(return_value=b"file bytes")
|
|
|
|
mock_session_cm = MagicMock()
|
|
mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_session_cm.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Patch at the source module paths since _run uses deferred imports
|
|
with patch("db.session.AsyncSessionLocal", return_value=mock_session_cm), \
|
|
patch("services.extractor.extract_text_from_bytes", return_value="document text"), \
|
|
patch("services.classifier.classify_document", capture_classify), \
|
|
patch("storage.get_storage_backend", return_value=mock_backend):
|
|
|
|
from tasks.document_tasks import _run
|
|
await _run(doc_id)
|
|
|
|
assert len(classify_calls) == 1
|
|
assert classify_calls[0]["ai_provider"] == "anthropic"
|
|
assert classify_calls[0]["ai_model"] == "claude-sonnet-4-6"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_default_provider_fallback(db_session):
|
|
"""When user.ai_provider is None, the classifier receives config.settings.default_ai_provider.
|
|
|
|
D-15: fallback chain is user.ai_provider → DEFAULT_AI_PROVIDER env var →
|
|
code default 'ollama' (CONTEXT.md D-15).
|
|
"""
|
|
from unittest.mock import AsyncMock, patch, MagicMock
|
|
from ai.base import ClassificationResult
|
|
from services.classifier import classify_document
|
|
import uuid
|
|
|
|
doc_id = str(uuid.uuid4())
|
|
user_id = uuid.uuid4()
|
|
|
|
mock_meta = {"extracted_text": "Sample document text."}
|
|
mock_doc = MagicMock()
|
|
mock_doc.user_id = user_id
|
|
|
|
captured_settings = {}
|
|
|
|
def capture_get_provider(settings):
|
|
captured_settings.update(settings)
|
|
mock_provider = MagicMock()
|
|
mock_provider.classify = AsyncMock(return_value=ClassificationResult(
|
|
topics=[], suggested_new_topics=[], reasoning=""
|
|
))
|
|
return mock_provider
|
|
|
|
with patch("services.classifier.storage.get_metadata", AsyncMock(return_value=mock_meta)), \
|
|
patch("services.classifier.storage.load_topics_for_user", AsyncMock(return_value=[])), \
|
|
patch("services.classifier.storage.load_topics", AsyncMock(return_value=[])), \
|
|
patch("services.classifier.storage.update_document_topics", AsyncMock(return_value=None)), \
|
|
patch("services.classifier.get_provider", side_effect=capture_get_provider):
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session.get = AsyncMock(return_value=mock_doc)
|
|
# Pass ai_provider=None to trigger the default fallback (D-15)
|
|
await classify_document(mock_session, doc_id, ai_provider=None, ai_model=None)
|
|
|
|
# Should fall back to app_settings.default_ai_provider = "ollama"
|
|
assert captured_settings.get("active_provider") == "ollama"
|