""" Unit tests for AI provider JSON parsing robustness and classifier orchestration. Uses a mock provider — no real AI calls made. """ import json import pytest from ai.utils import parse_classification as _parse_classification, parse_suggestions as _parse_suggestions, strip_code_fences as _strip_code_fences from ai.base import ClassificationResult def test_parse_clean_json(): raw = '{"assigned_topics": ["finance", "invoices"], "new_topic_suggestions": []}' result = _parse_classification(raw) assert result.topics == ["finance", "invoices"] assert result.suggested_new_topics == [] def test_parse_with_code_fence(): raw = '```json\n{"assigned_topics": ["legal"], "new_topic_suggestions": ["contracts"]}\n```' result = _parse_classification(raw) assert result.topics == ["legal"] assert result.suggested_new_topics == ["contracts"] def test_parse_with_preamble(): raw = 'Here is the classification:\n{"assigned_topics": ["hr"], "new_topic_suggestions": []}\nDone.' result = _parse_classification(raw) assert result.topics == ["hr"] def test_parse_malformed_returns_empty(): raw = "I cannot classify this document." result = _parse_classification(raw) assert result.topics == [] assert result.suggested_new_topics == [] def test_strip_code_fences(): raw = "```json\n{}\n```" assert _strip_code_fences(raw) == "{}" def test_parse_suggestions_clean(): raw = '{"suggested_topics": ["Human Resources", "Onboarding"]}' result = _parse_suggestions(raw) assert "Human Resources" in result assert "Onboarding" in result def test_parse_suggestions_with_fence(): raw = "```\n{\"suggested_topics\": [\"Finance\"]}\n```" result = _parse_suggestions(raw) assert result == ["Finance"] def test_parse_suggestions_malformed(): raw = "No suggestions available." result = _parse_suggestions(raw) assert result == [] @pytest.mark.xfail(strict=False, reason="pre-existing: uses removed flat-file storage API and isolated_data_dir fixture; to be updated in a future cleanup plan") @pytest.mark.asyncio async def test_classifier_with_mock_provider(isolated_data_dir): """Test classifier orchestration with a mock provider.""" from unittest.mock import AsyncMock, patch from ai.base import ClassificationResult import services.storage as st # Create a document doc_id = "test-doc-1" st.save_metadata({ "id": doc_id, "original_name": "test.txt", "filename": "test-doc-1.txt", "mime_type": "text/plain", "size_bytes": 50, "extracted_text": "Invoice for services rendered in March 2026.", "topics": [], "created_at": "2026-01-01T00:00:00Z", "classified_at": None, }) # Create some topics st.create_topic("Finance") st.create_topic("Legal") mock_result = ClassificationResult( topics=["Finance"], suggested_new_topics=["Invoices"], reasoning="Document is about financial invoicing.", ) with patch("services.classifier.get_provider") as mock_get_provider: mock_provider = AsyncMock() mock_provider.classify = AsyncMock(return_value=mock_result) mock_get_provider.return_value = mock_provider from services.classifier import classify_document topics = await classify_document(doc_id) assert "Finance" in topics assert "Invoices" in topics # Verify new topic was auto-created all_topics = st.load_topics() assert any(t["name"] == "Invoices" for t in all_topics) # Verify document was updated meta = st.get_metadata(doc_id) assert "Finance" in meta["topics"] # --------------------------------------------------------------------------- # Per-user AI provider resolution tests — Plan 03-04 (D-14, D-15, DOC-03, DOC-05) # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_per_user_provider(db_session): """When ai_provider='openai' and ai_model='gpt-4o' are passed to the classifier, it resolves _settings['active_provider'] == 'openai'. DOC-03: AI provider/model comes from the user's DB record (passed through from _run) not from global config or the retired load_settings() flat file (D-14). """ from unittest.mock import AsyncMock, patch, MagicMock from ai.base import ClassificationResult from services.classifier import classify_document import uuid doc_id = str(uuid.uuid4()) user_id = uuid.uuid4() mock_meta = {"extracted_text": "Sample document text for testing."} mock_doc = MagicMock() mock_doc.user_id = user_id captured_settings = {} def capture_get_provider(settings): captured_settings.update(settings) mock_provider = MagicMock() mock_provider.classify = AsyncMock(return_value=ClassificationResult( topics=[], suggested_new_topics=[], reasoning="" )) return mock_provider mock_session = AsyncMock() mock_session.get = AsyncMock(return_value=mock_doc) with patch("services.classifier.storage.get_metadata", AsyncMock(return_value=mock_meta)), \ patch("services.classifier.storage.load_topics_for_user", AsyncMock(return_value=[])), \ patch("services.classifier.storage.load_topics", AsyncMock(return_value=[])), \ patch("services.classifier.storage.update_document_topics", AsyncMock(return_value=None)), \ patch("services.classifier.get_provider", side_effect=capture_get_provider): await classify_document(mock_session, doc_id, ai_provider="openai", ai_model="gpt-4o") assert captured_settings.get("active_provider") == "openai" assert "openai" in captured_settings.get("providers", {}) assert captured_settings["providers"]["openai"]["model"] == "gpt-4o" @pytest.mark.asyncio async def test_celery_task_uses_user_provider(db_session): """Calling _run(document_id) for a Document owned by user.ai_provider='anthropic' calls classifier with ai_provider='anthropic'. DOC-05: the Celery extract_and_classify task resolves per-user AI config via a second DB lookup (doc.user_id → user.ai_provider/ai_model) and passes it to the classifier (CONTEXT.md D-14). Note: deferred imports inside _run are patched at their module paths. """ from unittest.mock import AsyncMock, patch, MagicMock import uuid doc_id = str(uuid.uuid4()) user_id = uuid.uuid4() mock_doc = MagicMock() mock_doc.user_id = user_id mock_doc.object_key = f"{user_id}/{doc_id}/file.txt" mock_doc.content_type = "text/plain" mock_doc.extracted_text = "" mock_doc.status = "uploaded" mock_doc.storage_backend = "minio" mock_user = MagicMock() mock_user.ai_provider = "anthropic" mock_user.ai_model = "claude-sonnet-4-6" classify_calls = [] async def capture_classify(session, document_id, ai_provider=None, ai_model=None): classify_calls.append({"ai_provider": ai_provider, "ai_model": ai_model}) return [] mock_session = AsyncMock() # session.get called twice: first for Document, then for User mock_session.get = AsyncMock(side_effect=[mock_doc, mock_user]) mock_session.commit = AsyncMock() mock_backend = AsyncMock() mock_backend.get_object = AsyncMock(return_value=b"file bytes") mock_session_cm = MagicMock() mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session) mock_session_cm.__aexit__ = AsyncMock(return_value=None) # Patch at the source module paths since _run uses deferred imports with patch("db.session.AsyncSessionLocal", return_value=mock_session_cm), \ patch("services.extractor.extract_text_from_bytes", return_value="document text"), \ patch("services.classifier.classify_document", capture_classify), \ patch("storage.get_storage_backend", return_value=mock_backend): from tasks.document_tasks import _run await _run(doc_id) assert len(classify_calls) == 1 assert classify_calls[0]["ai_provider"] == "anthropic" assert classify_calls[0]["ai_model"] == "claude-sonnet-4-6" @pytest.mark.asyncio async def test_default_provider_fallback(db_session): """When user.ai_provider is None, the classifier receives config.settings.default_ai_provider. D-15: fallback chain is user.ai_provider → DEFAULT_AI_PROVIDER env var → code default 'ollama' (CONTEXT.md D-15). """ from unittest.mock import AsyncMock, patch, MagicMock from ai.base import ClassificationResult from services.classifier import classify_document import uuid doc_id = str(uuid.uuid4()) user_id = uuid.uuid4() mock_meta = {"extracted_text": "Sample document text."} mock_doc = MagicMock() mock_doc.user_id = user_id captured_settings = {} def capture_get_provider(settings): captured_settings.update(settings) mock_provider = MagicMock() mock_provider.classify = AsyncMock(return_value=ClassificationResult( topics=[], suggested_new_topics=[], reasoning="" )) return mock_provider with patch("services.classifier.storage.get_metadata", AsyncMock(return_value=mock_meta)), \ patch("services.classifier.storage.load_topics_for_user", AsyncMock(return_value=[])), \ patch("services.classifier.storage.load_topics", AsyncMock(return_value=[])), \ patch("services.classifier.storage.update_document_topics", AsyncMock(return_value=None)), \ patch("services.classifier.get_provider", side_effect=capture_get_provider): mock_session = AsyncMock() mock_session.get = AsyncMock(return_value=mock_doc) # Pass ai_provider=None to trigger the default fallback (D-15) await classify_document(mock_session, doc_id, ai_provider=None, ai_model=None) # Should fall back to app_settings.default_ai_provider = "ollama" assert captured_settings.get("active_provider") == "ollama"