32d67de1ca
- Add backend/celery_app.py: Celery("docuvault") with Redis broker, JSON
serialization, and tasks.document_tasks.* routed to documents queue;
reads REDIS_URL directly from os.environ (no config import — Pitfall 7)
- Add backend/tasks/__init__.py: empty package marker
- Add backend/tasks/document_tasks.py: sync extract_and_classify Celery task
that calls asyncio.run(_run()) to retrieve bytes from MinIO, extract text
via extractor, and classify via classifier; classification failure is non-fatal
- Update backend/services/classifier.py: classify_document and
suggest_topics_for_document now accept session: AsyncSession as first arg;
all storage.* calls updated to async session-injection pattern
- Add extract_text_from_bytes helper to services/extractor.py for bytes-based
extraction (used by Celery worker, which retrieves bytes from MinIO)
70 lines
2.6 KiB
Python
70 lines
2.6 KiB
Python
"""
|
|
Classification orchestrator.
|
|
Loads settings, selects AI provider, classifies document, auto-creates suggested topics.
|
|
|
|
Updated in Plan 05: classify_document and suggest_topics_for_document now accept
|
|
an AsyncSession as their first argument so they can be called from the Celery task
|
|
wrapper and from API route handlers that already hold a session.
|
|
"""
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from services import storage
|
|
from ai import get_provider
|
|
|
|
MAX_AI_CHARS = 8_000
|
|
|
|
|
|
async def classify_document(
|
|
session: AsyncSession,
|
|
doc_id: str,
|
|
topic_names: list[str] | None = None,
|
|
) -> list[str]:
|
|
"""
|
|
Classify a document by its ID. Returns the list of assigned topic names.
|
|
If topic_names is provided, restrict classification to those topics.
|
|
Auto-creates any newly suggested topics.
|
|
"""
|
|
meta = await storage.get_metadata(session, doc_id)
|
|
if meta is None:
|
|
raise ValueError(f"Document {doc_id} not found")
|
|
|
|
settings = storage.load_settings()
|
|
system_prompt = settings.get("system_prompt", "")
|
|
provider = get_provider(settings)
|
|
|
|
# Use all known topics if not specified
|
|
if topic_names is None:
|
|
all_topics = await storage.load_topics(session)
|
|
topic_names = [t["name"] for t in all_topics]
|
|
|
|
text = meta.get("extracted_text", "")
|
|
result = await provider.classify(text[:MAX_AI_CHARS], topic_names, system_prompt)
|
|
|
|
# Collect all topic names to persist (assigned + suggested)
|
|
all_new_names = set(result.suggested_new_topics) | set(result.topics)
|
|
|
|
# Auto-create any topic not already in the registry
|
|
existing_names = {t.lower() for t in topic_names}
|
|
for name in all_new_names:
|
|
if name.strip() and name.lower() not in existing_names:
|
|
await storage.create_topic(session, name.strip())
|
|
|
|
# Final list: everything the AI assigned or suggested
|
|
final_topics = [t for t in list(set(result.topics + result.suggested_new_topics)) if t.strip()]
|
|
|
|
await storage.update_document_topics(session, doc_id, final_topics)
|
|
return final_topics
|
|
|
|
|
|
async def suggest_topics_for_document(session: AsyncSession, doc_id: str) -> list[str]:
|
|
"""Return AI-suggested topic names without modifying the document."""
|
|
meta = await storage.get_metadata(session, doc_id)
|
|
if meta is None:
|
|
raise ValueError(f"Document {doc_id} not found")
|
|
|
|
settings = storage.load_settings()
|
|
system_prompt = settings.get("system_prompt", "")
|
|
provider = get_provider(settings)
|
|
text = meta.get("extracted_text", "")
|
|
return await provider.suggest_topics(text[:MAX_AI_CHARS], system_prompt)
|