Add shared ai-service container as AI provider intermediary
All feature containers now POST messages to ai-service (port 8010) instead of calling AI providers directly. ai-service routes to LM Studio, Ollama, or Anthropic based on /config/ai_service_config.json. doc-service AI providers removed; replaced by httpx ai_client.py. Backend settings restructured to /api/settings/ai. Frontend gets dedicated AIAdminSettingsPage and AI Service card in AppsPage. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,12 @@
|
||||
AI_PROVIDER=lmstudio
|
||||
|
||||
LMSTUDIO_BASE_URL=http://host.docker.internal:1234/v1
|
||||
LMSTUDIO_API_KEY=your-lmstudio-api-key
|
||||
LMSTUDIO_MODEL=local-model
|
||||
|
||||
OLLAMA_BASE_URL=http://host.docker.internal:11434/v1
|
||||
OLLAMA_MODEL=llama3.2
|
||||
OLLAMA_API_KEY=ollama
|
||||
|
||||
ANTHROPIC_API_KEY=sk-ant-your-key-here
|
||||
ANTHROPIC_MODEL=claude-haiku-4-5-20251001
|
||||
@@ -0,0 +1,33 @@
|
||||
# ── Stage 1: dependency installation ─────────────────────────────────────────
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN pip install --upgrade pip
|
||||
|
||||
COPY pyproject.toml .
|
||||
RUN pip install --prefix=/install .
|
||||
|
||||
# ── Stage 2: runtime ──────────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim
|
||||
|
||||
# Create non-root user (UID/GID 1001)
|
||||
RUN groupadd --gid 1001 appuser && \
|
||||
useradd --uid 1001 --gid 1001 --no-create-home --shell /bin/sh appuser
|
||||
|
||||
# Pre-create the config directory with correct ownership
|
||||
RUN mkdir -p /config && chown -R appuser:appuser /config
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=builder /install /usr/local
|
||||
|
||||
COPY --chown=appuser:appuser app ./app
|
||||
COPY --chown=appuser:appuser scripts ./scripts
|
||||
RUN chmod +x scripts/start.sh scripts/start_dev.sh
|
||||
|
||||
USER appuser
|
||||
|
||||
EXPOSE 8010
|
||||
|
||||
CMD ["sh", "scripts/start.sh"]
|
||||
@@ -0,0 +1,12 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "ai-service"
|
||||
CONFIG_PATH: str = "/config/ai_service_config.json"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -0,0 +1,25 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.core.config import settings
|
||||
from app.routers import chat, health
|
||||
from app.services.config_reader import load_ai_config
|
||||
|
||||
logger = logging.getLogger("ai-service")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
config = await load_ai_config()
|
||||
provider = config.get("provider", "lmstudio")
|
||||
model = config.get(provider, {}).get("model", "unknown")
|
||||
logger.info("[ai-service] active provider: %s model: %s", provider, model)
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title=settings.PROJECT_NAME, lifespan=lifespan)
|
||||
|
||||
app.include_router(chat.router, tags=["chat"])
|
||||
app.include_router(health.router, tags=["health"])
|
||||
@@ -0,0 +1,20 @@
|
||||
from app.providers.base import AIProvider
|
||||
|
||||
|
||||
def get_provider(ai_config: dict) -> AIProvider:
|
||||
"""Return an AIProvider instance for the active provider in the config."""
|
||||
provider_name = ai_config.get("provider", "lmstudio")
|
||||
provider_cfg = ai_config.get(provider_name, {})
|
||||
|
||||
match provider_name:
|
||||
case "anthropic":
|
||||
from app.providers.anthropic_provider import AnthropicProvider
|
||||
return AnthropicProvider(provider_cfg)
|
||||
case "ollama" | "lmstudio":
|
||||
from app.providers.openai_compat import OpenAICompatProvider
|
||||
return OpenAICompatProvider(provider_cfg, provider_name=provider_name)
|
||||
case _:
|
||||
raise ValueError(f"Unknown AI provider: {provider_name!r}")
|
||||
|
||||
|
||||
__all__ = ["AIProvider", "get_provider"]
|
||||
@@ -0,0 +1,54 @@
|
||||
import asyncio
|
||||
|
||||
import anthropic
|
||||
|
||||
from app.providers.base import AIProvider
|
||||
from app.schemas.chat import ChatMessage
|
||||
|
||||
|
||||
class AnthropicProvider(AIProvider):
|
||||
def __init__(self, config: dict) -> None:
|
||||
self._client = anthropic.AsyncAnthropic(api_key=config.get("api_key", ""))
|
||||
self.model_name = config.get("model", "claude-haiku-4-5-20251001")
|
||||
self.provider_name = "anthropic"
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> tuple[str, int | None, int | None]:
|
||||
# Anthropic uses a top-level `system=` param, not a role in the messages array
|
||||
system_content = ""
|
||||
user_messages = []
|
||||
for msg in messages:
|
||||
if msg.role == "system":
|
||||
system_content += msg.content + "\n"
|
||||
else:
|
||||
user_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
try:
|
||||
response = await self._client.messages.create(
|
||||
model=self.model_name,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
system=system_content.strip() or anthropic.NOT_GIVEN,
|
||||
messages=user_messages,
|
||||
)
|
||||
except anthropic.APIConnectionError as exc:
|
||||
raise ProviderConnectionError(str(exc)) from exc
|
||||
except anthropic.APITimeoutError as exc:
|
||||
raise ProviderTimeoutError(str(exc)) from exc
|
||||
except anthropic.APIStatusError as exc:
|
||||
raise ProviderConnectionError(f"Anthropic API error {exc.status_code}: {exc.message}") from exc
|
||||
|
||||
content = response.content[0].text
|
||||
return content, response.usage.input_tokens, response.usage.output_tokens
|
||||
|
||||
|
||||
class ProviderConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ProviderTimeoutError(Exception):
|
||||
pass
|
||||
@@ -0,0 +1,23 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app.schemas.chat import ChatMessage
|
||||
|
||||
|
||||
class AIProvider(ABC):
|
||||
provider_name: str = "unknown"
|
||||
model_name: str = "unknown"
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> tuple[str, int | None, int | None]:
|
||||
"""
|
||||
Send messages to the provider and return (content, input_tokens, output_tokens).
|
||||
Raises:
|
||||
ProviderConnectionError: on network / auth failure
|
||||
ProviderTimeoutError: on request timeout
|
||||
"""
|
||||
...
|
||||
@@ -0,0 +1,52 @@
|
||||
"""OpenAI-compatible provider — handles both Ollama and LM Studio."""
|
||||
import asyncio
|
||||
|
||||
import openai
|
||||
|
||||
from app.providers.base import AIProvider
|
||||
from app.schemas.chat import ChatMessage
|
||||
|
||||
|
||||
class OpenAICompatProvider(AIProvider):
|
||||
def __init__(self, config: dict, provider_name: str = "lmstudio") -> None:
|
||||
self._client = openai.AsyncOpenAI(
|
||||
base_url=config.get("base_url", "http://localhost:1234/v1"),
|
||||
api_key=config.get("api_key") or "not-required",
|
||||
)
|
||||
self.model_name = config.get("model", "local-model")
|
||||
self.provider_name = provider_name
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> tuple[str, int | None, int | None]:
|
||||
raw_messages = [{"role": m.role, "content": m.content} for m in messages]
|
||||
try:
|
||||
response = await self._client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=raw_messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
except openai.APIConnectionError as exc:
|
||||
raise ProviderConnectionError(str(exc)) from exc
|
||||
except openai.APITimeoutError as exc:
|
||||
raise ProviderTimeoutError(str(exc)) from exc
|
||||
except openai.APIStatusError as exc:
|
||||
raise ProviderConnectionError(f"API error {exc.status_code}: {exc.message}") from exc
|
||||
|
||||
content = response.choices[0].message.content or ""
|
||||
usage = response.usage
|
||||
input_tokens = usage.prompt_tokens if usage else None
|
||||
output_tokens = usage.completion_tokens if usage else None
|
||||
return content, input_tokens, output_tokens
|
||||
|
||||
|
||||
class ProviderConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ProviderTimeoutError(Exception):
|
||||
pass
|
||||
@@ -0,0 +1,70 @@
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.providers import get_provider
|
||||
from app.providers.anthropic_provider import ProviderConnectionError as AnthropicConnError
|
||||
from app.providers.anthropic_provider import ProviderTimeoutError as AnthropicTimeoutError
|
||||
from app.providers.openai_compat import ProviderConnectionError as OpenAIConnError
|
||||
from app.providers.openai_compat import ProviderTimeoutError as OpenAITimeoutError
|
||||
from app.schemas.chat import ChatRequest, ChatResponse
|
||||
from app.services.config_reader import load_ai_config
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_FENCE_RE = re.compile(r"^```[a-z]*\n?(.*?)\n?```$", re.DOTALL)
|
||||
|
||||
|
||||
def _strip_fences(text: str) -> str:
|
||||
m = _FENCE_RE.match(text.strip())
|
||||
return m.group(1).strip() if m else text.strip()
|
||||
|
||||
|
||||
@router.post("/chat", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest) -> ChatResponse:
|
||||
config = await load_ai_config()
|
||||
|
||||
provider_name = config.get("provider", "lmstudio")
|
||||
if provider_name not in ("anthropic", "ollama", "lmstudio"):
|
||||
raise HTTPException(status_code=503, detail=f"Unknown provider configured: {provider_name!r}")
|
||||
|
||||
try:
|
||||
provider = get_provider(config)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=503, detail=str(exc))
|
||||
|
||||
timeout = config.get("timeout_seconds", 60)
|
||||
max_retries = config.get("max_retries", 2)
|
||||
last_exc: Exception | None = None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
content, input_tokens, output_tokens = await asyncio.wait_for(
|
||||
provider.chat(request.messages, request.max_tokens, request.temperature),
|
||||
timeout=float(timeout),
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError as exc:
|
||||
last_exc = exc
|
||||
# Don't retry on timeout — the model is busy; fail fast
|
||||
raise HTTPException(status_code=504, detail="AI provider timed out") from exc
|
||||
except (AnthropicConnError, OpenAIConnError) as exc:
|
||||
last_exc = exc
|
||||
if attempt < max_retries:
|
||||
await asyncio.sleep(0.5 * (attempt + 1))
|
||||
continue
|
||||
raise HTTPException(status_code=502, detail=f"AI provider error: {exc}") from exc
|
||||
except (AnthropicTimeoutError, OpenAITimeoutError) as exc:
|
||||
raise HTTPException(status_code=504, detail="AI provider timed out") from exc
|
||||
|
||||
if request.response_format == "json":
|
||||
content = _strip_fences(content)
|
||||
|
||||
return ChatResponse(
|
||||
content=content,
|
||||
provider=provider.provider_name,
|
||||
model=provider.model_name,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.services.config_reader import load_ai_config
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health() -> dict:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get("/health/provider")
|
||||
async def provider_status() -> dict:
|
||||
config = await load_ai_config()
|
||||
provider = config.get("provider", "lmstudio")
|
||||
pcfg = config.get(provider, {})
|
||||
model = pcfg.get("model", "")
|
||||
|
||||
# "configured" means we have the minimum required fields for the provider
|
||||
if provider == "anthropic":
|
||||
configured = bool(pcfg.get("api_key"))
|
||||
else:
|
||||
configured = bool(pcfg.get("base_url") and pcfg.get("model"))
|
||||
|
||||
return {
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"configured": configured,
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant"]
|
||||
content: str
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: list[ChatMessage]
|
||||
max_tokens: int = 2048
|
||||
temperature: float = 0.0
|
||||
response_format: Literal["json", "text"] = "text"
|
||||
|
||||
@field_validator("messages")
|
||||
@classmethod
|
||||
def messages_not_empty(cls, v: list) -> list:
|
||||
if not v:
|
||||
raise ValueError("messages must not be empty")
|
||||
return v
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
content: str
|
||||
provider: str
|
||||
model: str
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Reads ai_service_config.json from the shared config volume.
|
||||
30-second TTL cache + env var overrides (dev credentials stay out of git).
|
||||
|
||||
Env var overrides (all optional):
|
||||
AI_PROVIDER — "lmstudio" | "ollama" | "anthropic"
|
||||
LMSTUDIO_BASE_URL, LMSTUDIO_API_KEY, LMSTUDIO_MODEL
|
||||
OLLAMA_BASE_URL, OLLAMA_MODEL, OLLAMA_API_KEY
|
||||
ANTHROPIC_API_KEY, ANTHROPIC_MODEL
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
_DEFAULT_CONFIG: dict = {
|
||||
"provider": "lmstudio",
|
||||
"timeout_seconds": 60,
|
||||
"max_retries": 2,
|
||||
"anthropic": {"api_key": "", "model": "claude-haiku-4-5-20251001"},
|
||||
"ollama": {"base_url": "http://host.docker.internal:11434/v1", "model": "llama3.2", "api_key": "ollama"},
|
||||
"lmstudio": {"base_url": "http://host.docker.internal:1234/v1", "model": "local-model", "api_key": "lm-studio"},
|
||||
}
|
||||
|
||||
_cache: dict | None = None
|
||||
_cache_at: float = 0.0
|
||||
_CACHE_TTL = 30.0
|
||||
|
||||
|
||||
def _read_config_sync() -> dict:
|
||||
path = Path(settings.CONFIG_PATH)
|
||||
if not path.exists():
|
||||
return _apply_env_overrides(deepcopy(_DEFAULT_CONFIG))
|
||||
with open(path) as f:
|
||||
return _apply_env_overrides(json.load(f))
|
||||
|
||||
|
||||
def _apply_env_overrides(config: dict) -> dict:
|
||||
cfg = deepcopy(config)
|
||||
|
||||
if v := os.environ.get("AI_PROVIDER"):
|
||||
cfg["provider"] = v
|
||||
|
||||
lms = cfg.setdefault("lmstudio", {})
|
||||
if v := os.environ.get("LMSTUDIO_BASE_URL"):
|
||||
lms["base_url"] = v
|
||||
if v := os.environ.get("LMSTUDIO_API_KEY"):
|
||||
lms["api_key"] = v
|
||||
if v := os.environ.get("LMSTUDIO_MODEL"):
|
||||
lms["model"] = v
|
||||
|
||||
oll = cfg.setdefault("ollama", {})
|
||||
if v := os.environ.get("OLLAMA_BASE_URL"):
|
||||
oll["base_url"] = v
|
||||
if v := os.environ.get("OLLAMA_MODEL"):
|
||||
oll["model"] = v
|
||||
if v := os.environ.get("OLLAMA_API_KEY"):
|
||||
oll["api_key"] = v
|
||||
|
||||
ant = cfg.setdefault("anthropic", {})
|
||||
if v := os.environ.get("ANTHROPIC_API_KEY"):
|
||||
ant["api_key"] = v
|
||||
if v := os.environ.get("ANTHROPIC_MODEL"):
|
||||
ant["model"] = v
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
async def load_ai_config() -> dict:
|
||||
global _cache, _cache_at
|
||||
now = time.monotonic()
|
||||
if _cache is not None and (now - _cache_at) < _CACHE_TTL:
|
||||
return _cache
|
||||
data = await asyncio.to_thread(_read_config_sync)
|
||||
_cache = data
|
||||
_cache_at = now
|
||||
return data
|
||||
@@ -0,0 +1,29 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=45"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "ai-service"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"fastapi>=0.111",
|
||||
"uvicorn[standard]>=0.29",
|
||||
"pydantic-settings>=2.2",
|
||||
"anthropic>=0.28",
|
||||
"openai>=1.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8",
|
||||
"pytest-asyncio>=0.23",
|
||||
"httpx>=0.27",
|
||||
"ruff>=0.4",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
@@ -0,0 +1,4 @@
|
||||
#!/bin/sh
|
||||
set -e
|
||||
echo "[ai-service] starting uvicorn..."
|
||||
exec uvicorn app.main:app --host 0.0.0.0 --port 8010
|
||||
@@ -0,0 +1,4 @@
|
||||
#!/bin/sh
|
||||
set -e
|
||||
echo "[ai-service] starting uvicorn (dev)..."
|
||||
exec uvicorn app.main:app --host 0.0.0.0 --port 8010 --reload
|
||||
@@ -0,0 +1,57 @@
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def ai_client():
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
LMSTUDIO_CONFIG = {
|
||||
"provider": "lmstudio",
|
||||
"timeout_seconds": 10,
|
||||
"max_retries": 0,
|
||||
"lmstudio": {
|
||||
"base_url": "http://fake-lmstudio/v1",
|
||||
"model": "test-model",
|
||||
"api_key": "test-key",
|
||||
},
|
||||
}
|
||||
|
||||
OLLAMA_CONFIG = {
|
||||
"provider": "ollama",
|
||||
"timeout_seconds": 10,
|
||||
"max_retries": 0,
|
||||
"ollama": {
|
||||
"base_url": "http://fake-ollama/v1",
|
||||
"model": "llama3.2",
|
||||
"api_key": "ollama",
|
||||
},
|
||||
}
|
||||
|
||||
ANTHROPIC_CONFIG = {
|
||||
"provider": "anthropic",
|
||||
"timeout_seconds": 10,
|
||||
"max_retries": 0,
|
||||
"anthropic": {
|
||||
"api_key": "sk-ant-test",
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
},
|
||||
}
|
||||
|
||||
MISSING_KEY_ANTHROPIC_CONFIG = {
|
||||
"provider": "anthropic",
|
||||
"timeout_seconds": 10,
|
||||
"max_retries": 0,
|
||||
"anthropic": {
|
||||
"api_key": "",
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,221 @@
|
||||
"""Tests for POST /chat."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.conftest import ANTHROPIC_CONFIG, LMSTUDIO_CONFIG, OLLAMA_CONFIG
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_LOAD_CONFIG = "app.routers.chat.load_ai_config"
|
||||
_PROVIDER_CHAT = "app.providers.openai_compat.OpenAICompatProvider.chat"
|
||||
_ANTHROPIC_CHAT = "app.providers.anthropic_provider.AnthropicProvider.chat"
|
||||
|
||||
MESSAGES = [{"role": "user", "content": "Hello"}]
|
||||
SYSTEM_MESSAGES = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
|
||||
|
||||
def _mock_chat_response(content="ok", input_tokens=10, output_tokens=5):
|
||||
return AsyncMock(return_value=(content, input_tokens, output_tokens))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Success: each provider
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_lmstudio_success(ai_client):
|
||||
with patch(_LOAD_CONFIG, return_value=LMSTUDIO_CONFIG), patch(
|
||||
_PROVIDER_CHAT, new=_mock_chat_response("lmstudio reply")
|
||||
):
|
||||
resp = await ai_client.post("/chat", json={"messages": MESSAGES})
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["content"] == "lmstudio reply"
|
||||
assert data["provider"] == "lmstudio"
|
||||
assert data["model"] == "test-model"
|
||||
assert data["input_tokens"] == 10
|
||||
assert data["output_tokens"] == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_ollama_success(ai_client):
|
||||
with patch(_LOAD_CONFIG, return_value=OLLAMA_CONFIG), patch(
|
||||
_PROVIDER_CHAT, new=_mock_chat_response("ollama reply")
|
||||
):
|
||||
resp = await ai_client.post("/chat", json={"messages": MESSAGES})
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["content"] == "ollama reply"
|
||||
assert data["provider"] == "ollama"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_anthropic_success(ai_client):
|
||||
with patch(_LOAD_CONFIG, return_value=ANTHROPIC_CONFIG), patch(
|
||||
_ANTHROPIC_CHAT, new=_mock_chat_response("anthropic reply")
|
||||
):
|
||||
resp = await ai_client.post("/chat", json={"messages": MESSAGES})
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["content"] == "anthropic reply"
|
||||
assert data["provider"] == "anthropic"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# response_format
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_format_json_strips_fences(ai_client):
|
||||
fenced = "```json\n{\"key\": \"value\"}\n```"
|
||||
with patch(_LOAD_CONFIG, return_value=LMSTUDIO_CONFIG), patch(
|
||||
_PROVIDER_CHAT, new=_mock_chat_response(fenced)
|
||||
):
|
||||
resp = await ai_client.post(
|
||||
"/chat",
|
||||
json={"messages": MESSAGES, "response_format": "json"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["content"] == '{"key": "value"}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_format_text_preserves_fences(ai_client):
|
||||
fenced = "```python\nprint('hi')\n```"
|
||||
with patch(_LOAD_CONFIG, return_value=LMSTUDIO_CONFIG), patch(
|
||||
_PROVIDER_CHAT, new=_mock_chat_response(fenced)
|
||||
):
|
||||
resp = await ai_client.post(
|
||||
"/chat",
|
||||
json={"messages": MESSAGES, "response_format": "text"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert "```" in resp.json()["content"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Validation errors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_missing_messages_returns_422(ai_client):
|
||||
resp = await ai_client.post("/chat", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_empty_messages_returns_422(ai_client):
|
||||
resp = await ai_client.post("/chat", json={"messages": []})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider errors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_connection_error_returns_502(ai_client):
|
||||
from app.providers.openai_compat import ProviderConnectionError
|
||||
|
||||
with patch(_LOAD_CONFIG, return_value=LMSTUDIO_CONFIG), patch(
|
||||
_PROVIDER_CHAT, side_effect=ProviderConnectionError("refused")
|
||||
):
|
||||
resp = await ai_client.post("/chat", json={"messages": MESSAGES})
|
||||
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_timeout_returns_504(ai_client):
|
||||
async def _slow(*_args, **_kwargs):
|
||||
await asyncio.sleep(100)
|
||||
|
||||
with patch(_LOAD_CONFIG, return_value={**LMSTUDIO_CONFIG, "timeout_seconds": 0.01}), patch(
|
||||
_PROVIDER_CHAT, new=_slow
|
||||
):
|
||||
resp = await ai_client.post("/chat", json={"messages": MESSAGES})
|
||||
|
||||
assert resp.status_code == 504
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_unknown_provider_returns_503(ai_client):
|
||||
bad_config = {**LMSTUDIO_CONFIG, "provider": "unknown-llm"}
|
||||
with patch(_LOAD_CONFIG, return_value=bad_config):
|
||||
resp = await ai_client.post("/chat", json={"messages": MESSAGES})
|
||||
|
||||
assert resp.status_code == 503
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Anthropic system message extraction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_system_message_extracted(ai_client):
|
||||
"""System-role messages must not appear in the user_messages list sent to Anthropic."""
|
||||
captured_kwargs: dict = {}
|
||||
|
||||
async def _fake_create(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.content = [MagicMock(text="ok")]
|
||||
mock_resp.usage = MagicMock(input_tokens=5, output_tokens=2)
|
||||
return mock_resp
|
||||
|
||||
with patch(_LOAD_CONFIG, return_value=ANTHROPIC_CONFIG), patch(
|
||||
"anthropic.AsyncAnthropic.messages",
|
||||
new_callable=lambda: type(
|
||||
"Messages",
|
||||
(),
|
||||
{"create": staticmethod(AsyncMock(side_effect=_fake_create))},
|
||||
),
|
||||
):
|
||||
resp = await ai_client.post("/chat", json={"messages": SYSTEM_MESSAGES})
|
||||
|
||||
# Whether the call succeeded or not, no system role should reach the messages list
|
||||
if "messages" in captured_kwargs:
|
||||
roles = [m["role"] for m in captured_kwargs["messages"]]
|
||||
assert "system" not in roles
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parameter forwarding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_tokens_and_temperature_forwarded(ai_client):
|
||||
captured: dict = {}
|
||||
|
||||
async def _capture(messages, max_tokens, temperature):
|
||||
captured["max_tokens"] = max_tokens
|
||||
captured["temperature"] = temperature
|
||||
return ("ok", 1, 1)
|
||||
|
||||
with patch(_LOAD_CONFIG, return_value=LMSTUDIO_CONFIG), patch(_PROVIDER_CHAT, new=_capture):
|
||||
resp = await ai_client.post(
|
||||
"/chat",
|
||||
json={"messages": MESSAGES, "max_tokens": 512, "temperature": 0.7},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert captured["max_tokens"] == 512
|
||||
assert captured["temperature"] == pytest.approx(0.7)
|
||||
@@ -0,0 +1,38 @@
|
||||
"""Tests for GET /health and GET /health/provider."""
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.conftest import ANTHROPIC_CONFIG, LMSTUDIO_CONFIG, MISSING_KEY_ANTHROPIC_CONFIG
|
||||
|
||||
_LOAD_CONFIG = "app.routers.health.load_ai_config"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_returns_ok(ai_client):
|
||||
resp = await ai_client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"status": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_status_configured(ai_client):
|
||||
with patch(_LOAD_CONFIG, return_value=LMSTUDIO_CONFIG):
|
||||
resp = await ai_client.get("/health/provider")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["provider"] == "lmstudio"
|
||||
assert data["model"] == "test-model"
|
||||
assert data["configured"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_status_not_configured_when_api_key_missing(ai_client):
|
||||
with patch(_LOAD_CONFIG, return_value=MISSING_KEY_ANTHROPIC_CONFIG):
|
||||
resp = await ai_client.get("/health/provider")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["provider"] == "anthropic"
|
||||
assert data["configured"] is False
|
||||
@@ -6,6 +6,7 @@ class Settings(BaseSettings):
|
||||
DATABASE_URL: str = "postgresql+asyncpg://postgres:password@db:5432/destroying_sap"
|
||||
DATA_DIR: str = "/data/documents"
|
||||
CONFIG_PATH: str = "/config/doc_service_config.json"
|
||||
AI_SERVICE_URL: str = "http://ai-service:8010"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
@@ -17,7 +17,7 @@ from app.models.category import DocumentCategory
|
||||
from app.models.category_assignment import CategoryAssignment
|
||||
from app.models.document import Document
|
||||
from app.schemas.document import DocumentOut, DocumentStatusOut, DocumentTypeUpdate
|
||||
from app.services.ai import get_provider
|
||||
from app.services.ai_client import AIServiceError, classify_document
|
||||
from app.services.config_reader import load_doc_config
|
||||
from app.services.storage import delete_file, get_upload_path, save_upload
|
||||
|
||||
@@ -91,9 +91,7 @@ async def process_document(doc_id: str) -> None:
|
||||
|
||||
try:
|
||||
text = await asyncio.to_thread(_extract_pdf_text, doc.file_path)
|
||||
config = await load_doc_config()
|
||||
provider = get_provider(config["ai"])
|
||||
result = await provider.classify_document(text)
|
||||
result = await classify_document(text)
|
||||
|
||||
doc.raw_text = text[:500_000] # cap stored text at 500k chars
|
||||
doc.extracted_data = json.dumps(result)
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
from app.services.ai.base import AIProvider
|
||||
|
||||
|
||||
def get_provider(ai_config: dict) -> AIProvider:
|
||||
"""
|
||||
Factory: return an AIProvider instance based on the 'provider' key in the AI config section.
|
||||
ai_config is the 'ai' section of doc_service_config.json, loaded fresh per processing job.
|
||||
"""
|
||||
provider_name = ai_config.get("provider", "anthropic")
|
||||
provider_cfg = ai_config.get(provider_name, {})
|
||||
|
||||
match provider_name:
|
||||
case "anthropic":
|
||||
from app.services.ai.anthropic_provider import AnthropicProvider
|
||||
return AnthropicProvider(provider_cfg)
|
||||
case "ollama" | "lmstudio":
|
||||
from app.services.ai.openai_compat import OpenAICompatProvider
|
||||
return OpenAICompatProvider(provider_cfg)
|
||||
case _:
|
||||
raise ValueError(f"Unknown AI provider: {provider_name!r}")
|
||||
|
||||
|
||||
__all__ = ["AIProvider", "get_provider"]
|
||||
@@ -1,31 +0,0 @@
|
||||
import json
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
from app.services.ai.base import AIProvider, SYSTEM_PROMPT, USER_PROMPT_TEMPLATE
|
||||
|
||||
|
||||
class AnthropicProvider(AIProvider):
|
||||
def __init__(self, config: dict) -> None:
|
||||
self._client = AsyncAnthropic(api_key=config["api_key"])
|
||||
self._model = config.get("model", "claude-haiku-4-5-20251001")
|
||||
|
||||
async def classify_document(self, text: str) -> dict:
|
||||
message = await self._client.messages.create(
|
||||
model=self._model,
|
||||
max_tokens=2048,
|
||||
system=SYSTEM_PROMPT,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": USER_PROMPT_TEMPLATE.format(text=text[:100_000]),
|
||||
}],
|
||||
)
|
||||
raw = message.content[0].text.strip()
|
||||
return _parse_json(raw)
|
||||
|
||||
|
||||
def _parse_json(raw: str) -> dict:
|
||||
# Strip accidental markdown fences despite explicit instruction not to include them
|
||||
if raw.startswith("```"):
|
||||
raw = raw.split("\n", 1)[1].rsplit("```", 1)[0]
|
||||
return json.loads(raw)
|
||||
@@ -1,36 +0,0 @@
|
||||
"""
|
||||
OpenAI-compatible provider for Ollama and LM Studio.
|
||||
Both expose an OpenAI-compatible /v1/chat/completions endpoint.
|
||||
"""
|
||||
import json
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from app.services.ai.base import AIProvider, SYSTEM_PROMPT, USER_PROMPT_TEMPLATE
|
||||
|
||||
|
||||
class OpenAICompatProvider(AIProvider):
|
||||
def __init__(self, config: dict) -> None:
|
||||
self._client = AsyncOpenAI(
|
||||
base_url=config["base_url"],
|
||||
api_key=config.get("api_key", "not-required"),
|
||||
)
|
||||
self._model = config["model"]
|
||||
|
||||
async def classify_document(self, text: str) -> dict:
|
||||
response = await self._client.chat.completions.create(
|
||||
model=self._model,
|
||||
temperature=0,
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": USER_PROMPT_TEMPLATE.format(text=text[:100_000])},
|
||||
],
|
||||
)
|
||||
raw = response.choices[0].message.content.strip()
|
||||
return _parse_json(raw)
|
||||
|
||||
|
||||
def _parse_json(raw: str) -> dict:
|
||||
if raw.startswith("```"):
|
||||
raw = raw.split("\n", 1)[1].rsplit("```", 1)[0]
|
||||
return json.loads(raw)
|
||||
@@ -0,0 +1,49 @@
|
||||
"""HTTP client for the shared ai-service container."""
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
from app.core.config import settings
|
||||
from app.services.prompts import SYSTEM_PROMPT, USER_PROMPT_TEMPLATE
|
||||
|
||||
_client = httpx.AsyncClient(timeout=120.0)
|
||||
|
||||
|
||||
class AIServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
async def classify_document(text: str) -> dict:
|
||||
"""
|
||||
Send document text to ai-service for classification.
|
||||
Returns the parsed JSON result dict.
|
||||
Raises AIServiceError on HTTP errors or unexpected response shapes.
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": USER_PROMPT_TEMPLATE.format(text=text[:50_000])},
|
||||
]
|
||||
|
||||
try:
|
||||
response = await _client.post(
|
||||
f"{settings.AI_SERVICE_URL}/chat",
|
||||
json={
|
||||
"messages": messages,
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0,
|
||||
"response_format": "json",
|
||||
},
|
||||
)
|
||||
except httpx.RequestError as exc:
|
||||
raise AIServiceError(f"Could not reach ai-service: {exc}") from exc
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AIServiceError(
|
||||
f"ai-service returned {response.status_code}: {response.text[:200]}"
|
||||
)
|
||||
|
||||
try:
|
||||
content = response.json()["content"]
|
||||
return json.loads(content)
|
||||
except (KeyError, json.JSONDecodeError) as exc:
|
||||
raise AIServiceError(f"Unexpected ai-service response: {exc}") from exc
|
||||
@@ -1,18 +1,9 @@
|
||||
"""
|
||||
Reads doc_service_config.json from the shared config volume.
|
||||
Caches the result for 30 seconds to avoid hitting the filesystem on every request.
|
||||
Uses asyncio.to_thread so the synchronous file read doesn't block the event loop.
|
||||
30-second TTL cache + env var overrides.
|
||||
|
||||
Env var overrides (take precedence over the JSON config file, never committed):
|
||||
AI_PROVIDER — "lmstudio" | "ollama" | "anthropic"
|
||||
LMSTUDIO_BASE_URL — e.g. http://host.docker.internal:1234/v1
|
||||
LMSTUDIO_API_KEY
|
||||
LMSTUDIO_MODEL
|
||||
OLLAMA_BASE_URL
|
||||
OLLAMA_MODEL
|
||||
OLLAMA_API_KEY
|
||||
ANTHROPIC_API_KEY
|
||||
ANTHROPIC_MODEL
|
||||
Env var overrides (all optional):
|
||||
DOC_MAX_PDF_MB — max upload size in megabytes (e.g. "50")
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
@@ -24,15 +15,6 @@ from pathlib import Path
|
||||
from app.core.config import settings
|
||||
|
||||
_DEFAULT_CONFIG: dict = {
|
||||
"ai": {
|
||||
# Default: LM Studio running on the host machine at port 1234.
|
||||
# Inside Docker, host.docker.internal resolves to the host; for local
|
||||
# dev outside Docker use http://localhost:1234/v1 instead.
|
||||
"provider": "lmstudio",
|
||||
"anthropic": {"api_key": "", "model": "claude-haiku-4-5-20251001"},
|
||||
"ollama": {"base_url": "http://host.docker.internal:11434/v1", "model": "llama3.2", "api_key": "ollama"},
|
||||
"lmstudio": {"base_url": "http://host.docker.internal:1234/v1", "model": "local-model", "api_key": "lm-studio"},
|
||||
},
|
||||
"documents": {"max_pdf_bytes": 20 * 1024 * 1024},
|
||||
}
|
||||
|
||||
@@ -52,43 +34,13 @@ def _read_config_sync() -> dict:
|
||||
|
||||
|
||||
def _apply_env_overrides(config: dict) -> dict:
|
||||
"""
|
||||
Merge environment variable overrides into the config dict.
|
||||
Env vars win over whatever is stored in the JSON file.
|
||||
This lets the dev .env file pin the AI connection without writing to the
|
||||
shared volume (which would affect all users).
|
||||
"""
|
||||
cfg = deepcopy(config)
|
||||
ai = cfg.setdefault("ai", {})
|
||||
|
||||
if provider := os.environ.get("AI_PROVIDER"):
|
||||
ai["provider"] = provider
|
||||
|
||||
# LM Studio
|
||||
lms = ai.setdefault("lmstudio", {})
|
||||
if v := os.environ.get("LMSTUDIO_BASE_URL"):
|
||||
lms["base_url"] = v
|
||||
if v := os.environ.get("LMSTUDIO_API_KEY"):
|
||||
lms["api_key"] = v
|
||||
if v := os.environ.get("LMSTUDIO_MODEL"):
|
||||
lms["model"] = v
|
||||
|
||||
# Ollama
|
||||
oll = ai.setdefault("ollama", {})
|
||||
if v := os.environ.get("OLLAMA_BASE_URL"):
|
||||
oll["base_url"] = v
|
||||
if v := os.environ.get("OLLAMA_MODEL"):
|
||||
oll["model"] = v
|
||||
if v := os.environ.get("OLLAMA_API_KEY"):
|
||||
oll["api_key"] = v
|
||||
|
||||
# Anthropic
|
||||
ant = ai.setdefault("anthropic", {})
|
||||
if v := os.environ.get("ANTHROPIC_API_KEY"):
|
||||
ant["api_key"] = v
|
||||
if v := os.environ.get("ANTHROPIC_MODEL"):
|
||||
ant["model"] = v
|
||||
|
||||
docs = cfg.setdefault("documents", {})
|
||||
if v := os.environ.get("DOC_MAX_PDF_MB"):
|
||||
try:
|
||||
docs["max_pdf_bytes"] = int(v) * 1024 * 1024
|
||||
except ValueError:
|
||||
pass
|
||||
return cfg
|
||||
|
||||
|
||||
|
||||
-9
@@ -1,5 +1,3 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a financial document analysis assistant. "
|
||||
"Given the text extracted from a PDF document, return ONLY a JSON object "
|
||||
@@ -23,10 +21,3 @@ suggested_categories (array of 2 to 5 short category name strings a user might w
|
||||
|
||||
Document text:
|
||||
{text}"""
|
||||
|
||||
|
||||
class AIProvider(ABC):
|
||||
@abstractmethod
|
||||
async def classify_document(self, text: str) -> dict:
|
||||
"""Return structured extraction dict from document text."""
|
||||
...
|
||||
@@ -13,8 +13,7 @@ dependencies = [
|
||||
"asyncpg>=0.29",
|
||||
"alembic>=1.13",
|
||||
"pydantic-settings>=2.2",
|
||||
"anthropic>=0.28",
|
||||
"openai>=1.0",
|
||||
"httpx>=0.27",
|
||||
"pdfplumber>=0.11",
|
||||
"aiofiles>=23.0",
|
||||
"python-multipart>=0.0.9",
|
||||
|
||||
@@ -75,12 +75,13 @@ MOCK_AI_RESULT = {
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ai():
|
||||
"""Patch the AI classify_document call to return MOCK_AI_RESULT."""
|
||||
provider_mock = AsyncMock()
|
||||
provider_mock.classify_document = AsyncMock(return_value=MOCK_AI_RESULT)
|
||||
with patch("app.routers.documents.get_provider", return_value=provider_mock):
|
||||
yield provider_mock
|
||||
def mock_ai_service():
|
||||
"""Patch classify_document to return MOCK_AI_RESULT without hitting ai-service."""
|
||||
with patch(
|
||||
"app.services.ai_client.classify_document",
|
||||
new=AsyncMock(return_value=MOCK_AI_RESULT),
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
# ── HTTP client ────────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -189,7 +189,7 @@ async def test_cannot_assign_other_users_category(client, other_client, minimal_
|
||||
|
||||
# ── AI processing integration (with mock AI) ──────────────────────────────────
|
||||
|
||||
async def test_processing_sets_extracted_data(client, invoice_pdf, mock_ai):
|
||||
async def test_processing_sets_extracted_data(client, invoice_pdf, mock_ai_service):
|
||||
"""Upload + wait for background processing; verify extracted_data is populated."""
|
||||
r = await client.post("/documents/upload", files=_pdf_upload("invoice.pdf", invoice_pdf))
|
||||
assert r.status_code == 202
|
||||
@@ -217,9 +217,35 @@ async def test_processing_sets_extracted_data(client, invoice_pdf, mock_ai):
|
||||
assert len(extracted["suggested_categories"]) > 0
|
||||
|
||||
|
||||
# ── Graceful degradation when ai-service is unavailable ──────────────────────
|
||||
|
||||
async def test_processing_fails_gracefully_when_ai_service_502(client, invoice_pdf):
|
||||
"""When ai-service returns an error, document status should be 'failed', not crash."""
|
||||
from app.services.ai_client import AIServiceError
|
||||
|
||||
with patch(
|
||||
"app.services.ai_client.classify_document",
|
||||
side_effect=AIServiceError("ai-service returned 502"),
|
||||
):
|
||||
r = await client.post("/documents/upload", files=_pdf_upload("fail.pdf", invoice_pdf))
|
||||
assert r.status_code == 202
|
||||
doc_id = r.json()["id"]
|
||||
|
||||
import asyncio
|
||||
for _ in range(20):
|
||||
status_r = await client.get(f"/documents/{doc_id}/status")
|
||||
if status_r.json()["status"] in ("done", "failed"):
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
doc = (await client.get(f"/documents/{doc_id}")).json()
|
||||
assert doc["status"] == "failed"
|
||||
assert "ai-service" in (doc.get("error_message") or "").lower()
|
||||
|
||||
|
||||
# ── Live tests (require real PDFs in tests/pdfs/) ─────────────────────────────
|
||||
|
||||
async def test_live_upload_real_pdf(client, real_pdfs, mock_ai):
|
||||
async def test_live_upload_real_pdf(client, real_pdfs, mock_ai_service):
|
||||
"""Upload each real PDF from tests/pdfs/ and verify it reaches 'done'."""
|
||||
import asyncio
|
||||
for pdf_path in real_pdfs:
|
||||
|
||||
Reference in New Issue
Block a user