import json import re import anthropic from ai.base import AIProvider, ClassificationResult MAX_AI_CHARS = 8_000 class AnthropicProvider(AIProvider): def __init__(self, api_key: str, model: str = "claude-sonnet-4-6"): self._api_key = api_key self._model = model def _client(self): return anthropic.AsyncAnthropic(api_key=self._api_key) async def classify( self, document_text: str, existing_topics: list[str], system_prompt: str, ) -> ClassificationResult: topics_str = ", ".join(existing_topics) if existing_topics else "(none yet)" user_msg = ( f"Existing topics: [{topics_str}]\n\n" f"Document text:\n{document_text[:MAX_AI_CHARS]}" ) client = self._client() response = await client.messages.create( model=self._model, max_tokens=1024, system=system_prompt, messages=[{"role": "user", "content": user_msg}], ) raw = response.content[0].text return _parse_classification(raw) async def suggest_topics( self, document_text: str, system_prompt: str, ) -> list[str]: user_msg = ( "Suggest 3-5 topic names for this document. " "Return ONLY valid JSON: {\"suggested_topics\": [\"topic1\", \"topic2\"]}\n\n" f"Document text:\n{document_text[:MAX_AI_CHARS]}" ) client = self._client() response = await client.messages.create( model=self._model, max_tokens=256, system=system_prompt, messages=[{"role": "user", "content": user_msg}], ) raw = response.content[0].text return _parse_suggestions(raw) async def health_check(self) -> bool: try: client = self._client() await client.messages.create( model=self._model, max_tokens=5, messages=[{"role": "user", "content": "ping"}], ) return True except Exception: return False def _strip_code_fences(text: str) -> str: text = re.sub(r"```(?:json)?\s*", "", text) text = re.sub(r"```", "", text) return text.strip() def _parse_classification(raw: str) -> ClassificationResult: raw = _strip_code_fences(raw) # Try to find JSON object match = re.search(r"\{.*\}", raw, re.DOTALL) if match: try: data = json.loads(match.group()) return ClassificationResult( topics=data.get("assigned_topics", []), suggested_new_topics=data.get("new_topic_suggestions", []), reasoning=data.get("reasoning", ""), ) except json.JSONDecodeError: pass return ClassificationResult() def _parse_suggestions(raw: str) -> list[str]: raw = _strip_code_fences(raw) match = re.search(r"\{.*\}", raw, re.DOTALL) if match: try: data = json.loads(match.group()) return data.get("suggested_topics", []) except json.JSONDecodeError: pass return []