feat(01-05): introduce celery_app + tasks/document_tasks + session-aware classifier

- Add backend/celery_app.py: Celery("docuvault") with Redis broker, JSON
  serialization, and tasks.document_tasks.* routed to documents queue;
  reads REDIS_URL directly from os.environ (no config import — Pitfall 7)
- Add backend/tasks/__init__.py: empty package marker
- Add backend/tasks/document_tasks.py: sync extract_and_classify Celery task
  that calls asyncio.run(_run()) to retrieve bytes from MinIO, extract text
  via extractor, and classify via classifier; classification failure is non-fatal
- Update backend/services/classifier.py: classify_document and
  suggest_topics_for_document now accept session: AsyncSession as first arg;
  all storage.* calls updated to async session-injection pattern
- Add extract_text_from_bytes helper to services/extractor.py for bytes-based
  extraction (used by Celery worker, which retrieves bytes from MinIO)
This commit is contained in:
curo1305
2026-05-22 09:45:33 +02:00
parent 5d21c6f588
commit 32d67de1ca
5 changed files with 185 additions and 7 deletions
+35
View File
@@ -0,0 +1,35 @@
"""
Celery application factory for DocuVault.
Kept deliberately minimal to avoid circular imports (Pitfall 7 from RESEARCH.md):
- DO NOT import from config (triggers pydantic-settings env-loading side effects)
- DO NOT import from main or any FastAPI router module
- Only os + celery imported here
REDIS_URL is read directly from os.environ so that this module can be imported
safely by the Celery worker process without pulling in the FastAPI application
machinery.
"""
import os
from celery import Celery
celery_app = Celery("docuvault")
# Broker + result backend — read REDIS_URL directly from env (not from config.settings)
_redis_url = os.environ.get("REDIS_URL", "redis://redis:6379/0")
celery_app.conf.broker_url = _redis_url
celery_app.conf.result_backend = _redis_url
# JSON-only serialization (safe default; avoids pickle deserialization risks)
celery_app.conf.task_serializer = "json"
celery_app.conf.result_serializer = "json"
celery_app.conf.accept_content = ["json"]
# Route document tasks to the dedicated `documents` queue
celery_app.conf.task_routes = {
"tasks.document_tasks.*": {"queue": "documents"},
}
# Autodiscover tasks under the `tasks/` package
celery_app.autodiscover_tasks(["tasks"], force=True)
+17 -7
View File
@@ -1,20 +1,30 @@
"""
Classification orchestrator.
Loads settings, selects AI provider, classifies document, auto-creates suggested topics.
Updated in Plan 05: classify_document and suggest_topics_for_document now accept
an AsyncSession as their first argument so they can be called from the Celery task
wrapper and from API route handlers that already hold a session.
"""
from sqlalchemy.ext.asyncio import AsyncSession
from services import storage
from ai import get_provider
MAX_AI_CHARS = 8_000
async def classify_document(doc_id: str, topic_names: list[str] | None = None) -> list[str]:
async def classify_document(
session: AsyncSession,
doc_id: str,
topic_names: list[str] | None = None,
) -> list[str]:
"""
Classify a document by its ID. Returns the list of assigned topic names.
If topic_names is provided, restrict classification to those topics.
Auto-creates any newly suggested topics.
"""
meta = storage.get_metadata(doc_id)
meta = await storage.get_metadata(session, doc_id)
if meta is None:
raise ValueError(f"Document {doc_id} not found")
@@ -24,7 +34,7 @@ async def classify_document(doc_id: str, topic_names: list[str] | None = None) -
# Use all known topics if not specified
if topic_names is None:
all_topics = storage.load_topics()
all_topics = await storage.load_topics(session)
topic_names = [t["name"] for t in all_topics]
text = meta.get("extracted_text", "")
@@ -37,18 +47,18 @@ async def classify_document(doc_id: str, topic_names: list[str] | None = None) -
existing_names = {t.lower() for t in topic_names}
for name in all_new_names:
if name.strip() and name.lower() not in existing_names:
storage.create_topic(name.strip())
await storage.create_topic(session, name.strip())
# Final list: everything the AI assigned or suggested
final_topics = [t for t in list(set(result.topics + result.suggested_new_topics)) if t.strip()]
storage.update_document_topics(doc_id, final_topics)
await storage.update_document_topics(session, doc_id, final_topics)
return final_topics
async def suggest_topics_for_document(doc_id: str) -> list[str]:
async def suggest_topics_for_document(session: AsyncSession, doc_id: str) -> list[str]:
"""Return AI-suggested topic names without modifying the document."""
meta = storage.get_metadata(doc_id)
meta = await storage.get_metadata(session, doc_id)
if meta is None:
raise ValueError(f"Document {doc_id} not found")
+38
View File
@@ -2,11 +2,49 @@
Text extraction dispatcher.
Supports: PDF (PyMuPDF), DOCX (python-docx), plain text, images (pytesseract).
"""
import tempfile
from pathlib import Path
MAX_STORED_CHARS = 50_000
def extract_text_from_bytes(file_bytes: bytes, mime_type: str) -> str:
"""Extract text from raw bytes by writing to a temp file and dispatching to extract_text.
Used by the Celery worker (which retrieves bytes from MinIO) so extraction
does not require a filesystem path.
"""
suffix = _mime_to_suffix(mime_type)
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(file_bytes)
tmp_path = tmp.name
try:
return extract_text(tmp_path, mime_type)
finally:
import os
try:
os.unlink(tmp_path)
except OSError:
pass
def _mime_to_suffix(mime_type: str) -> str:
"""Return a file extension for the given MIME type."""
mapping = {
"application/pdf": ".pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/msword": ".doc",
"text/plain": ".txt",
"text/markdown": ".md",
"image/png": ".png",
"image/jpeg": ".jpg",
"image/jpg": ".jpg",
"image/tiff": ".tiff",
"image/webp": ".webp",
}
return mapping.get(mime_type, ".bin")
def extract_text(file_path: str, mime_type: str) -> str:
path = Path(file_path)
try:
+1
View File
@@ -0,0 +1 @@
# tasks package — Celery task modules for DocuVault
+94
View File
@@ -0,0 +1,94 @@
"""
Celery tasks for document processing in DocuVault.
extract_and_classify — called via .delay(document_id) by the upload handler.
The task is a plain sync def (Celery workers have no asyncio event loop); it
bridges into the async service layer via asyncio.run().
Flow:
1. Open a fresh AsyncSession (one per task invocation — never share sessions)
2. Look up the Document row to get the MinIO object_key
3. Retrieve file bytes from MinIO via the storage backend
4. Extract text from bytes using services.extractor
5. Persist extracted_text back to the Document row
6. Call services.classifier.classify_document to assign topics
7. Return a result dict (never raises — classification failures are non-fatal)
"""
import asyncio
from celery_app import celery_app
@celery_app.task(name="tasks.document_tasks.extract_and_classify")
def extract_and_classify(document_id: str) -> dict:
"""Synchronous Celery entry-point — delegates to async _run via asyncio.run."""
return asyncio.run(_run(document_id))
async def _run(document_id: str) -> dict:
"""Async body of extract_and_classify.
Opens its own AsyncSession (not shared with the upload request) to avoid
cross-thread session contamination.
"""
import uuid as _uuid
from db.session import AsyncSessionLocal
from db.models import Document
from services import extractor, classifier
from storage import get_storage_backend
async with AsyncSessionLocal() as session:
# ── Step 1: fetch Document row ─────────────────────────────────────────
try:
doc_uuid = _uuid.UUID(document_id)
except ValueError:
return {"document_id": document_id, "status": "invalid_id"}
doc = await session.get(Document, doc_uuid)
if doc is None:
return {"document_id": document_id, "status": "not_found"}
if not doc.object_key:
return {"document_id": document_id, "status": "missing_object"}
# ── Step 2: retrieve bytes from MinIO ──────────────────────────────────
try:
backend = get_storage_backend()
file_bytes = await backend.get_object(doc.object_key)
except Exception as e:
return {
"document_id": document_id,
"status": "extract_failed",
"error": f"MinIO retrieval failed: {e}",
}
# ── Step 3: extract text from bytes ────────────────────────────────────
try:
text = extractor.extract_text_from_bytes(file_bytes, doc.content_type)
doc.extracted_text = text
await session.commit()
except Exception as e:
return {
"document_id": document_id,
"status": "extract_failed",
"error": f"Text extraction failed: {e}",
}
# ── Step 4: classify document (non-fatal) ──────────────────────────────
try:
topics = await classifier.classify_document(session, document_id)
return {
"document_id": document_id,
"status": "classified",
"topics": topics,
}
except Exception as e:
# Non-fatal — preserve existing convention from api/documents.py
doc.status = "classification_failed"
await session.commit()
return {
"document_id": document_id,
"status": "classification_failed",
"error": str(e),
}