chore: initial commit — existing single-user document scanner codebase
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,36 @@
|
||||
from ai.base import AIProvider, ClassificationResult
|
||||
from ai.anthropic_provider import AnthropicProvider
|
||||
from ai.openai_provider import OpenAIProvider
|
||||
from ai.ollama_provider import OllamaProvider
|
||||
from ai.lmstudio_provider import LMStudioProvider
|
||||
|
||||
|
||||
def get_provider(settings: dict) -> AIProvider:
|
||||
active = settings.get("active_provider", "lmstudio")
|
||||
providers = settings.get("providers", {})
|
||||
cfg = providers.get(active, {})
|
||||
|
||||
match active:
|
||||
case "anthropic":
|
||||
return AnthropicProvider(
|
||||
api_key=cfg.get("api_key", ""),
|
||||
model=cfg.get("model", "claude-sonnet-4-6"),
|
||||
)
|
||||
case "openai":
|
||||
return OpenAIProvider(
|
||||
api_key=cfg.get("api_key", ""),
|
||||
model=cfg.get("model", "gpt-4o"),
|
||||
base_url=cfg.get("base_url") or None,
|
||||
)
|
||||
case "ollama":
|
||||
return OllamaProvider(
|
||||
base_url=cfg.get("base_url", "http://host.docker.internal:11434"),
|
||||
model=cfg.get("model", "llama3.2"),
|
||||
)
|
||||
case "lmstudio":
|
||||
return LMStudioProvider(
|
||||
base_url=cfg.get("base_url", "http://host.docker.internal:1234"),
|
||||
model=cfg.get("model", "gemma-4-e4b-it"),
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unknown AI provider: {active}")
|
||||
@@ -0,0 +1,103 @@
|
||||
import json
|
||||
import re
|
||||
import anthropic
|
||||
from ai.base import AIProvider, ClassificationResult
|
||||
|
||||
MAX_AI_CHARS = 8_000
|
||||
|
||||
|
||||
class AnthropicProvider(AIProvider):
|
||||
def __init__(self, api_key: str, model: str = "claude-sonnet-4-6"):
|
||||
self._api_key = api_key
|
||||
self._model = model
|
||||
|
||||
def _client(self):
|
||||
return anthropic.AsyncAnthropic(api_key=self._api_key)
|
||||
|
||||
async def classify(
|
||||
self,
|
||||
document_text: str,
|
||||
existing_topics: list[str],
|
||||
system_prompt: str,
|
||||
) -> ClassificationResult:
|
||||
topics_str = ", ".join(existing_topics) if existing_topics else "(none yet)"
|
||||
user_msg = (
|
||||
f"Existing topics: [{topics_str}]\n\n"
|
||||
f"Document text:\n{document_text[:MAX_AI_CHARS]}"
|
||||
)
|
||||
client = self._client()
|
||||
response = await client.messages.create(
|
||||
model=self._model,
|
||||
max_tokens=1024,
|
||||
system=system_prompt,
|
||||
messages=[{"role": "user", "content": user_msg}],
|
||||
)
|
||||
raw = response.content[0].text
|
||||
return _parse_classification(raw)
|
||||
|
||||
async def suggest_topics(
|
||||
self,
|
||||
document_text: str,
|
||||
system_prompt: str,
|
||||
) -> list[str]:
|
||||
user_msg = (
|
||||
"Suggest 3-5 topic names for this document. "
|
||||
"Return ONLY valid JSON: {\"suggested_topics\": [\"topic1\", \"topic2\"]}\n\n"
|
||||
f"Document text:\n{document_text[:MAX_AI_CHARS]}"
|
||||
)
|
||||
client = self._client()
|
||||
response = await client.messages.create(
|
||||
model=self._model,
|
||||
max_tokens=256,
|
||||
system=system_prompt,
|
||||
messages=[{"role": "user", "content": user_msg}],
|
||||
)
|
||||
raw = response.content[0].text
|
||||
return _parse_suggestions(raw)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
try:
|
||||
client = self._client()
|
||||
await client.messages.create(
|
||||
model=self._model,
|
||||
max_tokens=5,
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _strip_code_fences(text: str) -> str:
|
||||
text = re.sub(r"```(?:json)?\s*", "", text)
|
||||
text = re.sub(r"```", "", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _parse_classification(raw: str) -> ClassificationResult:
|
||||
raw = _strip_code_fences(raw)
|
||||
# Try to find JSON object
|
||||
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
data = json.loads(match.group())
|
||||
return ClassificationResult(
|
||||
topics=data.get("assigned_topics", []),
|
||||
suggested_new_topics=data.get("new_topic_suggestions", []),
|
||||
reasoning=data.get("reasoning", ""),
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return ClassificationResult()
|
||||
|
||||
|
||||
def _parse_suggestions(raw: str) -> list[str]:
|
||||
raw = _strip_code_fences(raw)
|
||||
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
data = json.loads(match.group())
|
||||
return data.get("suggested_topics", [])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return []
|
||||
@@ -0,0 +1,32 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassificationResult:
|
||||
topics: list[str] = field(default_factory=list)
|
||||
suggested_new_topics: list[str] = field(default_factory=list)
|
||||
reasoning: str = ""
|
||||
|
||||
|
||||
class AIProvider(ABC):
|
||||
@abstractmethod
|
||||
async def classify(
|
||||
self,
|
||||
document_text: str,
|
||||
existing_topics: list[str],
|
||||
system_prompt: str,
|
||||
) -> ClassificationResult:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def suggest_topics(
|
||||
self,
|
||||
document_text: str,
|
||||
system_prompt: str,
|
||||
) -> list[str]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> bool:
|
||||
...
|
||||
@@ -0,0 +1,10 @@
|
||||
from ai.openai_provider import OpenAIProvider
|
||||
|
||||
|
||||
class LMStudioProvider(OpenAIProvider):
|
||||
def __init__(self, base_url: str = "http://host.docker.internal:1234", model: str = "gemma-4-e4b-it"):
|
||||
super().__init__(
|
||||
api_key="lm-studio",
|
||||
model=model,
|
||||
base_url=base_url.rstrip("/") + "/v1",
|
||||
)
|
||||
@@ -0,0 +1,10 @@
|
||||
from ai.openai_provider import OpenAIProvider
|
||||
|
||||
|
||||
class OllamaProvider(OpenAIProvider):
|
||||
def __init__(self, base_url: str = "http://host.docker.internal:11434", model: str = "llama3.2"):
|
||||
super().__init__(
|
||||
api_key="ollama",
|
||||
model=model,
|
||||
base_url=base_url.rstrip("/") + "/v1",
|
||||
)
|
||||
@@ -0,0 +1,104 @@
|
||||
import json
|
||||
import re
|
||||
from openai import AsyncOpenAI
|
||||
from ai.base import AIProvider, ClassificationResult
|
||||
|
||||
MAX_AI_CHARS = 8_000
|
||||
|
||||
|
||||
class OpenAIProvider(AIProvider):
|
||||
def __init__(self, api_key: str, model: str = "gpt-4o", base_url: str | None = None):
|
||||
self._api_key = api_key
|
||||
self._model = model
|
||||
self._base_url = base_url
|
||||
|
||||
def _client(self) -> AsyncOpenAI:
|
||||
return AsyncOpenAI(api_key=self._api_key or "placeholder", base_url=self._base_url)
|
||||
|
||||
async def classify(
|
||||
self,
|
||||
document_text: str,
|
||||
existing_topics: list[str],
|
||||
system_prompt: str,
|
||||
) -> ClassificationResult:
|
||||
topics_str = ", ".join(existing_topics) if existing_topics else "(none yet)"
|
||||
user_msg = (
|
||||
f"Existing topics: [{topics_str}]\n\n"
|
||||
f"Document text:\n{document_text[:MAX_AI_CHARS]}"
|
||||
)
|
||||
response = await self._client().chat.completions.create(
|
||||
model=self._model,
|
||||
max_tokens=1024,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_msg},
|
||||
],
|
||||
)
|
||||
raw = response.choices[0].message.content or ""
|
||||
return _parse_classification(raw)
|
||||
|
||||
async def suggest_topics(
|
||||
self,
|
||||
document_text: str,
|
||||
system_prompt: str,
|
||||
) -> list[str]:
|
||||
user_msg = (
|
||||
"Suggest 3-5 topic names for this document. "
|
||||
"Return ONLY valid JSON: {\"suggested_topics\": [\"topic1\", \"topic2\"]}\n\n"
|
||||
f"Document text:\n{document_text[:MAX_AI_CHARS]}"
|
||||
)
|
||||
response = await self._client().chat.completions.create(
|
||||
model=self._model,
|
||||
max_tokens=256,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_msg},
|
||||
],
|
||||
)
|
||||
raw = response.choices[0].message.content or ""
|
||||
return _parse_suggestions(raw)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
try:
|
||||
await self._client().chat.completions.create(
|
||||
model=self._model,
|
||||
max_tokens=5,
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _strip_code_fences(text: str) -> str:
|
||||
text = re.sub(r"```(?:json)?\s*", "", text)
|
||||
text = re.sub(r"```", "", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _parse_classification(raw: str) -> ClassificationResult:
|
||||
raw = _strip_code_fences(raw)
|
||||
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
data = json.loads(match.group())
|
||||
return ClassificationResult(
|
||||
topics=data.get("assigned_topics", []),
|
||||
suggested_new_topics=data.get("new_topic_suggestions", []),
|
||||
reasoning=data.get("reasoning", ""),
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return ClassificationResult()
|
||||
|
||||
|
||||
def _parse_suggestions(raw: str) -> list[str]:
|
||||
raw = _strip_code_fences(raw)
|
||||
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
data = json.loads(match.group())
|
||||
return data.get("suggested_topics", [])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return []
|
||||
Reference in New Issue
Block a user