chore: initial commit — existing single-user document scanner codebase
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
pytest configuration: isolate each test with a temporary data directory.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import pytest
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolated_data_dir(monkeypatch, tmp_path):
|
||||
"""Each test gets its own clean data directory."""
|
||||
data_dir = tmp_path / "data"
|
||||
(data_dir / "uploads").mkdir(parents=True)
|
||||
(data_dir / "metadata").mkdir(parents=True)
|
||||
(data_dir / "topics.json").write_text(json.dumps({"topics": []}))
|
||||
|
||||
from config import DEFAULT_SETTINGS
|
||||
(data_dir / "settings.json").write_text(json.dumps(DEFAULT_SETTINGS))
|
||||
|
||||
monkeypatch.setenv("DATA_DIR", str(data_dir))
|
||||
|
||||
# Patch the module-level path constants so the running app sees the temp dir
|
||||
import config
|
||||
monkeypatch.setattr(config, "DATA_DIR", data_dir)
|
||||
monkeypatch.setattr(config, "UPLOADS_DIR", data_dir / "uploads")
|
||||
monkeypatch.setattr(config, "METADATA_DIR", data_dir / "metadata")
|
||||
monkeypatch.setattr(config, "TOPICS_FILE", data_dir / "topics.json")
|
||||
monkeypatch.setattr(config, "SETTINGS_FILE", data_dir / "settings.json")
|
||||
|
||||
import services.storage as st
|
||||
from filelock import FileLock
|
||||
monkeypatch.setattr(st, "UPLOADS_DIR", data_dir / "uploads")
|
||||
monkeypatch.setattr(st, "METADATA_DIR", data_dir / "metadata")
|
||||
monkeypatch.setattr(st, "TOPICS_FILE", data_dir / "topics.json")
|
||||
monkeypatch.setattr(st, "SETTINGS_FILE", data_dir / "settings.json")
|
||||
monkeypatch.setattr(st, "_topics_lock", FileLock(str(data_dir / "topics.json") + ".lock"))
|
||||
monkeypatch.setattr(st, "_settings_lock", FileLock(str(data_dir / "settings.json") + ".lock"))
|
||||
|
||||
yield data_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(isolated_data_dir):
|
||||
from main import app
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_txt(tmp_path):
|
||||
p = tmp_path / "sample.txt"
|
||||
p.write_text("This is a test document about invoices and finance.")
|
||||
return p
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_pdf(tmp_path):
|
||||
"""Create a minimal valid PDF for testing."""
|
||||
import fitz
|
||||
doc = fitz.open()
|
||||
page = doc.new_page()
|
||||
page.insert_text((50, 50), "Test PDF document about contracts and legal matters.")
|
||||
pdf_path = tmp_path / "sample.pdf"
|
||||
doc.save(str(pdf_path))
|
||||
doc.close()
|
||||
return pdf_path
|
||||
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Unit tests for AI provider JSON parsing robustness and classifier orchestration.
|
||||
Uses a mock provider — no real AI calls made.
|
||||
"""
|
||||
import json
|
||||
import pytest
|
||||
from ai.openai_provider import _parse_classification, _parse_suggestions, _strip_code_fences
|
||||
from ai.base import ClassificationResult
|
||||
|
||||
|
||||
def test_parse_clean_json():
|
||||
raw = '{"assigned_topics": ["finance", "invoices"], "new_topic_suggestions": []}'
|
||||
result = _parse_classification(raw)
|
||||
assert result.topics == ["finance", "invoices"]
|
||||
assert result.suggested_new_topics == []
|
||||
|
||||
|
||||
def test_parse_with_code_fence():
|
||||
raw = '```json\n{"assigned_topics": ["legal"], "new_topic_suggestions": ["contracts"]}\n```'
|
||||
result = _parse_classification(raw)
|
||||
assert result.topics == ["legal"]
|
||||
assert result.suggested_new_topics == ["contracts"]
|
||||
|
||||
|
||||
def test_parse_with_preamble():
|
||||
raw = 'Here is the classification:\n{"assigned_topics": ["hr"], "new_topic_suggestions": []}\nDone.'
|
||||
result = _parse_classification(raw)
|
||||
assert result.topics == ["hr"]
|
||||
|
||||
|
||||
def test_parse_malformed_returns_empty():
|
||||
raw = "I cannot classify this document."
|
||||
result = _parse_classification(raw)
|
||||
assert result.topics == []
|
||||
assert result.suggested_new_topics == []
|
||||
|
||||
|
||||
def test_strip_code_fences():
|
||||
raw = "```json\n{}\n```"
|
||||
assert _strip_code_fences(raw) == "{}"
|
||||
|
||||
|
||||
def test_parse_suggestions_clean():
|
||||
raw = '{"suggested_topics": ["Human Resources", "Onboarding"]}'
|
||||
result = _parse_suggestions(raw)
|
||||
assert "Human Resources" in result
|
||||
assert "Onboarding" in result
|
||||
|
||||
|
||||
def test_parse_suggestions_with_fence():
|
||||
raw = "```\n{\"suggested_topics\": [\"Finance\"]}\n```"
|
||||
result = _parse_suggestions(raw)
|
||||
assert result == ["Finance"]
|
||||
|
||||
|
||||
def test_parse_suggestions_malformed():
|
||||
raw = "No suggestions available."
|
||||
result = _parse_suggestions(raw)
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classifier_with_mock_provider(isolated_data_dir):
|
||||
"""Test classifier orchestration with a mock provider."""
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from ai.base import ClassificationResult
|
||||
import services.storage as st
|
||||
|
||||
# Create a document
|
||||
doc_id = "test-doc-1"
|
||||
st.save_metadata({
|
||||
"id": doc_id,
|
||||
"original_name": "test.txt",
|
||||
"filename": "test-doc-1.txt",
|
||||
"mime_type": "text/plain",
|
||||
"size_bytes": 50,
|
||||
"extracted_text": "Invoice for services rendered in March 2026.",
|
||||
"topics": [],
|
||||
"created_at": "2026-01-01T00:00:00Z",
|
||||
"classified_at": None,
|
||||
})
|
||||
|
||||
# Create some topics
|
||||
st.create_topic("Finance")
|
||||
st.create_topic("Legal")
|
||||
|
||||
mock_result = ClassificationResult(
|
||||
topics=["Finance"],
|
||||
suggested_new_topics=["Invoices"],
|
||||
reasoning="Document is about financial invoicing.",
|
||||
)
|
||||
|
||||
with patch("services.classifier.get_provider") as mock_get_provider:
|
||||
mock_provider = AsyncMock()
|
||||
mock_provider.classify = AsyncMock(return_value=mock_result)
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
from services.classifier import classify_document
|
||||
topics = await classify_document(doc_id)
|
||||
|
||||
assert "Finance" in topics
|
||||
assert "Invoices" in topics
|
||||
|
||||
# Verify new topic was auto-created
|
||||
all_topics = st.load_topics()
|
||||
assert any(t["name"] == "Invoices" for t in all_topics)
|
||||
|
||||
# Verify document was updated
|
||||
meta = st.get_metadata(doc_id)
|
||||
assert "Finance" in meta["topics"]
|
||||
@@ -0,0 +1,107 @@
|
||||
def test_upload_txt_no_classify(client, sample_txt):
|
||||
with open(sample_txt, "rb") as f:
|
||||
resp = client.post(
|
||||
"/api/documents/upload",
|
||||
files={"file": ("sample.txt", f, "text/plain")},
|
||||
data={"auto_classify": "false"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["original_name"] == "sample.txt"
|
||||
assert "extracted_text" in data
|
||||
assert "invoices" in data["extracted_text"].lower() or len(data["extracted_text"]) > 0
|
||||
assert data["topics"] == []
|
||||
assert "id" in data
|
||||
|
||||
|
||||
def test_upload_pdf_no_classify(client, sample_pdf):
|
||||
with open(sample_pdf, "rb") as f:
|
||||
resp = client.post(
|
||||
"/api/documents/upload",
|
||||
files={"file": ("sample.pdf", f, "application/pdf")},
|
||||
data={"auto_classify": "false"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["mime_type"] == "application/pdf"
|
||||
assert len(data["extracted_text"]) > 0
|
||||
|
||||
|
||||
def test_list_documents(client, sample_txt):
|
||||
with open(sample_txt, "rb") as f:
|
||||
client.post(
|
||||
"/api/documents/upload",
|
||||
files={"file": ("a.txt", f, "text/plain")},
|
||||
data={"auto_classify": "false"},
|
||||
)
|
||||
resp = client.get("/api/documents")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] == 1
|
||||
assert len(data["items"]) == 1
|
||||
|
||||
|
||||
def test_list_documents_filter_by_topic(client, sample_txt):
|
||||
with open(sample_txt, "rb") as f:
|
||||
upload = client.post(
|
||||
"/api/documents/upload",
|
||||
files={"file": ("a.txt", f, "text/plain")},
|
||||
data={"auto_classify": "false"},
|
||||
).json()
|
||||
|
||||
import services.storage as st
|
||||
st.update_document_topics(upload["id"], ["finance"])
|
||||
|
||||
resp = client.get("/api/documents?topic=finance")
|
||||
assert resp.json()["total"] == 1
|
||||
|
||||
resp2 = client.get("/api/documents?topic=legal")
|
||||
assert resp2.json()["total"] == 0
|
||||
|
||||
|
||||
def test_get_document(client, sample_txt):
|
||||
with open(sample_txt, "rb") as f:
|
||||
upload = client.post(
|
||||
"/api/documents/upload",
|
||||
files={"file": ("a.txt", f, "text/plain")},
|
||||
data={"auto_classify": "false"},
|
||||
).json()
|
||||
|
||||
resp = client.get(f"/api/documents/{upload['id']}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["id"] == upload["id"]
|
||||
|
||||
|
||||
def test_get_document_not_found(client):
|
||||
resp = client.get("/api/documents/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_delete_document(client, sample_txt):
|
||||
with open(sample_txt, "rb") as f:
|
||||
upload = client.post(
|
||||
"/api/documents/upload",
|
||||
files={"file": ("a.txt", f, "text/plain")},
|
||||
data={"auto_classify": "false"},
|
||||
).json()
|
||||
|
||||
resp = client.delete(f"/api/documents/{upload['id']}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["success"] is True
|
||||
|
||||
resp2 = client.get(f"/api/documents/{upload['id']}")
|
||||
assert resp2.status_code == 404
|
||||
|
||||
|
||||
def test_delete_document_not_found(client):
|
||||
resp = client.delete("/api/documents/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_upload_empty_file(client):
|
||||
resp = client.post(
|
||||
"/api/documents/upload",
|
||||
files={"file": ("empty.txt", b"", "text/plain")},
|
||||
data={"auto_classify": "false"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
@@ -0,0 +1,52 @@
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from services.extractor import extract_text
|
||||
|
||||
|
||||
def test_extract_txt(tmp_path):
|
||||
p = tmp_path / "test.txt"
|
||||
p.write_text("Hello world this is a test document.", encoding="utf-8")
|
||||
text = extract_text(str(p), "text/plain")
|
||||
assert "Hello world" in text
|
||||
|
||||
|
||||
def test_extract_pdf(tmp_path):
|
||||
import fitz
|
||||
doc = fitz.open()
|
||||
page = doc.new_page()
|
||||
page.insert_text((50, 50), "PDF content about legal contracts.")
|
||||
pdf_path = tmp_path / "test.pdf"
|
||||
doc.save(str(pdf_path))
|
||||
doc.close()
|
||||
|
||||
text = extract_text(str(pdf_path), "application/pdf")
|
||||
assert "PDF content" in text
|
||||
|
||||
|
||||
def test_extract_docx(tmp_path):
|
||||
from docx import Document
|
||||
doc = Document()
|
||||
doc.add_paragraph("DOCX paragraph about financial reports.")
|
||||
docx_path = tmp_path / "test.docx"
|
||||
doc.save(str(docx_path))
|
||||
|
||||
text = extract_text(
|
||||
str(docx_path),
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
)
|
||||
assert "DOCX paragraph" in text
|
||||
|
||||
|
||||
def test_extract_unknown_falls_back_to_text(tmp_path):
|
||||
p = tmp_path / "test.csv"
|
||||
p.write_text("col1,col2\nval1,val2", encoding="utf-8")
|
||||
text = extract_text(str(p), "text/csv")
|
||||
assert "col1" in text
|
||||
|
||||
|
||||
def test_extract_truncation(tmp_path):
|
||||
p = tmp_path / "big.txt"
|
||||
p.write_text("A" * 60_000, encoding="utf-8")
|
||||
text = extract_text(str(p), "text/plain")
|
||||
assert len(text) <= 50_100 # 50k + truncation marker
|
||||
assert "truncated" in text
|
||||
@@ -0,0 +1,4 @@
|
||||
def test_health(client):
|
||||
resp = client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"status": "ok"}
|
||||
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Integration test against a live LM Studio instance.
|
||||
Skipped automatically if LM Studio is not reachable.
|
||||
"""
|
||||
import pytest
|
||||
import httpx
|
||||
|
||||
|
||||
def lmstudio_available() -> bool:
|
||||
try:
|
||||
r = httpx.get("http://host.docker.internal:1234/v1/models", timeout=3)
|
||||
return r.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not lmstudio_available(), reason="LM Studio not reachable at host.docker.internal:1234")
|
||||
@pytest.mark.asyncio
|
||||
async def test_lmstudio_health_check():
|
||||
from ai.lmstudio_provider import LMStudioProvider
|
||||
provider = LMStudioProvider(
|
||||
base_url="http://host.docker.internal:1234",
|
||||
model="gemma-4-e4b-it",
|
||||
)
|
||||
ok = await provider.health_check()
|
||||
assert ok, "LM Studio health check failed"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not lmstudio_available(), reason="LM Studio not reachable at host.docker.internal:1234")
|
||||
@pytest.mark.asyncio
|
||||
async def test_lmstudio_classify():
|
||||
from ai.lmstudio_provider import LMStudioProvider
|
||||
from config import DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
provider = LMStudioProvider(
|
||||
base_url="http://host.docker.internal:1234",
|
||||
model="gemma-4-e4b-it",
|
||||
)
|
||||
result = await provider.classify(
|
||||
document_text="This document is an invoice for software development services.",
|
||||
existing_topics=["Finance", "Legal", "HR"],
|
||||
system_prompt=DEFAULT_SYSTEM_PROMPT,
|
||||
)
|
||||
# Result should have some topics assigned or suggested
|
||||
assert isinstance(result.topics, list)
|
||||
assert isinstance(result.suggested_new_topics, list)
|
||||
@@ -0,0 +1,60 @@
|
||||
def test_get_settings_defaults(client):
|
||||
resp = client.get("/api/settings")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["active_provider"] == "lmstudio"
|
||||
assert "system_prompt" in data
|
||||
assert "providers" in data
|
||||
# API keys should be masked or empty
|
||||
for prov in ("anthropic", "openai"):
|
||||
key = data["providers"][prov].get("api_key", "")
|
||||
assert "****" not in key or len(key) <= 8 # masked or empty
|
||||
|
||||
|
||||
def test_patch_system_prompt(client):
|
||||
new_prompt = "Custom system prompt for testing."
|
||||
resp = client.patch("/api/settings", json={"system_prompt": new_prompt})
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp2 = client.get("/api/settings")
|
||||
assert resp2.json()["system_prompt"] == new_prompt
|
||||
|
||||
|
||||
def test_patch_active_provider(client):
|
||||
resp = client.patch("/api/settings", json={"active_provider": "ollama"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["active_provider"] == "ollama"
|
||||
|
||||
|
||||
def test_patch_invalid_provider(client):
|
||||
resp = client.patch("/api/settings", json={"active_provider": "unknown"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_patch_provider_config(client):
|
||||
resp = client.patch("/api/settings", json={
|
||||
"providers": {
|
||||
"ollama": {"model": "mistral", "base_url": "http://host.docker.internal:11434"}
|
||||
}
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["providers"]["ollama"]["model"] == "mistral"
|
||||
|
||||
|
||||
def test_masked_api_key_not_overwritten(client):
|
||||
"""Patching with a masked key should not overwrite the real stored key."""
|
||||
# First set a real key
|
||||
client.patch("/api/settings", json={"providers": {"anthropic": {"api_key": "sk-ant-realkey"}}})
|
||||
# Then patch with masked key (simulating frontend re-submitting)
|
||||
client.patch("/api/settings", json={"providers": {"anthropic": {"api_key": "****key"}}})
|
||||
# The stored key should still be the real one
|
||||
import services.storage as st
|
||||
settings = st.load_settings()
|
||||
assert settings["providers"]["anthropic"]["api_key"] == "sk-ant-realkey"
|
||||
|
||||
|
||||
def test_get_default_prompt(client):
|
||||
resp = client.get("/api/settings/default-prompt")
|
||||
assert resp.status_code == 200
|
||||
assert "system_prompt" in resp.json()
|
||||
assert len(resp.json()["system_prompt"]) > 0
|
||||
@@ -0,0 +1,72 @@
|
||||
def test_list_topics_empty(client):
|
||||
resp = client.get("/api/topics")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["topics"] == []
|
||||
|
||||
|
||||
def test_create_topic(client):
|
||||
resp = client.post("/api/topics", json={"name": "Finance", "description": "Financial docs", "color": "#ff0000"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "Finance"
|
||||
assert data["color"] == "#ff0000"
|
||||
assert "id" in data
|
||||
|
||||
|
||||
def test_create_topic_deduplication(client):
|
||||
client.post("/api/topics", json={"name": "Finance"})
|
||||
resp = client.post("/api/topics", json={"name": "finance"}) # case-insensitive
|
||||
assert resp.status_code == 200
|
||||
topics = client.get("/api/topics").json()["topics"]
|
||||
assert len(topics) == 1
|
||||
|
||||
|
||||
def test_update_topic(client):
|
||||
create = client.post("/api/topics", json={"name": "Old Name"}).json()
|
||||
resp = client.patch(f"/api/topics/{create['id']}", json={"name": "New Name"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "New Name"
|
||||
|
||||
|
||||
def test_update_topic_not_found(client):
|
||||
resp = client.patch("/api/topics/nonexistent", json={"name": "X"})
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_delete_topic(client):
|
||||
create = client.post("/api/topics", json={"name": "ToDelete"}).json()
|
||||
resp = client.delete(f"/api/topics/{create['id']}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["success"] is True
|
||||
|
||||
topics = client.get("/api/topics").json()["topics"]
|
||||
assert not any(t["name"] == "ToDelete" for t in topics)
|
||||
|
||||
|
||||
def test_delete_topic_cascades_to_documents(client, sample_txt):
|
||||
# Create a topic
|
||||
topic = client.post("/api/topics", json={"name": "Legal"}).json()
|
||||
|
||||
# Upload doc (no auto classify to control topics manually)
|
||||
with open(sample_txt, "rb") as f:
|
||||
upload = client.post(
|
||||
"/api/documents/upload",
|
||||
files={"file": ("sample.txt", f, "text/plain")},
|
||||
data={"auto_classify": "false"},
|
||||
).json()
|
||||
|
||||
# Manually set topic on the document via classify endpoint
|
||||
import services.storage as st
|
||||
st.update_document_topics(upload["id"], ["Legal"])
|
||||
|
||||
# Delete topic
|
||||
client.delete(f"/api/topics/{topic['id']}")
|
||||
|
||||
# Verify document no longer has the topic
|
||||
doc = client.get(f"/api/documents/{upload['id']}").json()
|
||||
assert "Legal" not in doc["topics"]
|
||||
|
||||
|
||||
def test_delete_topic_not_found(client):
|
||||
resp = client.delete("/api/topics/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
Reference in New Issue
Block a user