import json import re from openai import AsyncOpenAI from ai.base import AIProvider, ClassificationResult MAX_AI_CHARS = 8_000 class OpenAIProvider(AIProvider): def __init__(self, api_key: str, model: str = "gpt-4o", base_url: str | None = None): self._api_key = api_key self._model = model self._base_url = base_url def _client(self) -> AsyncOpenAI: return AsyncOpenAI(api_key=self._api_key or "placeholder", base_url=self._base_url) async def classify( self, document_text: str, existing_topics: list[str], system_prompt: str, ) -> ClassificationResult: topics_str = ", ".join(existing_topics) if existing_topics else "(none yet)" user_msg = ( f"Existing topics: [{topics_str}]\n\n" f"Document text:\n{document_text[:MAX_AI_CHARS]}" ) response = await self._client().chat.completions.create( model=self._model, max_tokens=1024, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_msg}, ], ) raw = response.choices[0].message.content or "" return _parse_classification(raw) async def suggest_topics( self, document_text: str, system_prompt: str, ) -> list[str]: user_msg = ( "Suggest 3-5 topic names for this document. " "Return ONLY valid JSON: {\"suggested_topics\": [\"topic1\", \"topic2\"]}\n\n" f"Document text:\n{document_text[:MAX_AI_CHARS]}" ) response = await self._client().chat.completions.create( model=self._model, max_tokens=256, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_msg}, ], ) raw = response.choices[0].message.content or "" return _parse_suggestions(raw) async def health_check(self) -> bool: try: await self._client().chat.completions.create( model=self._model, max_tokens=5, messages=[{"role": "user", "content": "ping"}], ) return True except Exception: return False def _strip_code_fences(text: str) -> str: text = re.sub(r"```(?:json)?\s*", "", text) text = re.sub(r"```", "", text) return text.strip() def _parse_classification(raw: str) -> ClassificationResult: raw = _strip_code_fences(raw) match = re.search(r"\{.*\}", raw, re.DOTALL) if match: try: data = json.loads(match.group()) return ClassificationResult( topics=data.get("assigned_topics", []), suggested_new_topics=data.get("new_topic_suggestions", []), reasoning=data.get("reasoning", ""), ) except json.JSONDecodeError: pass return ClassificationResult() def _parse_suggestions(raw: str) -> list[str]: raw = _strip_code_fences(raw) match = re.search(r"\{.*\}", raw, re.DOTALL) if match: try: data = json.loads(match.group()) return data.get("suggested_topics", []) except json.JSONDecodeError: pass return []