diff --git a/backend/celery_app.py b/backend/celery_app.py new file mode 100644 index 0000000..14c5037 --- /dev/null +++ b/backend/celery_app.py @@ -0,0 +1,35 @@ +""" +Celery application factory for DocuVault. + +Kept deliberately minimal to avoid circular imports (Pitfall 7 from RESEARCH.md): + - DO NOT import from config (triggers pydantic-settings env-loading side effects) + - DO NOT import from main or any FastAPI router module + - Only os + celery imported here + +REDIS_URL is read directly from os.environ so that this module can be imported +safely by the Celery worker process without pulling in the FastAPI application +machinery. +""" +import os + +from celery import Celery + +celery_app = Celery("docuvault") + +# Broker + result backend — read REDIS_URL directly from env (not from config.settings) +_redis_url = os.environ.get("REDIS_URL", "redis://redis:6379/0") +celery_app.conf.broker_url = _redis_url +celery_app.conf.result_backend = _redis_url + +# JSON-only serialization (safe default; avoids pickle deserialization risks) +celery_app.conf.task_serializer = "json" +celery_app.conf.result_serializer = "json" +celery_app.conf.accept_content = ["json"] + +# Route document tasks to the dedicated `documents` queue +celery_app.conf.task_routes = { + "tasks.document_tasks.*": {"queue": "documents"}, +} + +# Autodiscover tasks under the `tasks/` package +celery_app.autodiscover_tasks(["tasks"], force=True) diff --git a/backend/services/classifier.py b/backend/services/classifier.py index 40c11dc..ffe355e 100644 --- a/backend/services/classifier.py +++ b/backend/services/classifier.py @@ -1,20 +1,30 @@ """ Classification orchestrator. Loads settings, selects AI provider, classifies document, auto-creates suggested topics. + +Updated in Plan 05: classify_document and suggest_topics_for_document now accept +an AsyncSession as their first argument so they can be called from the Celery task +wrapper and from API route handlers that already hold a session. """ +from sqlalchemy.ext.asyncio import AsyncSession + from services import storage from ai import get_provider MAX_AI_CHARS = 8_000 -async def classify_document(doc_id: str, topic_names: list[str] | None = None) -> list[str]: +async def classify_document( + session: AsyncSession, + doc_id: str, + topic_names: list[str] | None = None, +) -> list[str]: """ Classify a document by its ID. Returns the list of assigned topic names. If topic_names is provided, restrict classification to those topics. Auto-creates any newly suggested topics. """ - meta = storage.get_metadata(doc_id) + meta = await storage.get_metadata(session, doc_id) if meta is None: raise ValueError(f"Document {doc_id} not found") @@ -24,7 +34,7 @@ async def classify_document(doc_id: str, topic_names: list[str] | None = None) - # Use all known topics if not specified if topic_names is None: - all_topics = storage.load_topics() + all_topics = await storage.load_topics(session) topic_names = [t["name"] for t in all_topics] text = meta.get("extracted_text", "") @@ -37,18 +47,18 @@ async def classify_document(doc_id: str, topic_names: list[str] | None = None) - existing_names = {t.lower() for t in topic_names} for name in all_new_names: if name.strip() and name.lower() not in existing_names: - storage.create_topic(name.strip()) + await storage.create_topic(session, name.strip()) # Final list: everything the AI assigned or suggested final_topics = [t for t in list(set(result.topics + result.suggested_new_topics)) if t.strip()] - storage.update_document_topics(doc_id, final_topics) + await storage.update_document_topics(session, doc_id, final_topics) return final_topics -async def suggest_topics_for_document(doc_id: str) -> list[str]: +async def suggest_topics_for_document(session: AsyncSession, doc_id: str) -> list[str]: """Return AI-suggested topic names without modifying the document.""" - meta = storage.get_metadata(doc_id) + meta = await storage.get_metadata(session, doc_id) if meta is None: raise ValueError(f"Document {doc_id} not found") diff --git a/backend/services/extractor.py b/backend/services/extractor.py index f85f7e2..2fe366f 100644 --- a/backend/services/extractor.py +++ b/backend/services/extractor.py @@ -2,11 +2,49 @@ Text extraction dispatcher. Supports: PDF (PyMuPDF), DOCX (python-docx), plain text, images (pytesseract). """ +import tempfile from pathlib import Path MAX_STORED_CHARS = 50_000 +def extract_text_from_bytes(file_bytes: bytes, mime_type: str) -> str: + """Extract text from raw bytes by writing to a temp file and dispatching to extract_text. + + Used by the Celery worker (which retrieves bytes from MinIO) so extraction + does not require a filesystem path. + """ + suffix = _mime_to_suffix(mime_type) + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: + tmp.write(file_bytes) + tmp_path = tmp.name + try: + return extract_text(tmp_path, mime_type) + finally: + import os + try: + os.unlink(tmp_path) + except OSError: + pass + + +def _mime_to_suffix(mime_type: str) -> str: + """Return a file extension for the given MIME type.""" + mapping = { + "application/pdf": ".pdf", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", + "application/msword": ".doc", + "text/plain": ".txt", + "text/markdown": ".md", + "image/png": ".png", + "image/jpeg": ".jpg", + "image/jpg": ".jpg", + "image/tiff": ".tiff", + "image/webp": ".webp", + } + return mapping.get(mime_type, ".bin") + + def extract_text(file_path: str, mime_type: str) -> str: path = Path(file_path) try: diff --git a/backend/tasks/__init__.py b/backend/tasks/__init__.py new file mode 100644 index 0000000..dfab74a --- /dev/null +++ b/backend/tasks/__init__.py @@ -0,0 +1 @@ +# tasks package — Celery task modules for DocuVault diff --git a/backend/tasks/document_tasks.py b/backend/tasks/document_tasks.py new file mode 100644 index 0000000..38e9501 --- /dev/null +++ b/backend/tasks/document_tasks.py @@ -0,0 +1,94 @@ +""" +Celery tasks for document processing in DocuVault. + +extract_and_classify — called via .delay(document_id) by the upload handler. +The task is a plain sync def (Celery workers have no asyncio event loop); it +bridges into the async service layer via asyncio.run(). + +Flow: + 1. Open a fresh AsyncSession (one per task invocation — never share sessions) + 2. Look up the Document row to get the MinIO object_key + 3. Retrieve file bytes from MinIO via the storage backend + 4. Extract text from bytes using services.extractor + 5. Persist extracted_text back to the Document row + 6. Call services.classifier.classify_document to assign topics + 7. Return a result dict (never raises — classification failures are non-fatal) +""" +import asyncio + +from celery_app import celery_app + + +@celery_app.task(name="tasks.document_tasks.extract_and_classify") +def extract_and_classify(document_id: str) -> dict: + """Synchronous Celery entry-point — delegates to async _run via asyncio.run.""" + return asyncio.run(_run(document_id)) + + +async def _run(document_id: str) -> dict: + """Async body of extract_and_classify. + + Opens its own AsyncSession (not shared with the upload request) to avoid + cross-thread session contamination. + """ + import uuid as _uuid + + from db.session import AsyncSessionLocal + from db.models import Document + from services import extractor, classifier + from storage import get_storage_backend + + async with AsyncSessionLocal() as session: + # ── Step 1: fetch Document row ───────────────────────────────────────── + try: + doc_uuid = _uuid.UUID(document_id) + except ValueError: + return {"document_id": document_id, "status": "invalid_id"} + + doc = await session.get(Document, doc_uuid) + if doc is None: + return {"document_id": document_id, "status": "not_found"} + + if not doc.object_key: + return {"document_id": document_id, "status": "missing_object"} + + # ── Step 2: retrieve bytes from MinIO ────────────────────────────────── + try: + backend = get_storage_backend() + file_bytes = await backend.get_object(doc.object_key) + except Exception as e: + return { + "document_id": document_id, + "status": "extract_failed", + "error": f"MinIO retrieval failed: {e}", + } + + # ── Step 3: extract text from bytes ──────────────────────────────────── + try: + text = extractor.extract_text_from_bytes(file_bytes, doc.content_type) + doc.extracted_text = text + await session.commit() + except Exception as e: + return { + "document_id": document_id, + "status": "extract_failed", + "error": f"Text extraction failed: {e}", + } + + # ── Step 4: classify document (non-fatal) ────────────────────────────── + try: + topics = await classifier.classify_document(session, document_id) + return { + "document_id": document_id, + "status": "classified", + "topics": topics, + } + except Exception as e: + # Non-fatal — preserve existing convention from api/documents.py + doc.status = "classification_failed" + await session.commit() + return { + "document_id": document_id, + "status": "classification_failed", + "error": str(e), + }