Files
kite/backend/ai/openai_provider.py
T
2026-05-22 08:53:28 +02:00

105 lines
3.3 KiB
Python

import json
import re
from openai import AsyncOpenAI
from ai.base import AIProvider, ClassificationResult
MAX_AI_CHARS = 8_000
class OpenAIProvider(AIProvider):
def __init__(self, api_key: str, model: str = "gpt-4o", base_url: str | None = None):
self._api_key = api_key
self._model = model
self._base_url = base_url
def _client(self) -> AsyncOpenAI:
return AsyncOpenAI(api_key=self._api_key or "placeholder", base_url=self._base_url)
async def classify(
self,
document_text: str,
existing_topics: list[str],
system_prompt: str,
) -> ClassificationResult:
topics_str = ", ".join(existing_topics) if existing_topics else "(none yet)"
user_msg = (
f"Existing topics: [{topics_str}]\n\n"
f"Document text:\n{document_text[:MAX_AI_CHARS]}"
)
response = await self._client().chat.completions.create(
model=self._model,
max_tokens=1024,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_msg},
],
)
raw = response.choices[0].message.content or ""
return _parse_classification(raw)
async def suggest_topics(
self,
document_text: str,
system_prompt: str,
) -> list[str]:
user_msg = (
"Suggest 3-5 topic names for this document. "
"Return ONLY valid JSON: {\"suggested_topics\": [\"topic1\", \"topic2\"]}\n\n"
f"Document text:\n{document_text[:MAX_AI_CHARS]}"
)
response = await self._client().chat.completions.create(
model=self._model,
max_tokens=256,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_msg},
],
)
raw = response.choices[0].message.content or ""
return _parse_suggestions(raw)
async def health_check(self) -> bool:
try:
await self._client().chat.completions.create(
model=self._model,
max_tokens=5,
messages=[{"role": "user", "content": "ping"}],
)
return True
except Exception:
return False
def _strip_code_fences(text: str) -> str:
text = re.sub(r"```(?:json)?\s*", "", text)
text = re.sub(r"```", "", text)
return text.strip()
def _parse_classification(raw: str) -> ClassificationResult:
raw = _strip_code_fences(raw)
match = re.search(r"\{.*\}", raw, re.DOTALL)
if match:
try:
data = json.loads(match.group())
return ClassificationResult(
topics=data.get("assigned_topics", []),
suggested_new_topics=data.get("new_topic_suggestions", []),
reasoning=data.get("reasoning", ""),
)
except json.JSONDecodeError:
pass
return ClassificationResult()
def _parse_suggestions(raw: str) -> list[str]:
raw = _strip_code_fences(raw)
match = re.search(r"\{.*\}", raw, re.DOTALL)
if match:
try:
data = json.loads(match.group())
return data.get("suggested_topics", [])
except json.JSONDecodeError:
pass
return []