1882edfff6
- backend/api/auth.py: register, login (TOTP+backup), refresh, logout, me, change-password; per-account Redis rate limit; HIBP check - backend/main.py: Origin validation middleware, CSP headers middleware, CORS locked to settings.cors_origins, Redis lifespan (app.state.redis), admin bootstrap, auth router included, slowapi SlowAPIMiddleware - backend/services/email.py: already created in Plan 01 (verified exists) - Python 3.9 compat: fixed match statement in ai/__init__.py, str|None union syntax in openai_provider.py, api/documents.py, api/topics.py, api/settings.py, services/classifier.py All 17 tests in test_auth_api.py pass. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
105 lines
3.3 KiB
Python
105 lines
3.3 KiB
Python
import json
|
|
import re
|
|
from openai import AsyncOpenAI
|
|
from ai.base import AIProvider, ClassificationResult
|
|
|
|
MAX_AI_CHARS = 8_000
|
|
|
|
|
|
class OpenAIProvider(AIProvider):
|
|
def __init__(self, api_key: str, model: str = "gpt-4o", base_url=None): # type: ignore[type-arg]
|
|
self._api_key = api_key
|
|
self._model = model
|
|
self._base_url = base_url
|
|
|
|
def _client(self) -> AsyncOpenAI:
|
|
return AsyncOpenAI(api_key=self._api_key or "placeholder", base_url=self._base_url)
|
|
|
|
async def classify(
|
|
self,
|
|
document_text: str,
|
|
existing_topics: list[str],
|
|
system_prompt: str,
|
|
) -> ClassificationResult:
|
|
topics_str = ", ".join(existing_topics) if existing_topics else "(none yet)"
|
|
user_msg = (
|
|
f"Existing topics: [{topics_str}]\n\n"
|
|
f"Document text:\n{document_text[:MAX_AI_CHARS]}"
|
|
)
|
|
response = await self._client().chat.completions.create(
|
|
model=self._model,
|
|
max_tokens=1024,
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_msg},
|
|
],
|
|
)
|
|
raw = response.choices[0].message.content or ""
|
|
return _parse_classification(raw)
|
|
|
|
async def suggest_topics(
|
|
self,
|
|
document_text: str,
|
|
system_prompt: str,
|
|
) -> list[str]:
|
|
user_msg = (
|
|
"Suggest 3-5 topic names for this document. "
|
|
"Return ONLY valid JSON: {\"suggested_topics\": [\"topic1\", \"topic2\"]}\n\n"
|
|
f"Document text:\n{document_text[:MAX_AI_CHARS]}"
|
|
)
|
|
response = await self._client().chat.completions.create(
|
|
model=self._model,
|
|
max_tokens=256,
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_msg},
|
|
],
|
|
)
|
|
raw = response.choices[0].message.content or ""
|
|
return _parse_suggestions(raw)
|
|
|
|
async def health_check(self) -> bool:
|
|
try:
|
|
await self._client().chat.completions.create(
|
|
model=self._model,
|
|
max_tokens=5,
|
|
messages=[{"role": "user", "content": "ping"}],
|
|
)
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def _strip_code_fences(text: str) -> str:
|
|
text = re.sub(r"```(?:json)?\s*", "", text)
|
|
text = re.sub(r"```", "", text)
|
|
return text.strip()
|
|
|
|
|
|
def _parse_classification(raw: str) -> ClassificationResult:
|
|
raw = _strip_code_fences(raw)
|
|
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
|
if match:
|
|
try:
|
|
data = json.loads(match.group())
|
|
return ClassificationResult(
|
|
topics=data.get("assigned_topics", []),
|
|
suggested_new_topics=data.get("new_topic_suggestions", []),
|
|
reasoning=data.get("reasoning", ""),
|
|
)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
return ClassificationResult()
|
|
|
|
|
|
def _parse_suggestions(raw: str) -> list[str]:
|
|
raw = _strip_code_fences(raw)
|
|
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
|
if match:
|
|
try:
|
|
data = json.loads(match.group())
|
|
return data.get("suggested_topics", [])
|
|
except json.JSONDecodeError:
|
|
pass
|
|
return []
|