feat(01-05): introduce celery_app + tasks/document_tasks + session-aware classifier
- Add backend/celery_app.py: Celery("docuvault") with Redis broker, JSON
serialization, and tasks.document_tasks.* routed to documents queue;
reads REDIS_URL directly from os.environ (no config import — Pitfall 7)
- Add backend/tasks/__init__.py: empty package marker
- Add backend/tasks/document_tasks.py: sync extract_and_classify Celery task
that calls asyncio.run(_run()) to retrieve bytes from MinIO, extract text
via extractor, and classify via classifier; classification failure is non-fatal
- Update backend/services/classifier.py: classify_document and
suggest_topics_for_document now accept session: AsyncSession as first arg;
all storage.* calls updated to async session-injection pattern
- Add extract_text_from_bytes helper to services/extractor.py for bytes-based
extraction (used by Celery worker, which retrieves bytes from MinIO)
This commit is contained in:
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Celery application factory for DocuVault.
|
||||
|
||||
Kept deliberately minimal to avoid circular imports (Pitfall 7 from RESEARCH.md):
|
||||
- DO NOT import from config (triggers pydantic-settings env-loading side effects)
|
||||
- DO NOT import from main or any FastAPI router module
|
||||
- Only os + celery imported here
|
||||
|
||||
REDIS_URL is read directly from os.environ so that this module can be imported
|
||||
safely by the Celery worker process without pulling in the FastAPI application
|
||||
machinery.
|
||||
"""
|
||||
import os
|
||||
|
||||
from celery import Celery
|
||||
|
||||
celery_app = Celery("docuvault")
|
||||
|
||||
# Broker + result backend — read REDIS_URL directly from env (not from config.settings)
|
||||
_redis_url = os.environ.get("REDIS_URL", "redis://redis:6379/0")
|
||||
celery_app.conf.broker_url = _redis_url
|
||||
celery_app.conf.result_backend = _redis_url
|
||||
|
||||
# JSON-only serialization (safe default; avoids pickle deserialization risks)
|
||||
celery_app.conf.task_serializer = "json"
|
||||
celery_app.conf.result_serializer = "json"
|
||||
celery_app.conf.accept_content = ["json"]
|
||||
|
||||
# Route document tasks to the dedicated `documents` queue
|
||||
celery_app.conf.task_routes = {
|
||||
"tasks.document_tasks.*": {"queue": "documents"},
|
||||
}
|
||||
|
||||
# Autodiscover tasks under the `tasks/` package
|
||||
celery_app.autodiscover_tasks(["tasks"], force=True)
|
||||
@@ -1,20 +1,30 @@
|
||||
"""
|
||||
Classification orchestrator.
|
||||
Loads settings, selects AI provider, classifies document, auto-creates suggested topics.
|
||||
|
||||
Updated in Plan 05: classify_document and suggest_topics_for_document now accept
|
||||
an AsyncSession as their first argument so they can be called from the Celery task
|
||||
wrapper and from API route handlers that already hold a session.
|
||||
"""
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from services import storage
|
||||
from ai import get_provider
|
||||
|
||||
MAX_AI_CHARS = 8_000
|
||||
|
||||
|
||||
async def classify_document(doc_id: str, topic_names: list[str] | None = None) -> list[str]:
|
||||
async def classify_document(
|
||||
session: AsyncSession,
|
||||
doc_id: str,
|
||||
topic_names: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Classify a document by its ID. Returns the list of assigned topic names.
|
||||
If topic_names is provided, restrict classification to those topics.
|
||||
Auto-creates any newly suggested topics.
|
||||
"""
|
||||
meta = storage.get_metadata(doc_id)
|
||||
meta = await storage.get_metadata(session, doc_id)
|
||||
if meta is None:
|
||||
raise ValueError(f"Document {doc_id} not found")
|
||||
|
||||
@@ -24,7 +34,7 @@ async def classify_document(doc_id: str, topic_names: list[str] | None = None) -
|
||||
|
||||
# Use all known topics if not specified
|
||||
if topic_names is None:
|
||||
all_topics = storage.load_topics()
|
||||
all_topics = await storage.load_topics(session)
|
||||
topic_names = [t["name"] for t in all_topics]
|
||||
|
||||
text = meta.get("extracted_text", "")
|
||||
@@ -37,18 +47,18 @@ async def classify_document(doc_id: str, topic_names: list[str] | None = None) -
|
||||
existing_names = {t.lower() for t in topic_names}
|
||||
for name in all_new_names:
|
||||
if name.strip() and name.lower() not in existing_names:
|
||||
storage.create_topic(name.strip())
|
||||
await storage.create_topic(session, name.strip())
|
||||
|
||||
# Final list: everything the AI assigned or suggested
|
||||
final_topics = [t for t in list(set(result.topics + result.suggested_new_topics)) if t.strip()]
|
||||
|
||||
storage.update_document_topics(doc_id, final_topics)
|
||||
await storage.update_document_topics(session, doc_id, final_topics)
|
||||
return final_topics
|
||||
|
||||
|
||||
async def suggest_topics_for_document(doc_id: str) -> list[str]:
|
||||
async def suggest_topics_for_document(session: AsyncSession, doc_id: str) -> list[str]:
|
||||
"""Return AI-suggested topic names without modifying the document."""
|
||||
meta = storage.get_metadata(doc_id)
|
||||
meta = await storage.get_metadata(session, doc_id)
|
||||
if meta is None:
|
||||
raise ValueError(f"Document {doc_id} not found")
|
||||
|
||||
|
||||
@@ -2,11 +2,49 @@
|
||||
Text extraction dispatcher.
|
||||
Supports: PDF (PyMuPDF), DOCX (python-docx), plain text, images (pytesseract).
|
||||
"""
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
MAX_STORED_CHARS = 50_000
|
||||
|
||||
|
||||
def extract_text_from_bytes(file_bytes: bytes, mime_type: str) -> str:
|
||||
"""Extract text from raw bytes by writing to a temp file and dispatching to extract_text.
|
||||
|
||||
Used by the Celery worker (which retrieves bytes from MinIO) so extraction
|
||||
does not require a filesystem path.
|
||||
"""
|
||||
suffix = _mime_to_suffix(mime_type)
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
||||
tmp.write(file_bytes)
|
||||
tmp_path = tmp.name
|
||||
try:
|
||||
return extract_text(tmp_path, mime_type)
|
||||
finally:
|
||||
import os
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _mime_to_suffix(mime_type: str) -> str:
|
||||
"""Return a file extension for the given MIME type."""
|
||||
mapping = {
|
||||
"application/pdf": ".pdf",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/msword": ".doc",
|
||||
"text/plain": ".txt",
|
||||
"text/markdown": ".md",
|
||||
"image/png": ".png",
|
||||
"image/jpeg": ".jpg",
|
||||
"image/jpg": ".jpg",
|
||||
"image/tiff": ".tiff",
|
||||
"image/webp": ".webp",
|
||||
}
|
||||
return mapping.get(mime_type, ".bin")
|
||||
|
||||
|
||||
def extract_text(file_path: str, mime_type: str) -> str:
|
||||
path = Path(file_path)
|
||||
try:
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# tasks package — Celery task modules for DocuVault
|
||||
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Celery tasks for document processing in DocuVault.
|
||||
|
||||
extract_and_classify — called via .delay(document_id) by the upload handler.
|
||||
The task is a plain sync def (Celery workers have no asyncio event loop); it
|
||||
bridges into the async service layer via asyncio.run().
|
||||
|
||||
Flow:
|
||||
1. Open a fresh AsyncSession (one per task invocation — never share sessions)
|
||||
2. Look up the Document row to get the MinIO object_key
|
||||
3. Retrieve file bytes from MinIO via the storage backend
|
||||
4. Extract text from bytes using services.extractor
|
||||
5. Persist extracted_text back to the Document row
|
||||
6. Call services.classifier.classify_document to assign topics
|
||||
7. Return a result dict (never raises — classification failures are non-fatal)
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from celery_app import celery_app
|
||||
|
||||
|
||||
@celery_app.task(name="tasks.document_tasks.extract_and_classify")
|
||||
def extract_and_classify(document_id: str) -> dict:
|
||||
"""Synchronous Celery entry-point — delegates to async _run via asyncio.run."""
|
||||
return asyncio.run(_run(document_id))
|
||||
|
||||
|
||||
async def _run(document_id: str) -> dict:
|
||||
"""Async body of extract_and_classify.
|
||||
|
||||
Opens its own AsyncSession (not shared with the upload request) to avoid
|
||||
cross-thread session contamination.
|
||||
"""
|
||||
import uuid as _uuid
|
||||
|
||||
from db.session import AsyncSessionLocal
|
||||
from db.models import Document
|
||||
from services import extractor, classifier
|
||||
from storage import get_storage_backend
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
# ── Step 1: fetch Document row ─────────────────────────────────────────
|
||||
try:
|
||||
doc_uuid = _uuid.UUID(document_id)
|
||||
except ValueError:
|
||||
return {"document_id": document_id, "status": "invalid_id"}
|
||||
|
||||
doc = await session.get(Document, doc_uuid)
|
||||
if doc is None:
|
||||
return {"document_id": document_id, "status": "not_found"}
|
||||
|
||||
if not doc.object_key:
|
||||
return {"document_id": document_id, "status": "missing_object"}
|
||||
|
||||
# ── Step 2: retrieve bytes from MinIO ──────────────────────────────────
|
||||
try:
|
||||
backend = get_storage_backend()
|
||||
file_bytes = await backend.get_object(doc.object_key)
|
||||
except Exception as e:
|
||||
return {
|
||||
"document_id": document_id,
|
||||
"status": "extract_failed",
|
||||
"error": f"MinIO retrieval failed: {e}",
|
||||
}
|
||||
|
||||
# ── Step 3: extract text from bytes ────────────────────────────────────
|
||||
try:
|
||||
text = extractor.extract_text_from_bytes(file_bytes, doc.content_type)
|
||||
doc.extracted_text = text
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
return {
|
||||
"document_id": document_id,
|
||||
"status": "extract_failed",
|
||||
"error": f"Text extraction failed: {e}",
|
||||
}
|
||||
|
||||
# ── Step 4: classify document (non-fatal) ──────────────────────────────
|
||||
try:
|
||||
topics = await classifier.classify_document(session, document_id)
|
||||
return {
|
||||
"document_id": document_id,
|
||||
"status": "classified",
|
||||
"topics": topics,
|
||||
}
|
||||
except Exception as e:
|
||||
# Non-fatal — preserve existing convention from api/documents.py
|
||||
doc.status = "classification_failed"
|
||||
await session.commit()
|
||||
return {
|
||||
"document_id": document_id,
|
||||
"status": "classification_failed",
|
||||
"error": str(e),
|
||||
}
|
||||
Reference in New Issue
Block a user