feat(setup): provider registry and interactive setup wizard
- setup/providers.py: registry for 8 providers (3 local, 5 cloud), frozen dataclasses - setup/wizard.py: questionary-based wizard — provider select, model input, API key collected via vault.writer (not config.yaml), connectivity check, test call via litellm Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,104 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Provider:
|
||||||
|
id: str
|
||||||
|
display_name: str
|
||||||
|
requires_key: bool
|
||||||
|
default_model: str
|
||||||
|
litellm_prefix: str
|
||||||
|
base_url: str | None = None
|
||||||
|
key_env_var: str | None = None
|
||||||
|
connectivity_check: str | None = None
|
||||||
|
group: str = "Cloud"
|
||||||
|
|
||||||
|
|
||||||
|
PROVIDERS: list[Provider] = [
|
||||||
|
# ── Local ────────────────────────────────────────────────────────────────
|
||||||
|
Provider(
|
||||||
|
id="lmstudio",
|
||||||
|
display_name="LM Studio (local)",
|
||||||
|
requires_key=False,
|
||||||
|
default_model="local-model",
|
||||||
|
litellm_prefix="openai/",
|
||||||
|
base_url="http://localhost:1234/v1",
|
||||||
|
connectivity_check="http://localhost:1234/v1/models",
|
||||||
|
group="Local",
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
id="ollama",
|
||||||
|
display_name="Ollama (local)",
|
||||||
|
requires_key=False,
|
||||||
|
default_model="llama3",
|
||||||
|
litellm_prefix="ollama/",
|
||||||
|
base_url="http://localhost:11434",
|
||||||
|
connectivity_check="http://localhost:11434/api/tags",
|
||||||
|
group="Local",
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
id="llamacpp",
|
||||||
|
display_name="llama.cpp server (local)",
|
||||||
|
requires_key=False,
|
||||||
|
default_model="local-model",
|
||||||
|
litellm_prefix="openai/",
|
||||||
|
base_url="http://localhost:8080/v1",
|
||||||
|
connectivity_check="http://localhost:8080/v1/models",
|
||||||
|
group="Local",
|
||||||
|
),
|
||||||
|
# ── Cloud ────────────────────────────────────────────────────────────────
|
||||||
|
Provider(
|
||||||
|
id="anthropic",
|
||||||
|
display_name="Anthropic (Claude)",
|
||||||
|
requires_key=True,
|
||||||
|
default_model="claude-sonnet-4-6",
|
||||||
|
litellm_prefix="anthropic/",
|
||||||
|
key_env_var="ANTHROPIC_API_KEY",
|
||||||
|
group="Cloud",
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
id="openai",
|
||||||
|
display_name="OpenAI (GPT)",
|
||||||
|
requires_key=True,
|
||||||
|
default_model="gpt-4o",
|
||||||
|
litellm_prefix="openai/",
|
||||||
|
key_env_var="OPENAI_API_KEY",
|
||||||
|
group="Cloud",
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
id="gemini",
|
||||||
|
display_name="Google (Gemini)",
|
||||||
|
requires_key=True,
|
||||||
|
default_model="gemini/gemini-2.0-flash",
|
||||||
|
litellm_prefix="gemini/",
|
||||||
|
key_env_var="GEMINI_API_KEY",
|
||||||
|
group="Cloud",
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
id="deepseek",
|
||||||
|
display_name="DeepSeek",
|
||||||
|
requires_key=True,
|
||||||
|
default_model="deepseek/deepseek-chat",
|
||||||
|
litellm_prefix="deepseek/",
|
||||||
|
key_env_var="DEEPSEEK_API_KEY",
|
||||||
|
group="Cloud",
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
id="qwen",
|
||||||
|
display_name="Qwen (Alibaba)",
|
||||||
|
requires_key=True,
|
||||||
|
default_model="openai/qwen-plus",
|
||||||
|
litellm_prefix="openai/",
|
||||||
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
|
key_env_var="DASHSCOPE_API_KEY",
|
||||||
|
group="Cloud",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
PROVIDERS_BY_ID: dict[str, Provider] = {p.id: p for p in PROVIDERS}
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider(provider_id: str) -> Provider:
|
||||||
|
if provider_id not in PROVIDERS_BY_ID:
|
||||||
|
raise KeyError(f"Unknown provider: {provider_id!r}")
|
||||||
|
return PROVIDERS_BY_ID[provider_id]
|
||||||
@@ -0,0 +1,139 @@
|
|||||||
|
import httpx
|
||||||
|
import questionary
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
from pyra.config.manager import save_config
|
||||||
|
from pyra.config.schema import ProviderConfig, PyraConfig
|
||||||
|
from pyra.setup.providers import PROVIDERS, Provider, get_provider
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
def run_setup() -> None:
|
||||||
|
console.print(Panel(
|
||||||
|
Text("Welcome to Pyra Setup", justify="center", style="bold cyan"),
|
||||||
|
subtitle="Personal AI Assistant",
|
||||||
|
border_style="cyan",
|
||||||
|
))
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
provider = _choose_provider()
|
||||||
|
model = _choose_model(provider)
|
||||||
|
|
||||||
|
if provider.requires_key:
|
||||||
|
_collect_api_key(provider)
|
||||||
|
|
||||||
|
_test_connection(provider, model)
|
||||||
|
|
||||||
|
cfg = PyraConfig(
|
||||||
|
ai=ProviderConfig(
|
||||||
|
provider_id=provider.id,
|
||||||
|
model=model,
|
||||||
|
base_url=provider.base_url,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
save_config(cfg)
|
||||||
|
|
||||||
|
console.print()
|
||||||
|
console.print(Panel(
|
||||||
|
f"[green]Setup complete![/green]\n\n"
|
||||||
|
f"Provider: [bold]{provider.display_name}[/bold]\n"
|
||||||
|
f"Model: [bold]{model}[/bold]\n\n"
|
||||||
|
"Run [bold cyan]pyra chat[/bold cyan] to start talking.",
|
||||||
|
border_style="green",
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
def _choose_provider() -> Provider:
|
||||||
|
local = [p for p in PROVIDERS if p.group == "Local"]
|
||||||
|
cloud = [p for p in PROVIDERS if p.group == "Cloud"]
|
||||||
|
|
||||||
|
choices = (
|
||||||
|
[questionary.Choice("── Local ──────────────────", disabled=True)]
|
||||||
|
+ [questionary.Choice(p.display_name, value=p.id) for p in local]
|
||||||
|
+ [questionary.Choice("── Cloud ──────────────────", disabled=True)]
|
||||||
|
+ [questionary.Choice(p.display_name, value=p.id) for p in cloud]
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_id = questionary.select(
|
||||||
|
"Choose your AI provider:",
|
||||||
|
choices=choices,
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
if provider_id is None:
|
||||||
|
raise SystemExit(0)
|
||||||
|
|
||||||
|
provider = get_provider(provider_id)
|
||||||
|
|
||||||
|
if provider.connectivity_check:
|
||||||
|
_check_local_server(provider)
|
||||||
|
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
def _check_local_server(provider: Provider) -> None:
|
||||||
|
console.print(f" Checking connection to [bold]{provider.display_name}[/bold]...", end=" ")
|
||||||
|
try:
|
||||||
|
resp = httpx.get(provider.connectivity_check, timeout=3.0)
|
||||||
|
resp.raise_for_status()
|
||||||
|
console.print("[green]✓[/green]")
|
||||||
|
except Exception:
|
||||||
|
console.print("[yellow]✗ (server not reachable)[/yellow]")
|
||||||
|
console.print(
|
||||||
|
f" [yellow]Warning:[/yellow] Could not reach {provider.base_url}.\n"
|
||||||
|
f" Make sure {provider.display_name} is running before using Pyra."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _choose_model(provider: Provider) -> str:
|
||||||
|
model = questionary.text(
|
||||||
|
"Model name:",
|
||||||
|
default=provider.default_model,
|
||||||
|
).ask()
|
||||||
|
if model is None:
|
||||||
|
raise SystemExit(0)
|
||||||
|
return model.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_api_key(provider: Provider) -> None:
|
||||||
|
from pyra.vault.writer import set_key
|
||||||
|
|
||||||
|
console.print(
|
||||||
|
f"\n [dim]API key will be stored in the encrypted vault — never in config.yaml[/dim]"
|
||||||
|
)
|
||||||
|
key = questionary.password(f"Enter your {provider.display_name} API key:").ask()
|
||||||
|
if key is None:
|
||||||
|
raise SystemExit(0)
|
||||||
|
key = key.strip()
|
||||||
|
if not key:
|
||||||
|
console.print("[red]No key entered — skipping.[/red]")
|
||||||
|
return
|
||||||
|
set_key(provider.id, key)
|
||||||
|
console.print(" [green]✓ Key stored in vault[/green]")
|
||||||
|
|
||||||
|
|
||||||
|
def _test_connection(provider: Provider, model: str) -> None:
|
||||||
|
from pyra.vault.reader import get_key
|
||||||
|
|
||||||
|
console.print("\n Running test call...", end=" ")
|
||||||
|
try:
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
api_key = get_key(provider.id) if provider.requires_key else "no-key"
|
||||||
|
kwargs: dict = {
|
||||||
|
"model": f"{provider.litellm_prefix}{model}",
|
||||||
|
"messages": [{"role": "user", "content": "Reply with exactly: OK"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
}
|
||||||
|
if provider.base_url:
|
||||||
|
kwargs["api_base"] = provider.base_url
|
||||||
|
if api_key and api_key != "no-key":
|
||||||
|
kwargs["api_key"] = api_key
|
||||||
|
|
||||||
|
litellm.completion(**kwargs)
|
||||||
|
console.print("[green]✓ Connection OK[/green]")
|
||||||
|
except Exception as exc:
|
||||||
|
console.print(f"[yellow]✗ Test call failed: {exc}[/yellow]")
|
||||||
|
console.print(" [dim]You can still proceed — check your config with 'pyra setup' again.[/dim]")
|
||||||
Reference in New Issue
Block a user