Add shared ai-service container as AI provider intermediary
All feature containers now POST messages to ai-service (port 8010) instead of calling AI providers directly. ai-service routes to LM Studio, Ollama, or Anthropic based on /config/ai_service_config.json. doc-service AI providers removed; replaced by httpx ai_client.py. Backend settings restructured to /api/settings/ai. Frontend gets dedicated AIAdminSettingsPage and AI Service card in AppsPage. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,12 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "ai-service"
|
||||
CONFIG_PATH: str = "/config/ai_service_config.json"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -0,0 +1,25 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.core.config import settings
|
||||
from app.routers import chat, health
|
||||
from app.services.config_reader import load_ai_config
|
||||
|
||||
logger = logging.getLogger("ai-service")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
config = await load_ai_config()
|
||||
provider = config.get("provider", "lmstudio")
|
||||
model = config.get(provider, {}).get("model", "unknown")
|
||||
logger.info("[ai-service] active provider: %s model: %s", provider, model)
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title=settings.PROJECT_NAME, lifespan=lifespan)
|
||||
|
||||
app.include_router(chat.router, tags=["chat"])
|
||||
app.include_router(health.router, tags=["health"])
|
||||
@@ -0,0 +1,20 @@
|
||||
from app.providers.base import AIProvider
|
||||
|
||||
|
||||
def get_provider(ai_config: dict) -> AIProvider:
|
||||
"""Return an AIProvider instance for the active provider in the config."""
|
||||
provider_name = ai_config.get("provider", "lmstudio")
|
||||
provider_cfg = ai_config.get(provider_name, {})
|
||||
|
||||
match provider_name:
|
||||
case "anthropic":
|
||||
from app.providers.anthropic_provider import AnthropicProvider
|
||||
return AnthropicProvider(provider_cfg)
|
||||
case "ollama" | "lmstudio":
|
||||
from app.providers.openai_compat import OpenAICompatProvider
|
||||
return OpenAICompatProvider(provider_cfg, provider_name=provider_name)
|
||||
case _:
|
||||
raise ValueError(f"Unknown AI provider: {provider_name!r}")
|
||||
|
||||
|
||||
__all__ = ["AIProvider", "get_provider"]
|
||||
@@ -0,0 +1,54 @@
|
||||
import asyncio
|
||||
|
||||
import anthropic
|
||||
|
||||
from app.providers.base import AIProvider
|
||||
from app.schemas.chat import ChatMessage
|
||||
|
||||
|
||||
class AnthropicProvider(AIProvider):
|
||||
def __init__(self, config: dict) -> None:
|
||||
self._client = anthropic.AsyncAnthropic(api_key=config.get("api_key", ""))
|
||||
self.model_name = config.get("model", "claude-haiku-4-5-20251001")
|
||||
self.provider_name = "anthropic"
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> tuple[str, int | None, int | None]:
|
||||
# Anthropic uses a top-level `system=` param, not a role in the messages array
|
||||
system_content = ""
|
||||
user_messages = []
|
||||
for msg in messages:
|
||||
if msg.role == "system":
|
||||
system_content += msg.content + "\n"
|
||||
else:
|
||||
user_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
try:
|
||||
response = await self._client.messages.create(
|
||||
model=self.model_name,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
system=system_content.strip() or anthropic.NOT_GIVEN,
|
||||
messages=user_messages,
|
||||
)
|
||||
except anthropic.APIConnectionError as exc:
|
||||
raise ProviderConnectionError(str(exc)) from exc
|
||||
except anthropic.APITimeoutError as exc:
|
||||
raise ProviderTimeoutError(str(exc)) from exc
|
||||
except anthropic.APIStatusError as exc:
|
||||
raise ProviderConnectionError(f"Anthropic API error {exc.status_code}: {exc.message}") from exc
|
||||
|
||||
content = response.content[0].text
|
||||
return content, response.usage.input_tokens, response.usage.output_tokens
|
||||
|
||||
|
||||
class ProviderConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ProviderTimeoutError(Exception):
|
||||
pass
|
||||
@@ -0,0 +1,23 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app.schemas.chat import ChatMessage
|
||||
|
||||
|
||||
class AIProvider(ABC):
|
||||
provider_name: str = "unknown"
|
||||
model_name: str = "unknown"
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> tuple[str, int | None, int | None]:
|
||||
"""
|
||||
Send messages to the provider and return (content, input_tokens, output_tokens).
|
||||
Raises:
|
||||
ProviderConnectionError: on network / auth failure
|
||||
ProviderTimeoutError: on request timeout
|
||||
"""
|
||||
...
|
||||
@@ -0,0 +1,52 @@
|
||||
"""OpenAI-compatible provider — handles both Ollama and LM Studio."""
|
||||
import asyncio
|
||||
|
||||
import openai
|
||||
|
||||
from app.providers.base import AIProvider
|
||||
from app.schemas.chat import ChatMessage
|
||||
|
||||
|
||||
class OpenAICompatProvider(AIProvider):
|
||||
def __init__(self, config: dict, provider_name: str = "lmstudio") -> None:
|
||||
self._client = openai.AsyncOpenAI(
|
||||
base_url=config.get("base_url", "http://localhost:1234/v1"),
|
||||
api_key=config.get("api_key") or "not-required",
|
||||
)
|
||||
self.model_name = config.get("model", "local-model")
|
||||
self.provider_name = provider_name
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> tuple[str, int | None, int | None]:
|
||||
raw_messages = [{"role": m.role, "content": m.content} for m in messages]
|
||||
try:
|
||||
response = await self._client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=raw_messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
except openai.APIConnectionError as exc:
|
||||
raise ProviderConnectionError(str(exc)) from exc
|
||||
except openai.APITimeoutError as exc:
|
||||
raise ProviderTimeoutError(str(exc)) from exc
|
||||
except openai.APIStatusError as exc:
|
||||
raise ProviderConnectionError(f"API error {exc.status_code}: {exc.message}") from exc
|
||||
|
||||
content = response.choices[0].message.content or ""
|
||||
usage = response.usage
|
||||
input_tokens = usage.prompt_tokens if usage else None
|
||||
output_tokens = usage.completion_tokens if usage else None
|
||||
return content, input_tokens, output_tokens
|
||||
|
||||
|
||||
class ProviderConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ProviderTimeoutError(Exception):
|
||||
pass
|
||||
@@ -0,0 +1,70 @@
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.providers import get_provider
|
||||
from app.providers.anthropic_provider import ProviderConnectionError as AnthropicConnError
|
||||
from app.providers.anthropic_provider import ProviderTimeoutError as AnthropicTimeoutError
|
||||
from app.providers.openai_compat import ProviderConnectionError as OpenAIConnError
|
||||
from app.providers.openai_compat import ProviderTimeoutError as OpenAITimeoutError
|
||||
from app.schemas.chat import ChatRequest, ChatResponse
|
||||
from app.services.config_reader import load_ai_config
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_FENCE_RE = re.compile(r"^```[a-z]*\n?(.*?)\n?```$", re.DOTALL)
|
||||
|
||||
|
||||
def _strip_fences(text: str) -> str:
|
||||
m = _FENCE_RE.match(text.strip())
|
||||
return m.group(1).strip() if m else text.strip()
|
||||
|
||||
|
||||
@router.post("/chat", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest) -> ChatResponse:
|
||||
config = await load_ai_config()
|
||||
|
||||
provider_name = config.get("provider", "lmstudio")
|
||||
if provider_name not in ("anthropic", "ollama", "lmstudio"):
|
||||
raise HTTPException(status_code=503, detail=f"Unknown provider configured: {provider_name!r}")
|
||||
|
||||
try:
|
||||
provider = get_provider(config)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=503, detail=str(exc))
|
||||
|
||||
timeout = config.get("timeout_seconds", 60)
|
||||
max_retries = config.get("max_retries", 2)
|
||||
last_exc: Exception | None = None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
content, input_tokens, output_tokens = await asyncio.wait_for(
|
||||
provider.chat(request.messages, request.max_tokens, request.temperature),
|
||||
timeout=float(timeout),
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError as exc:
|
||||
last_exc = exc
|
||||
# Don't retry on timeout — the model is busy; fail fast
|
||||
raise HTTPException(status_code=504, detail="AI provider timed out") from exc
|
||||
except (AnthropicConnError, OpenAIConnError) as exc:
|
||||
last_exc = exc
|
||||
if attempt < max_retries:
|
||||
await asyncio.sleep(0.5 * (attempt + 1))
|
||||
continue
|
||||
raise HTTPException(status_code=502, detail=f"AI provider error: {exc}") from exc
|
||||
except (AnthropicTimeoutError, OpenAITimeoutError) as exc:
|
||||
raise HTTPException(status_code=504, detail="AI provider timed out") from exc
|
||||
|
||||
if request.response_format == "json":
|
||||
content = _strip_fences(content)
|
||||
|
||||
return ChatResponse(
|
||||
content=content,
|
||||
provider=provider.provider_name,
|
||||
model=provider.model_name,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.services.config_reader import load_ai_config
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health() -> dict:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get("/health/provider")
|
||||
async def provider_status() -> dict:
|
||||
config = await load_ai_config()
|
||||
provider = config.get("provider", "lmstudio")
|
||||
pcfg = config.get(provider, {})
|
||||
model = pcfg.get("model", "")
|
||||
|
||||
# "configured" means we have the minimum required fields for the provider
|
||||
if provider == "anthropic":
|
||||
configured = bool(pcfg.get("api_key"))
|
||||
else:
|
||||
configured = bool(pcfg.get("base_url") and pcfg.get("model"))
|
||||
|
||||
return {
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"configured": configured,
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant"]
|
||||
content: str
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: list[ChatMessage]
|
||||
max_tokens: int = 2048
|
||||
temperature: float = 0.0
|
||||
response_format: Literal["json", "text"] = "text"
|
||||
|
||||
@field_validator("messages")
|
||||
@classmethod
|
||||
def messages_not_empty(cls, v: list) -> list:
|
||||
if not v:
|
||||
raise ValueError("messages must not be empty")
|
||||
return v
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
content: str
|
||||
provider: str
|
||||
model: str
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Reads ai_service_config.json from the shared config volume.
|
||||
30-second TTL cache + env var overrides (dev credentials stay out of git).
|
||||
|
||||
Env var overrides (all optional):
|
||||
AI_PROVIDER — "lmstudio" | "ollama" | "anthropic"
|
||||
LMSTUDIO_BASE_URL, LMSTUDIO_API_KEY, LMSTUDIO_MODEL
|
||||
OLLAMA_BASE_URL, OLLAMA_MODEL, OLLAMA_API_KEY
|
||||
ANTHROPIC_API_KEY, ANTHROPIC_MODEL
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
_DEFAULT_CONFIG: dict = {
|
||||
"provider": "lmstudio",
|
||||
"timeout_seconds": 60,
|
||||
"max_retries": 2,
|
||||
"anthropic": {"api_key": "", "model": "claude-haiku-4-5-20251001"},
|
||||
"ollama": {"base_url": "http://host.docker.internal:11434/v1", "model": "llama3.2", "api_key": "ollama"},
|
||||
"lmstudio": {"base_url": "http://host.docker.internal:1234/v1", "model": "local-model", "api_key": "lm-studio"},
|
||||
}
|
||||
|
||||
_cache: dict | None = None
|
||||
_cache_at: float = 0.0
|
||||
_CACHE_TTL = 30.0
|
||||
|
||||
|
||||
def _read_config_sync() -> dict:
|
||||
path = Path(settings.CONFIG_PATH)
|
||||
if not path.exists():
|
||||
return _apply_env_overrides(deepcopy(_DEFAULT_CONFIG))
|
||||
with open(path) as f:
|
||||
return _apply_env_overrides(json.load(f))
|
||||
|
||||
|
||||
def _apply_env_overrides(config: dict) -> dict:
|
||||
cfg = deepcopy(config)
|
||||
|
||||
if v := os.environ.get("AI_PROVIDER"):
|
||||
cfg["provider"] = v
|
||||
|
||||
lms = cfg.setdefault("lmstudio", {})
|
||||
if v := os.environ.get("LMSTUDIO_BASE_URL"):
|
||||
lms["base_url"] = v
|
||||
if v := os.environ.get("LMSTUDIO_API_KEY"):
|
||||
lms["api_key"] = v
|
||||
if v := os.environ.get("LMSTUDIO_MODEL"):
|
||||
lms["model"] = v
|
||||
|
||||
oll = cfg.setdefault("ollama", {})
|
||||
if v := os.environ.get("OLLAMA_BASE_URL"):
|
||||
oll["base_url"] = v
|
||||
if v := os.environ.get("OLLAMA_MODEL"):
|
||||
oll["model"] = v
|
||||
if v := os.environ.get("OLLAMA_API_KEY"):
|
||||
oll["api_key"] = v
|
||||
|
||||
ant = cfg.setdefault("anthropic", {})
|
||||
if v := os.environ.get("ANTHROPIC_API_KEY"):
|
||||
ant["api_key"] = v
|
||||
if v := os.environ.get("ANTHROPIC_MODEL"):
|
||||
ant["model"] = v
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
async def load_ai_config() -> dict:
|
||||
global _cache, _cache_at
|
||||
now = time.monotonic()
|
||||
if _cache is not None and (now - _cache_at) < _CACHE_TTL:
|
||||
return _cache
|
||||
data = await asyncio.to_thread(_read_config_sync)
|
||||
_cache = data
|
||||
_cache_at = now
|
||||
return data
|
||||
Reference in New Issue
Block a user