Add shared ai-service container as AI provider intermediary

All feature containers now POST messages to ai-service (port 8010) instead
of calling AI providers directly. ai-service routes to LM Studio, Ollama,
or Anthropic based on /config/ai_service_config.json. doc-service AI
providers removed; replaced by httpx ai_client.py. Backend settings
restructured to /api/settings/ai. Frontend gets dedicated AIAdminSettingsPage
and AI Service card in AppsPage.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
curo1305
2026-04-14 12:30:45 +02:00
parent 52a2967f61
commit 88c1ea297e
47 changed files with 1354 additions and 497 deletions
+12
View File
@@ -0,0 +1,12 @@
AI_PROVIDER=lmstudio
LMSTUDIO_BASE_URL=http://host.docker.internal:1234/v1
LMSTUDIO_API_KEY=your-lmstudio-api-key
LMSTUDIO_MODEL=local-model
OLLAMA_BASE_URL=http://host.docker.internal:11434/v1
OLLAMA_MODEL=llama3.2
OLLAMA_API_KEY=ollama
ANTHROPIC_API_KEY=sk-ant-your-key-here
ANTHROPIC_MODEL=claude-haiku-4-5-20251001
+33
View File
@@ -0,0 +1,33 @@
# ── Stage 1: dependency installation ─────────────────────────────────────────
FROM python:3.12-slim AS builder
WORKDIR /app
RUN pip install --upgrade pip
COPY pyproject.toml .
RUN pip install --prefix=/install .
# ── Stage 2: runtime ──────────────────────────────────────────────────────────
FROM python:3.12-slim
# Create non-root user (UID/GID 1001)
RUN groupadd --gid 1001 appuser && \
useradd --uid 1001 --gid 1001 --no-create-home --shell /bin/sh appuser
# Pre-create the config directory with correct ownership
RUN mkdir -p /config && chown -R appuser:appuser /config
WORKDIR /app
COPY --from=builder /install /usr/local
COPY --chown=appuser:appuser app ./app
COPY --chown=appuser:appuser scripts ./scripts
RUN chmod +x scripts/start.sh scripts/start_dev.sh
USER appuser
EXPOSE 8010
CMD ["sh", "scripts/start.sh"]
View File
+12
View File
@@ -0,0 +1,12 @@
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
PROJECT_NAME: str = "ai-service"
CONFIG_PATH: str = "/config/ai_service_config.json"
class Config:
env_file = ".env"
settings = Settings()
+25
View File
@@ -0,0 +1,25 @@
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI
from app.core.config import settings
from app.routers import chat, health
from app.services.config_reader import load_ai_config
logger = logging.getLogger("ai-service")
@asynccontextmanager
async def lifespan(app: FastAPI):
config = await load_ai_config()
provider = config.get("provider", "lmstudio")
model = config.get(provider, {}).get("model", "unknown")
logger.info("[ai-service] active provider: %s model: %s", provider, model)
yield
app = FastAPI(title=settings.PROJECT_NAME, lifespan=lifespan)
app.include_router(chat.router, tags=["chat"])
app.include_router(health.router, tags=["health"])
@@ -0,0 +1,20 @@
from app.providers.base import AIProvider
def get_provider(ai_config: dict) -> AIProvider:
"""Return an AIProvider instance for the active provider in the config."""
provider_name = ai_config.get("provider", "lmstudio")
provider_cfg = ai_config.get(provider_name, {})
match provider_name:
case "anthropic":
from app.providers.anthropic_provider import AnthropicProvider
return AnthropicProvider(provider_cfg)
case "ollama" | "lmstudio":
from app.providers.openai_compat import OpenAICompatProvider
return OpenAICompatProvider(provider_cfg, provider_name=provider_name)
case _:
raise ValueError(f"Unknown AI provider: {provider_name!r}")
__all__ = ["AIProvider", "get_provider"]
@@ -0,0 +1,54 @@
import asyncio
import anthropic
from app.providers.base import AIProvider
from app.schemas.chat import ChatMessage
class AnthropicProvider(AIProvider):
def __init__(self, config: dict) -> None:
self._client = anthropic.AsyncAnthropic(api_key=config.get("api_key", ""))
self.model_name = config.get("model", "claude-haiku-4-5-20251001")
self.provider_name = "anthropic"
async def chat(
self,
messages: list[ChatMessage],
max_tokens: int,
temperature: float,
) -> tuple[str, int | None, int | None]:
# Anthropic uses a top-level `system=` param, not a role in the messages array
system_content = ""
user_messages = []
for msg in messages:
if msg.role == "system":
system_content += msg.content + "\n"
else:
user_messages.append({"role": msg.role, "content": msg.content})
try:
response = await self._client.messages.create(
model=self.model_name,
max_tokens=max_tokens,
temperature=temperature,
system=system_content.strip() or anthropic.NOT_GIVEN,
messages=user_messages,
)
except anthropic.APIConnectionError as exc:
raise ProviderConnectionError(str(exc)) from exc
except anthropic.APITimeoutError as exc:
raise ProviderTimeoutError(str(exc)) from exc
except anthropic.APIStatusError as exc:
raise ProviderConnectionError(f"Anthropic API error {exc.status_code}: {exc.message}") from exc
content = response.content[0].text
return content, response.usage.input_tokens, response.usage.output_tokens
class ProviderConnectionError(Exception):
pass
class ProviderTimeoutError(Exception):
pass
+23
View File
@@ -0,0 +1,23 @@
from abc import ABC, abstractmethod
from app.schemas.chat import ChatMessage
class AIProvider(ABC):
provider_name: str = "unknown"
model_name: str = "unknown"
@abstractmethod
async def chat(
self,
messages: list[ChatMessage],
max_tokens: int,
temperature: float,
) -> tuple[str, int | None, int | None]:
"""
Send messages to the provider and return (content, input_tokens, output_tokens).
Raises:
ProviderConnectionError: on network / auth failure
ProviderTimeoutError: on request timeout
"""
...
@@ -0,0 +1,52 @@
"""OpenAI-compatible provider — handles both Ollama and LM Studio."""
import asyncio
import openai
from app.providers.base import AIProvider
from app.schemas.chat import ChatMessage
class OpenAICompatProvider(AIProvider):
def __init__(self, config: dict, provider_name: str = "lmstudio") -> None:
self._client = openai.AsyncOpenAI(
base_url=config.get("base_url", "http://localhost:1234/v1"),
api_key=config.get("api_key") or "not-required",
)
self.model_name = config.get("model", "local-model")
self.provider_name = provider_name
async def chat(
self,
messages: list[ChatMessage],
max_tokens: int,
temperature: float,
) -> tuple[str, int | None, int | None]:
raw_messages = [{"role": m.role, "content": m.content} for m in messages]
try:
response = await self._client.chat.completions.create(
model=self.model_name,
messages=raw_messages,
max_tokens=max_tokens,
temperature=temperature,
)
except openai.APIConnectionError as exc:
raise ProviderConnectionError(str(exc)) from exc
except openai.APITimeoutError as exc:
raise ProviderTimeoutError(str(exc)) from exc
except openai.APIStatusError as exc:
raise ProviderConnectionError(f"API error {exc.status_code}: {exc.message}") from exc
content = response.choices[0].message.content or ""
usage = response.usage
input_tokens = usage.prompt_tokens if usage else None
output_tokens = usage.completion_tokens if usage else None
return content, input_tokens, output_tokens
class ProviderConnectionError(Exception):
pass
class ProviderTimeoutError(Exception):
pass
+70
View File
@@ -0,0 +1,70 @@
import asyncio
import re
from fastapi import APIRouter, HTTPException
from app.providers import get_provider
from app.providers.anthropic_provider import ProviderConnectionError as AnthropicConnError
from app.providers.anthropic_provider import ProviderTimeoutError as AnthropicTimeoutError
from app.providers.openai_compat import ProviderConnectionError as OpenAIConnError
from app.providers.openai_compat import ProviderTimeoutError as OpenAITimeoutError
from app.schemas.chat import ChatRequest, ChatResponse
from app.services.config_reader import load_ai_config
router = APIRouter()
_FENCE_RE = re.compile(r"^```[a-z]*\n?(.*?)\n?```$", re.DOTALL)
def _strip_fences(text: str) -> str:
m = _FENCE_RE.match(text.strip())
return m.group(1).strip() if m else text.strip()
@router.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest) -> ChatResponse:
config = await load_ai_config()
provider_name = config.get("provider", "lmstudio")
if provider_name not in ("anthropic", "ollama", "lmstudio"):
raise HTTPException(status_code=503, detail=f"Unknown provider configured: {provider_name!r}")
try:
provider = get_provider(config)
except ValueError as exc:
raise HTTPException(status_code=503, detail=str(exc))
timeout = config.get("timeout_seconds", 60)
max_retries = config.get("max_retries", 2)
last_exc: Exception | None = None
for attempt in range(max_retries + 1):
try:
content, input_tokens, output_tokens = await asyncio.wait_for(
provider.chat(request.messages, request.max_tokens, request.temperature),
timeout=float(timeout),
)
break
except asyncio.TimeoutError as exc:
last_exc = exc
# Don't retry on timeout — the model is busy; fail fast
raise HTTPException(status_code=504, detail="AI provider timed out") from exc
except (AnthropicConnError, OpenAIConnError) as exc:
last_exc = exc
if attempt < max_retries:
await asyncio.sleep(0.5 * (attempt + 1))
continue
raise HTTPException(status_code=502, detail=f"AI provider error: {exc}") from exc
except (AnthropicTimeoutError, OpenAITimeoutError) as exc:
raise HTTPException(status_code=504, detail="AI provider timed out") from exc
if request.response_format == "json":
content = _strip_fences(content)
return ChatResponse(
content=content,
provider=provider.provider_name,
model=provider.model_name,
input_tokens=input_tokens,
output_tokens=output_tokens,
)
+30
View File
@@ -0,0 +1,30 @@
from fastapi import APIRouter
from app.services.config_reader import load_ai_config
router = APIRouter()
@router.get("/health")
async def health() -> dict:
return {"status": "ok"}
@router.get("/health/provider")
async def provider_status() -> dict:
config = await load_ai_config()
provider = config.get("provider", "lmstudio")
pcfg = config.get(provider, {})
model = pcfg.get("model", "")
# "configured" means we have the minimum required fields for the provider
if provider == "anthropic":
configured = bool(pcfg.get("api_key"))
else:
configured = bool(pcfg.get("base_url") and pcfg.get("model"))
return {
"provider": provider,
"model": model,
"configured": configured,
}
+30
View File
@@ -0,0 +1,30 @@
from typing import Literal
from pydantic import BaseModel, field_validator
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant"]
content: str
class ChatRequest(BaseModel):
messages: list[ChatMessage]
max_tokens: int = 2048
temperature: float = 0.0
response_format: Literal["json", "text"] = "text"
@field_validator("messages")
@classmethod
def messages_not_empty(cls, v: list) -> list:
if not v:
raise ValueError("messages must not be empty")
return v
class ChatResponse(BaseModel):
content: str
provider: str
model: str
input_tokens: int | None = None
output_tokens: int | None = None
@@ -0,0 +1,81 @@
"""
Reads ai_service_config.json from the shared config volume.
30-second TTL cache + env var overrides (dev credentials stay out of git).
Env var overrides (all optional):
AI_PROVIDER — "lmstudio" | "ollama" | "anthropic"
LMSTUDIO_BASE_URL, LMSTUDIO_API_KEY, LMSTUDIO_MODEL
OLLAMA_BASE_URL, OLLAMA_MODEL, OLLAMA_API_KEY
ANTHROPIC_API_KEY, ANTHROPIC_MODEL
"""
import asyncio
import json
import os
import time
from copy import deepcopy
from pathlib import Path
from app.core.config import settings
_DEFAULT_CONFIG: dict = {
"provider": "lmstudio",
"timeout_seconds": 60,
"max_retries": 2,
"anthropic": {"api_key": "", "model": "claude-haiku-4-5-20251001"},
"ollama": {"base_url": "http://host.docker.internal:11434/v1", "model": "llama3.2", "api_key": "ollama"},
"lmstudio": {"base_url": "http://host.docker.internal:1234/v1", "model": "local-model", "api_key": "lm-studio"},
}
_cache: dict | None = None
_cache_at: float = 0.0
_CACHE_TTL = 30.0
def _read_config_sync() -> dict:
path = Path(settings.CONFIG_PATH)
if not path.exists():
return _apply_env_overrides(deepcopy(_DEFAULT_CONFIG))
with open(path) as f:
return _apply_env_overrides(json.load(f))
def _apply_env_overrides(config: dict) -> dict:
cfg = deepcopy(config)
if v := os.environ.get("AI_PROVIDER"):
cfg["provider"] = v
lms = cfg.setdefault("lmstudio", {})
if v := os.environ.get("LMSTUDIO_BASE_URL"):
lms["base_url"] = v
if v := os.environ.get("LMSTUDIO_API_KEY"):
lms["api_key"] = v
if v := os.environ.get("LMSTUDIO_MODEL"):
lms["model"] = v
oll = cfg.setdefault("ollama", {})
if v := os.environ.get("OLLAMA_BASE_URL"):
oll["base_url"] = v
if v := os.environ.get("OLLAMA_MODEL"):
oll["model"] = v
if v := os.environ.get("OLLAMA_API_KEY"):
oll["api_key"] = v
ant = cfg.setdefault("anthropic", {})
if v := os.environ.get("ANTHROPIC_API_KEY"):
ant["api_key"] = v
if v := os.environ.get("ANTHROPIC_MODEL"):
ant["model"] = v
return cfg
async def load_ai_config() -> dict:
global _cache, _cache_at
now = time.monotonic()
if _cache is not None and (now - _cache_at) < _CACHE_TTL:
return _cache
data = await asyncio.to_thread(_read_config_sync)
_cache = data
_cache_at = now
return data
+29
View File
@@ -0,0 +1,29 @@
[build-system]
requires = ["setuptools>=45"]
build-backend = "setuptools.build_meta"
[project]
name = "ai-service"
version = "0.1.0"
requires-python = ">=3.11"
dependencies = [
"fastapi>=0.111",
"uvicorn[standard]>=0.29",
"pydantic-settings>=2.2",
"anthropic>=0.28",
"openai>=1.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8",
"pytest-asyncio>=0.23",
"httpx>=0.27",
"ruff>=0.4",
]
[tool.pytest.ini_options]
asyncio_mode = "auto"
[tool.ruff]
line-length = 100
+4
View File
@@ -0,0 +1,4 @@
#!/bin/sh
set -e
echo "[ai-service] starting uvicorn..."
exec uvicorn app.main:app --host 0.0.0.0 --port 8010
+4
View File
@@ -0,0 +1,4 @@
#!/bin/sh
set -e
echo "[ai-service] starting uvicorn (dev)..."
exec uvicorn app.main:app --host 0.0.0.0 --port 8010 --reload
+57
View File
@@ -0,0 +1,57 @@
import pytest
from httpx import ASGITransport, AsyncClient
from app.main import app
@pytest.fixture
async def ai_client():
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
yield client
# ---------------------------------------------------------------------------
# Config fixtures
# ---------------------------------------------------------------------------
LMSTUDIO_CONFIG = {
"provider": "lmstudio",
"timeout_seconds": 10,
"max_retries": 0,
"lmstudio": {
"base_url": "http://fake-lmstudio/v1",
"model": "test-model",
"api_key": "test-key",
},
}
OLLAMA_CONFIG = {
"provider": "ollama",
"timeout_seconds": 10,
"max_retries": 0,
"ollama": {
"base_url": "http://fake-ollama/v1",
"model": "llama3.2",
"api_key": "ollama",
},
}
ANTHROPIC_CONFIG = {
"provider": "anthropic",
"timeout_seconds": 10,
"max_retries": 0,
"anthropic": {
"api_key": "sk-ant-test",
"model": "claude-haiku-4-5-20251001",
},
}
MISSING_KEY_ANTHROPIC_CONFIG = {
"provider": "anthropic",
"timeout_seconds": 10,
"max_retries": 0,
"anthropic": {
"api_key": "",
"model": "claude-haiku-4-5-20251001",
},
}
+221
View File
@@ -0,0 +1,221 @@
"""Tests for POST /chat."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from tests.conftest import ANTHROPIC_CONFIG, LMSTUDIO_CONFIG, OLLAMA_CONFIG
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_LOAD_CONFIG = "app.routers.chat.load_ai_config"
_PROVIDER_CHAT = "app.providers.openai_compat.OpenAICompatProvider.chat"
_ANTHROPIC_CHAT = "app.providers.anthropic_provider.AnthropicProvider.chat"
MESSAGES = [{"role": "user", "content": "Hello"}]
SYSTEM_MESSAGES = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello"},
]
def _mock_chat_response(content="ok", input_tokens=10, output_tokens=5):
return AsyncMock(return_value=(content, input_tokens, output_tokens))
# ---------------------------------------------------------------------------
# Success: each provider
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_lmstudio_success(ai_client):
with patch(_LOAD_CONFIG, return_value=LMSTUDIO_CONFIG), patch(
_PROVIDER_CHAT, new=_mock_chat_response("lmstudio reply")
):
resp = await ai_client.post("/chat", json={"messages": MESSAGES})
assert resp.status_code == 200
data = resp.json()
assert data["content"] == "lmstudio reply"
assert data["provider"] == "lmstudio"
assert data["model"] == "test-model"
assert data["input_tokens"] == 10
assert data["output_tokens"] == 5
@pytest.mark.asyncio
async def test_chat_ollama_success(ai_client):
with patch(_LOAD_CONFIG, return_value=OLLAMA_CONFIG), patch(
_PROVIDER_CHAT, new=_mock_chat_response("ollama reply")
):
resp = await ai_client.post("/chat", json={"messages": MESSAGES})
assert resp.status_code == 200
data = resp.json()
assert data["content"] == "ollama reply"
assert data["provider"] == "ollama"
@pytest.mark.asyncio
async def test_chat_anthropic_success(ai_client):
with patch(_LOAD_CONFIG, return_value=ANTHROPIC_CONFIG), patch(
_ANTHROPIC_CHAT, new=_mock_chat_response("anthropic reply")
):
resp = await ai_client.post("/chat", json={"messages": MESSAGES})
assert resp.status_code == 200
data = resp.json()
assert data["content"] == "anthropic reply"
assert data["provider"] == "anthropic"
# ---------------------------------------------------------------------------
# response_format
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_response_format_json_strips_fences(ai_client):
fenced = "```json\n{\"key\": \"value\"}\n```"
with patch(_LOAD_CONFIG, return_value=LMSTUDIO_CONFIG), patch(
_PROVIDER_CHAT, new=_mock_chat_response(fenced)
):
resp = await ai_client.post(
"/chat",
json={"messages": MESSAGES, "response_format": "json"},
)
assert resp.status_code == 200
assert resp.json()["content"] == '{"key": "value"}'
@pytest.mark.asyncio
async def test_response_format_text_preserves_fences(ai_client):
fenced = "```python\nprint('hi')\n```"
with patch(_LOAD_CONFIG, return_value=LMSTUDIO_CONFIG), patch(
_PROVIDER_CHAT, new=_mock_chat_response(fenced)
):
resp = await ai_client.post(
"/chat",
json={"messages": MESSAGES, "response_format": "text"},
)
assert resp.status_code == 200
assert "```" in resp.json()["content"]
# ---------------------------------------------------------------------------
# Validation errors
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_missing_messages_returns_422(ai_client):
resp = await ai_client.post("/chat", json={})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_chat_empty_messages_returns_422(ai_client):
resp = await ai_client.post("/chat", json={"messages": []})
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# Provider errors
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_connection_error_returns_502(ai_client):
from app.providers.openai_compat import ProviderConnectionError
with patch(_LOAD_CONFIG, return_value=LMSTUDIO_CONFIG), patch(
_PROVIDER_CHAT, side_effect=ProviderConnectionError("refused")
):
resp = await ai_client.post("/chat", json={"messages": MESSAGES})
assert resp.status_code == 502
@pytest.mark.asyncio
async def test_chat_timeout_returns_504(ai_client):
async def _slow(*_args, **_kwargs):
await asyncio.sleep(100)
with patch(_LOAD_CONFIG, return_value={**LMSTUDIO_CONFIG, "timeout_seconds": 0.01}), patch(
_PROVIDER_CHAT, new=_slow
):
resp = await ai_client.post("/chat", json={"messages": MESSAGES})
assert resp.status_code == 504
@pytest.mark.asyncio
async def test_chat_unknown_provider_returns_503(ai_client):
bad_config = {**LMSTUDIO_CONFIG, "provider": "unknown-llm"}
with patch(_LOAD_CONFIG, return_value=bad_config):
resp = await ai_client.post("/chat", json={"messages": MESSAGES})
assert resp.status_code == 503
# ---------------------------------------------------------------------------
# Anthropic system message extraction
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_anthropic_system_message_extracted(ai_client):
"""System-role messages must not appear in the user_messages list sent to Anthropic."""
captured_kwargs: dict = {}
async def _fake_create(**kwargs):
captured_kwargs.update(kwargs)
mock_resp = MagicMock()
mock_resp.content = [MagicMock(text="ok")]
mock_resp.usage = MagicMock(input_tokens=5, output_tokens=2)
return mock_resp
with patch(_LOAD_CONFIG, return_value=ANTHROPIC_CONFIG), patch(
"anthropic.AsyncAnthropic.messages",
new_callable=lambda: type(
"Messages",
(),
{"create": staticmethod(AsyncMock(side_effect=_fake_create))},
),
):
resp = await ai_client.post("/chat", json={"messages": SYSTEM_MESSAGES})
# Whether the call succeeded or not, no system role should reach the messages list
if "messages" in captured_kwargs:
roles = [m["role"] for m in captured_kwargs["messages"]]
assert "system" not in roles
# ---------------------------------------------------------------------------
# Parameter forwarding
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_max_tokens_and_temperature_forwarded(ai_client):
captured: dict = {}
async def _capture(messages, max_tokens, temperature):
captured["max_tokens"] = max_tokens
captured["temperature"] = temperature
return ("ok", 1, 1)
with patch(_LOAD_CONFIG, return_value=LMSTUDIO_CONFIG), patch(_PROVIDER_CHAT, new=_capture):
resp = await ai_client.post(
"/chat",
json={"messages": MESSAGES, "max_tokens": 512, "temperature": 0.7},
)
assert resp.status_code == 200
assert captured["max_tokens"] == 512
assert captured["temperature"] == pytest.approx(0.7)
+38
View File
@@ -0,0 +1,38 @@
"""Tests for GET /health and GET /health/provider."""
from unittest.mock import patch
import pytest
from tests.conftest import ANTHROPIC_CONFIG, LMSTUDIO_CONFIG, MISSING_KEY_ANTHROPIC_CONFIG
_LOAD_CONFIG = "app.routers.health.load_ai_config"
@pytest.mark.asyncio
async def test_health_returns_ok(ai_client):
resp = await ai_client.get("/health")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}
@pytest.mark.asyncio
async def test_provider_status_configured(ai_client):
with patch(_LOAD_CONFIG, return_value=LMSTUDIO_CONFIG):
resp = await ai_client.get("/health/provider")
assert resp.status_code == 200
data = resp.json()
assert data["provider"] == "lmstudio"
assert data["model"] == "test-model"
assert data["configured"] is True
@pytest.mark.asyncio
async def test_provider_status_not_configured_when_api_key_missing(ai_client):
with patch(_LOAD_CONFIG, return_value=MISSING_KEY_ANTHROPIC_CONFIG):
resp = await ai_client.get("/health/provider")
assert resp.status_code == 200
data = resp.json()
assert data["provider"] == "anthropic"
assert data["configured"] is False
+1
View File
@@ -6,6 +6,7 @@ class Settings(BaseSettings):
DATABASE_URL: str = "postgresql+asyncpg://postgres:password@db:5432/destroying_sap"
DATA_DIR: str = "/data/documents"
CONFIG_PATH: str = "/config/doc_service_config.json"
AI_SERVICE_URL: str = "http://ai-service:8010"
class Config:
env_file = ".env"
@@ -17,7 +17,7 @@ from app.models.category import DocumentCategory
from app.models.category_assignment import CategoryAssignment
from app.models.document import Document
from app.schemas.document import DocumentOut, DocumentStatusOut, DocumentTypeUpdate
from app.services.ai import get_provider
from app.services.ai_client import AIServiceError, classify_document
from app.services.config_reader import load_doc_config
from app.services.storage import delete_file, get_upload_path, save_upload
@@ -91,9 +91,7 @@ async def process_document(doc_id: str) -> None:
try:
text = await asyncio.to_thread(_extract_pdf_text, doc.file_path)
config = await load_doc_config()
provider = get_provider(config["ai"])
result = await provider.classify_document(text)
result = await classify_document(text)
doc.raw_text = text[:500_000] # cap stored text at 500k chars
doc.extracted_data = json.dumps(result)
@@ -1,23 +0,0 @@
from app.services.ai.base import AIProvider
def get_provider(ai_config: dict) -> AIProvider:
"""
Factory: return an AIProvider instance based on the 'provider' key in the AI config section.
ai_config is the 'ai' section of doc_service_config.json, loaded fresh per processing job.
"""
provider_name = ai_config.get("provider", "anthropic")
provider_cfg = ai_config.get(provider_name, {})
match provider_name:
case "anthropic":
from app.services.ai.anthropic_provider import AnthropicProvider
return AnthropicProvider(provider_cfg)
case "ollama" | "lmstudio":
from app.services.ai.openai_compat import OpenAICompatProvider
return OpenAICompatProvider(provider_cfg)
case _:
raise ValueError(f"Unknown AI provider: {provider_name!r}")
__all__ = ["AIProvider", "get_provider"]
@@ -1,31 +0,0 @@
import json
from anthropic import AsyncAnthropic
from app.services.ai.base import AIProvider, SYSTEM_PROMPT, USER_PROMPT_TEMPLATE
class AnthropicProvider(AIProvider):
def __init__(self, config: dict) -> None:
self._client = AsyncAnthropic(api_key=config["api_key"])
self._model = config.get("model", "claude-haiku-4-5-20251001")
async def classify_document(self, text: str) -> dict:
message = await self._client.messages.create(
model=self._model,
max_tokens=2048,
system=SYSTEM_PROMPT,
messages=[{
"role": "user",
"content": USER_PROMPT_TEMPLATE.format(text=text[:100_000]),
}],
)
raw = message.content[0].text.strip()
return _parse_json(raw)
def _parse_json(raw: str) -> dict:
# Strip accidental markdown fences despite explicit instruction not to include them
if raw.startswith("```"):
raw = raw.split("\n", 1)[1].rsplit("```", 1)[0]
return json.loads(raw)
@@ -1,36 +0,0 @@
"""
OpenAI-compatible provider for Ollama and LM Studio.
Both expose an OpenAI-compatible /v1/chat/completions endpoint.
"""
import json
from openai import AsyncOpenAI
from app.services.ai.base import AIProvider, SYSTEM_PROMPT, USER_PROMPT_TEMPLATE
class OpenAICompatProvider(AIProvider):
def __init__(self, config: dict) -> None:
self._client = AsyncOpenAI(
base_url=config["base_url"],
api_key=config.get("api_key", "not-required"),
)
self._model = config["model"]
async def classify_document(self, text: str) -> dict:
response = await self._client.chat.completions.create(
model=self._model,
temperature=0,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": USER_PROMPT_TEMPLATE.format(text=text[:100_000])},
],
)
raw = response.choices[0].message.content.strip()
return _parse_json(raw)
def _parse_json(raw: str) -> dict:
if raw.startswith("```"):
raw = raw.split("\n", 1)[1].rsplit("```", 1)[0]
return json.loads(raw)
@@ -0,0 +1,49 @@
"""HTTP client for the shared ai-service container."""
import json
import httpx
from app.core.config import settings
from app.services.prompts import SYSTEM_PROMPT, USER_PROMPT_TEMPLATE
_client = httpx.AsyncClient(timeout=120.0)
class AIServiceError(Exception):
pass
async def classify_document(text: str) -> dict:
"""
Send document text to ai-service for classification.
Returns the parsed JSON result dict.
Raises AIServiceError on HTTP errors or unexpected response shapes.
"""
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": USER_PROMPT_TEMPLATE.format(text=text[:50_000])},
]
try:
response = await _client.post(
f"{settings.AI_SERVICE_URL}/chat",
json={
"messages": messages,
"max_tokens": 2048,
"temperature": 0,
"response_format": "json",
},
)
except httpx.RequestError as exc:
raise AIServiceError(f"Could not reach ai-service: {exc}") from exc
if response.status_code != 200:
raise AIServiceError(
f"ai-service returned {response.status_code}: {response.text[:200]}"
)
try:
content = response.json()["content"]
return json.loads(content)
except (KeyError, json.JSONDecodeError) as exc:
raise AIServiceError(f"Unexpected ai-service response: {exc}") from exc
@@ -1,18 +1,9 @@
"""
Reads doc_service_config.json from the shared config volume.
Caches the result for 30 seconds to avoid hitting the filesystem on every request.
Uses asyncio.to_thread so the synchronous file read doesn't block the event loop.
30-second TTL cache + env var overrides.
Env var overrides (take precedence over the JSON config file, never committed):
AI_PROVIDER — "lmstudio" | "ollama" | "anthropic"
LMSTUDIO_BASE_URL — e.g. http://host.docker.internal:1234/v1
LMSTUDIO_API_KEY
LMSTUDIO_MODEL
OLLAMA_BASE_URL
OLLAMA_MODEL
OLLAMA_API_KEY
ANTHROPIC_API_KEY
ANTHROPIC_MODEL
Env var overrides (all optional):
DOC_MAX_PDF_MB — max upload size in megabytes (e.g. "50")
"""
import asyncio
import json
@@ -24,15 +15,6 @@ from pathlib import Path
from app.core.config import settings
_DEFAULT_CONFIG: dict = {
"ai": {
# Default: LM Studio running on the host machine at port 1234.
# Inside Docker, host.docker.internal resolves to the host; for local
# dev outside Docker use http://localhost:1234/v1 instead.
"provider": "lmstudio",
"anthropic": {"api_key": "", "model": "claude-haiku-4-5-20251001"},
"ollama": {"base_url": "http://host.docker.internal:11434/v1", "model": "llama3.2", "api_key": "ollama"},
"lmstudio": {"base_url": "http://host.docker.internal:1234/v1", "model": "local-model", "api_key": "lm-studio"},
},
"documents": {"max_pdf_bytes": 20 * 1024 * 1024},
}
@@ -52,43 +34,13 @@ def _read_config_sync() -> dict:
def _apply_env_overrides(config: dict) -> dict:
"""
Merge environment variable overrides into the config dict.
Env vars win over whatever is stored in the JSON file.
This lets the dev .env file pin the AI connection without writing to the
shared volume (which would affect all users).
"""
cfg = deepcopy(config)
ai = cfg.setdefault("ai", {})
if provider := os.environ.get("AI_PROVIDER"):
ai["provider"] = provider
# LM Studio
lms = ai.setdefault("lmstudio", {})
if v := os.environ.get("LMSTUDIO_BASE_URL"):
lms["base_url"] = v
if v := os.environ.get("LMSTUDIO_API_KEY"):
lms["api_key"] = v
if v := os.environ.get("LMSTUDIO_MODEL"):
lms["model"] = v
# Ollama
oll = ai.setdefault("ollama", {})
if v := os.environ.get("OLLAMA_BASE_URL"):
oll["base_url"] = v
if v := os.environ.get("OLLAMA_MODEL"):
oll["model"] = v
if v := os.environ.get("OLLAMA_API_KEY"):
oll["api_key"] = v
# Anthropic
ant = ai.setdefault("anthropic", {})
if v := os.environ.get("ANTHROPIC_API_KEY"):
ant["api_key"] = v
if v := os.environ.get("ANTHROPIC_MODEL"):
ant["model"] = v
docs = cfg.setdefault("documents", {})
if v := os.environ.get("DOC_MAX_PDF_MB"):
try:
docs["max_pdf_bytes"] = int(v) * 1024 * 1024
except ValueError:
pass
return cfg
@@ -1,5 +1,3 @@
from abc import ABC, abstractmethod
SYSTEM_PROMPT = (
"You are a financial document analysis assistant. "
"Given the text extracted from a PDF document, return ONLY a JSON object "
@@ -23,10 +21,3 @@ suggested_categories (array of 2 to 5 short category name strings a user might w
Document text:
{text}"""
class AIProvider(ABC):
@abstractmethod
async def classify_document(self, text: str) -> dict:
"""Return structured extraction dict from document text."""
...
+1 -2
View File
@@ -13,8 +13,7 @@ dependencies = [
"asyncpg>=0.29",
"alembic>=1.13",
"pydantic-settings>=2.2",
"anthropic>=0.28",
"openai>=1.0",
"httpx>=0.27",
"pdfplumber>=0.11",
"aiofiles>=23.0",
"python-multipart>=0.0.9",
+7 -6
View File
@@ -75,12 +75,13 @@ MOCK_AI_RESULT = {
@pytest.fixture
def mock_ai():
"""Patch the AI classify_document call to return MOCK_AI_RESULT."""
provider_mock = AsyncMock()
provider_mock.classify_document = AsyncMock(return_value=MOCK_AI_RESULT)
with patch("app.routers.documents.get_provider", return_value=provider_mock):
yield provider_mock
def mock_ai_service():
"""Patch classify_document to return MOCK_AI_RESULT without hitting ai-service."""
with patch(
"app.services.ai_client.classify_document",
new=AsyncMock(return_value=MOCK_AI_RESULT),
) as mock:
yield mock
# ── HTTP client ────────────────────────────────────────────────────────────────
+28 -2
View File
@@ -189,7 +189,7 @@ async def test_cannot_assign_other_users_category(client, other_client, minimal_
# ── AI processing integration (with mock AI) ──────────────────────────────────
async def test_processing_sets_extracted_data(client, invoice_pdf, mock_ai):
async def test_processing_sets_extracted_data(client, invoice_pdf, mock_ai_service):
"""Upload + wait for background processing; verify extracted_data is populated."""
r = await client.post("/documents/upload", files=_pdf_upload("invoice.pdf", invoice_pdf))
assert r.status_code == 202
@@ -217,9 +217,35 @@ async def test_processing_sets_extracted_data(client, invoice_pdf, mock_ai):
assert len(extracted["suggested_categories"]) > 0
# ── Graceful degradation when ai-service is unavailable ──────────────────────
async def test_processing_fails_gracefully_when_ai_service_502(client, invoice_pdf):
"""When ai-service returns an error, document status should be 'failed', not crash."""
from app.services.ai_client import AIServiceError
with patch(
"app.services.ai_client.classify_document",
side_effect=AIServiceError("ai-service returned 502"),
):
r = await client.post("/documents/upload", files=_pdf_upload("fail.pdf", invoice_pdf))
assert r.status_code == 202
doc_id = r.json()["id"]
import asyncio
for _ in range(20):
status_r = await client.get(f"/documents/{doc_id}/status")
if status_r.json()["status"] in ("done", "failed"):
break
await asyncio.sleep(0.1)
doc = (await client.get(f"/documents/{doc_id}")).json()
assert doc["status"] == "failed"
assert "ai-service" in (doc.get("error_message") or "").lower()
# ── Live tests (require real PDFs in tests/pdfs/) ─────────────────────────────
async def test_live_upload_real_pdf(client, real_pdfs, mock_ai):
async def test_live_upload_real_pdf(client, real_pdfs, mock_ai_service):
"""Upload each real PDF from tests/pdfs/ and verify it reaches 'done'."""
import asyncio
for pdf_path in real_pdfs: