019e8044a9
- _test_connection() now returns the (possibly changed) model name and offers a "Change model" option when the error is model-related - _show_local_model_status() prints which models are currently loaded immediately after selecting a local provider - Draft persistence: each completed wizard step is saved to ~/.pyra/setup.draft.json (chmod 600); on the next run a yellow panel summarises progress and offers [Resume / Start fresh]; draft is deleted on successful completion or Ctrl-C with no completed steps Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
568 lines
20 KiB
Python
568 lines
20 KiB
Python
import contextlib
|
|
import json
|
|
|
|
import httpx
|
|
import questionary
|
|
from rich.console import Console
|
|
from rich.panel import Panel
|
|
from rich.text import Text
|
|
|
|
from pyra.config.manager import save_config
|
|
from pyra.config.schema import GeneralConfig, ProviderConfig, PyraConfig
|
|
from pyra.setup.providers import PROVIDERS, Provider, get_provider
|
|
from pyra.utils.paths import pyra_home, safe_chmod
|
|
|
|
console = Console()
|
|
|
|
_USE_CASE_PLUGINS: dict[str, list[str]] = {
|
|
"Research & web": ["websearch", "headless_browser"],
|
|
"Development & servers": ["server_manager", "ssh_tool", "docker_tool"],
|
|
"File management": ["gdrive", "onedrive", "dropbox_tool"],
|
|
"Communication bots": ["matrix_bot", "telegram_bot", "signal_bot"],
|
|
"Email": ["email"],
|
|
"Productivity & calendars": ["nextcloud"],
|
|
}
|
|
|
|
|
|
_DRAFT_FILE = "setup.draft.json"
|
|
|
|
|
|
def _draft_path():
|
|
return pyra_home() / _DRAFT_FILE
|
|
|
|
|
|
def _save_draft(state: dict) -> None:
|
|
path = _draft_path()
|
|
path.write_text(json.dumps(state, indent=2))
|
|
safe_chmod(path, 0o600)
|
|
|
|
|
|
def _load_draft() -> dict | None:
|
|
path = _draft_path()
|
|
if not path.exists():
|
|
return None
|
|
try:
|
|
return json.loads(path.read_text())
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def _delete_draft() -> None:
|
|
with contextlib.suppress(FileNotFoundError):
|
|
_draft_path().unlink()
|
|
|
|
|
|
def _mark_step_done(state: dict, step: str) -> None:
|
|
state.setdefault("completed_steps", [])
|
|
if step not in state["completed_steps"]:
|
|
state["completed_steps"].append(step)
|
|
|
|
|
|
def _offer_resume(draft: dict) -> bool:
|
|
"""Show a summary of the incomplete setup and ask Resume / Start fresh."""
|
|
completed = draft.get("completed_steps", [])
|
|
step_display = {
|
|
"profile": f"Profile: {draft.get('user_name', '?')}",
|
|
"provider": f"Provider: {draft.get('provider_id', '?')}",
|
|
"model": f"Model: {draft.get('model', '?')}",
|
|
"api_key": "API key: stored in vault",
|
|
"connection": "Connection: test passed",
|
|
}
|
|
lines = ["[bold]An incomplete setup was found.[/bold]\n"]
|
|
for step, label in step_display.items():
|
|
if step in completed:
|
|
lines.append(f" [green]✓[/green] {label}")
|
|
else:
|
|
lines.append(f" [dim]○ {label.split(':')[0].strip()}: pending[/dim]")
|
|
|
|
console.print(Panel("\n".join(lines), title="Incomplete setup", border_style="yellow"))
|
|
console.print()
|
|
|
|
action = questionary.select(
|
|
"What would you like to do?",
|
|
choices=[
|
|
questionary.Choice("Resume from where you left off", value="resume"),
|
|
questionary.Choice("Start fresh", value="fresh"),
|
|
],
|
|
).ask()
|
|
if action is None:
|
|
raise SystemExit(0)
|
|
return action == "resume"
|
|
|
|
|
|
def run_setup() -> None:
|
|
console.print(Panel(
|
|
Text("Welcome to Pyra Setup", justify="center", style="bold cyan"),
|
|
subtitle="Personal AI Assistant",
|
|
border_style="cyan",
|
|
))
|
|
console.print()
|
|
|
|
state: dict = {}
|
|
draft = _load_draft()
|
|
if draft:
|
|
if _offer_resume(draft):
|
|
state = draft
|
|
else:
|
|
_delete_draft()
|
|
|
|
try:
|
|
# ── Step 1: profile ────────────────────────────────────────────────
|
|
if "profile" in state.get("completed_steps", []):
|
|
user_name = state["user_name"]
|
|
purpose = state["purpose"]
|
|
use_cases = state["use_cases"]
|
|
console.print(f" [dim]✓ Profile: {user_name}[/dim]")
|
|
else:
|
|
user_name, purpose, use_cases = _collect_user_profile()
|
|
state.update(user_name=user_name, purpose=purpose, use_cases=use_cases)
|
|
_mark_step_done(state, "profile")
|
|
_save_draft(state)
|
|
|
|
# ── Step 2: provider ───────────────────────────────────────────────
|
|
if "provider" in state.get("completed_steps", []):
|
|
provider = get_provider(state["provider_id"])
|
|
console.print(f" [dim]✓ Provider: {provider.display_name}[/dim]")
|
|
else:
|
|
provider = _choose_provider()
|
|
state.update(provider_id=provider.id)
|
|
_mark_step_done(state, "provider")
|
|
_save_draft(state)
|
|
|
|
# ── Step 3: model ──────────────────────────────────────────────────
|
|
if "model" in state.get("completed_steps", []):
|
|
model = state["model"]
|
|
console.print(f" [dim]✓ Model: {model}[/dim]")
|
|
else:
|
|
model = _choose_model(provider)
|
|
state.update(model=model)
|
|
_mark_step_done(state, "model")
|
|
_save_draft(state)
|
|
|
|
# ── Step 4: API key ────────────────────────────────────────────────
|
|
if "api_key" not in state.get("completed_steps", []) and provider.requires_key:
|
|
from pyra.vault.reader import get_key as _get_key
|
|
if not _get_key(provider.id):
|
|
_collect_api_key(provider)
|
|
_mark_step_done(state, "api_key")
|
|
_save_draft(state)
|
|
|
|
# ── Step 5: connection test ────────────────────────────────────────
|
|
if "connection" not in state.get("completed_steps", []):
|
|
model = _test_connection(provider, model)
|
|
state["model"] = model
|
|
_mark_step_done(state, "connection")
|
|
_save_draft(state)
|
|
|
|
# ── Finalise ───────────────────────────────────────────────────────
|
|
cfg = PyraConfig(
|
|
ai=ProviderConfig(
|
|
provider_id=provider.id,
|
|
model=model,
|
|
base_url=provider.base_url,
|
|
),
|
|
general=GeneralConfig(user_name=user_name, purpose=purpose),
|
|
)
|
|
save_config(cfg)
|
|
_delete_draft()
|
|
|
|
_suggest_plugins(use_cases)
|
|
|
|
console.print()
|
|
console.print(Panel(
|
|
f"[green]Setup complete![/green]\n\n"
|
|
f"Provider: [bold]{provider.display_name}[/bold]\n"
|
|
f"Model: [bold]{model}[/bold]\n\n"
|
|
"Run [bold cyan]pyra chat[/bold cyan] to start talking.",
|
|
border_style="green",
|
|
))
|
|
|
|
except SystemExit:
|
|
if state.get("completed_steps"):
|
|
console.print()
|
|
console.print(
|
|
" [dim]Setup paused — run [bold]pyra setup[/bold] to resume.[/dim]"
|
|
)
|
|
raise
|
|
|
|
|
|
def _collect_user_profile() -> tuple[str, str, list[str]]:
|
|
console.print("[bold]Let's personalise your setup.[/bold]")
|
|
console.print()
|
|
|
|
name = questionary.text("What should Pyra call you?", default="User").ask()
|
|
if name is None:
|
|
raise SystemExit(0)
|
|
name = name.strip() or "User"
|
|
|
|
purpose = questionary.text(
|
|
"In one sentence, what will you mainly use Pyra for? (optional)",
|
|
).ask()
|
|
if purpose is None:
|
|
raise SystemExit(0)
|
|
purpose = purpose.strip()
|
|
|
|
use_cases = questionary.checkbox(
|
|
"Which areas interest you? (Space to select, Enter to confirm)",
|
|
choices=list(_USE_CASE_PLUGINS.keys()),
|
|
).ask()
|
|
if use_cases is None:
|
|
raise SystemExit(0)
|
|
|
|
console.print()
|
|
return name, purpose, use_cases or []
|
|
|
|
|
|
def _suggest_plugins(use_cases: list[str]) -> None:
|
|
if not use_cases:
|
|
return
|
|
|
|
lines: list[str] = []
|
|
for uc in use_cases:
|
|
plugins = _USE_CASE_PLUGINS.get(uc, [])
|
|
if plugins:
|
|
lines.append(f"[bold]{uc}[/bold]")
|
|
for p in plugins:
|
|
lines.append(f" pyra plugin install {p}")
|
|
if not lines:
|
|
return
|
|
|
|
lines.append("")
|
|
lines.append("[dim]All listed plugins are in development — install when available.[/dim]")
|
|
|
|
console.print()
|
|
console.print(Panel(
|
|
"\n".join(lines),
|
|
title="Suggested plugins",
|
|
border_style="dim cyan",
|
|
))
|
|
|
|
|
|
def _choose_provider() -> Provider:
|
|
local = [p for p in PROVIDERS if p.group == "Local"]
|
|
cloud = [p for p in PROVIDERS if p.group == "Cloud"]
|
|
|
|
choices = (
|
|
[questionary.Choice("── Local ──────────────────", disabled=True)]
|
|
+ [questionary.Choice(p.display_name, value=p.id) for p in local]
|
|
+ [questionary.Choice("── Cloud ──────────────────", disabled=True)]
|
|
+ [questionary.Choice(p.display_name, value=p.id) for p in cloud]
|
|
)
|
|
|
|
provider_id = questionary.select(
|
|
"Choose your AI provider:",
|
|
choices=choices,
|
|
).ask()
|
|
|
|
if provider_id is None:
|
|
raise SystemExit(0)
|
|
|
|
provider = get_provider(provider_id)
|
|
|
|
if provider.connectivity_check:
|
|
_check_local_server(provider)
|
|
_show_local_model_status(provider)
|
|
|
|
return provider
|
|
|
|
|
|
def _classify_error(exc: Exception) -> tuple[str, str]:
|
|
"""Return (short_label, resolution_hint) for a provider or network error."""
|
|
name = type(exc).__name__
|
|
module = type(exc).__module__ or ""
|
|
is_llm = "litellm" in module or "openai" in module
|
|
|
|
if is_llm:
|
|
if "AuthenticationError" in name:
|
|
return (
|
|
"Invalid API key",
|
|
"The provider rejected your API key.\n"
|
|
"Double-check it on the provider dashboard and re-enter it.",
|
|
)
|
|
if "NotFoundError" in name:
|
|
return (
|
|
"Model not found",
|
|
"The model name doesn't exist for this provider.\n"
|
|
"Check the exact model identifier on the provider's model list.",
|
|
)
|
|
if "RateLimitError" in name:
|
|
return (
|
|
"Rate limit reached",
|
|
"You've exceeded the provider's request rate.\n"
|
|
"Wait a few seconds and retry.",
|
|
)
|
|
if "ServiceUnavailable" in name:
|
|
return (
|
|
"Service temporarily unavailable",
|
|
"The provider's servers returned a 5xx error.\n"
|
|
"This is usually transient — wait a minute and retry.",
|
|
)
|
|
if "APIConnectionError" in name or "ConnectError" in name:
|
|
return (
|
|
"Cannot reach provider",
|
|
"A network error prevented the connection.\n"
|
|
"Check your internet connection and firewall settings.",
|
|
)
|
|
if "Timeout" in name:
|
|
return (
|
|
"Request timed out",
|
|
"The provider did not respond in time.\n"
|
|
"The service may be overloaded — retry in a moment.",
|
|
)
|
|
if "BadRequestError" in name or "InvalidRequest" in name:
|
|
return (
|
|
"Bad request",
|
|
"The request was rejected — the model name or parameters may be wrong.\n"
|
|
"Verify the exact model identifier.",
|
|
)
|
|
return ("Provider error", str(exc)[:300])
|
|
|
|
if "httpx" in module:
|
|
if "ConnectError" in name or "ConnectTimeout" in name:
|
|
return (
|
|
"Server not reachable",
|
|
"Could not connect to the local server.\n"
|
|
"Make sure it is running and listening on the expected address.",
|
|
)
|
|
if "Timeout" in name:
|
|
return (
|
|
"Connection timed out",
|
|
"The local server did not respond in time.\n"
|
|
"It may still be starting up — wait a moment and retry.",
|
|
)
|
|
if "HTTPStatusError" in name:
|
|
code = getattr(getattr(exc, "response", None), "status_code", 0)
|
|
if code == 401:
|
|
return ("Unauthorized (401)", "The server requires credentials that were not provided.")
|
|
if code == 404:
|
|
return ("Endpoint not found (404)", "The API endpoint was not found — check the server version.")
|
|
return (f"HTTP {code}", f"The server returned an unexpected error (HTTP {code}).")
|
|
|
|
return ("Unexpected error", str(exc)[:300])
|
|
|
|
|
|
def _check_local_server(provider: Provider) -> None:
|
|
while True:
|
|
console.print(
|
|
f" Checking connection to [bold]{provider.display_name}[/bold]...", end=" "
|
|
)
|
|
try:
|
|
resp = httpx.get(provider.connectivity_check, timeout=3.0)
|
|
resp.raise_for_status()
|
|
console.print("[green]✓[/green]")
|
|
return
|
|
except Exception as exc:
|
|
label, hint = _classify_error(exc)
|
|
console.print("[yellow]✗[/yellow]")
|
|
console.print()
|
|
console.print(Panel(
|
|
f"[bold yellow]{label}[/bold yellow]\n\n{hint}",
|
|
title="Connection problem",
|
|
border_style="yellow",
|
|
))
|
|
action = questionary.select(
|
|
"How would you like to proceed?",
|
|
choices=[
|
|
questionary.Choice("Retry", value="retry"),
|
|
questionary.Choice(
|
|
"Continue anyway (model list may be unavailable)", value="continue"
|
|
),
|
|
questionary.Choice("Abort setup", value="abort"),
|
|
],
|
|
).ask()
|
|
if action is None or action == "abort":
|
|
raise SystemExit(0)
|
|
if action == "continue":
|
|
return
|
|
# "retry" → loop
|
|
|
|
|
|
def _show_local_model_status(provider: Provider) -> None:
|
|
"""Print a one-line status showing which models are currently loaded."""
|
|
models = _fetch_local_models(provider)
|
|
if not models:
|
|
console.print(" [yellow]No model currently loaded[/yellow]")
|
|
elif len(models) == 1:
|
|
console.print(f" [green]Loaded model:[/green] {models[0]}")
|
|
else:
|
|
names = ", ".join(models)
|
|
console.print(f" [green]{len(models)} models loaded:[/green] {names}")
|
|
|
|
|
|
def _fetch_local_models(provider: Provider) -> list[str]:
|
|
"""Return currently loaded/available models from a local provider's API."""
|
|
if not provider.base_url:
|
|
return []
|
|
try:
|
|
if provider.id == "ollama":
|
|
resp = httpx.get(f"{provider.base_url}/api/tags", timeout=3.0)
|
|
resp.raise_for_status()
|
|
return [m["name"] for m in resp.json().get("models", [])]
|
|
else:
|
|
resp = httpx.get(f"{provider.base_url}/models", timeout=3.0)
|
|
resp.raise_for_status()
|
|
return [m["id"] for m in resp.json().get("data", [])]
|
|
except Exception:
|
|
return []
|
|
|
|
|
|
def _fetch_lmstudio_available_models() -> list[str]:
|
|
"""Return all downloaded (not necessarily loaded) models from LM Studio's beta API."""
|
|
try:
|
|
resp = httpx.get("http://localhost:1234/api/v0/models", timeout=3.0)
|
|
resp.raise_for_status()
|
|
return [m["id"] for m in resp.json().get("data", [])]
|
|
except Exception:
|
|
return []
|
|
|
|
|
|
def _load_lmstudio_model(model_id: str) -> bool:
|
|
"""Attempt to load a model via LM Studio's beta API. Returns True on success."""
|
|
try:
|
|
resp = httpx.post(
|
|
"http://localhost:1234/api/v0/models/load",
|
|
json={"identifier": model_id},
|
|
timeout=60.0,
|
|
)
|
|
return resp.is_success
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def _choose_model(provider: Provider) -> str:
|
|
if provider.group != "Local":
|
|
model = questionary.text("Model name:", default=provider.default_model).ask()
|
|
if model is None:
|
|
raise SystemExit(0)
|
|
return model.strip()
|
|
|
|
_MANUAL = "__manual__"
|
|
loaded = _fetch_local_models(provider)
|
|
|
|
if loaded:
|
|
choices = loaded + [questionary.Choice("── Enter manually ──", value=_MANUAL)]
|
|
selected = questionary.select("Select model:", choices=choices).ask()
|
|
if selected is None:
|
|
raise SystemExit(0)
|
|
if selected != _MANUAL:
|
|
return selected
|
|
|
|
elif provider.id == "lmstudio":
|
|
console.print(" [yellow]No model currently loaded in LM Studio.[/yellow]")
|
|
available = _fetch_lmstudio_available_models()
|
|
if available:
|
|
choices = available + [questionary.Choice("── Enter manually ──", value=_MANUAL)]
|
|
selected = questionary.select(
|
|
"Select a downloaded model to load:", choices=choices
|
|
).ask()
|
|
if selected is None:
|
|
raise SystemExit(0)
|
|
if selected != _MANUAL:
|
|
console.print(f" Loading [bold]{selected}[/bold]...", end=" ")
|
|
if _load_lmstudio_model(selected):
|
|
console.print("[green]✓ Loaded[/green]")
|
|
else:
|
|
console.print(
|
|
"[yellow]Could not load via API — "
|
|
"please load the model manually in LM Studio.[/yellow]"
|
|
)
|
|
return selected
|
|
else:
|
|
console.print(Panel(
|
|
"No models are loaded or downloaded in LM Studio.\n"
|
|
"Open LM Studio → Local Server tab → load a model, then re-run setup.",
|
|
border_style="yellow",
|
|
))
|
|
|
|
else:
|
|
console.print(f" [yellow]No models found at {provider.base_url}.[/yellow]")
|
|
|
|
model = questionary.text("Model name:", default=provider.default_model).ask()
|
|
if model is None:
|
|
raise SystemExit(0)
|
|
return model.strip()
|
|
|
|
|
|
def _collect_api_key(provider: Provider) -> None:
|
|
from pyra.vault.writer import set_key
|
|
|
|
console.print(
|
|
f"\n [dim]API key will be stored in the encrypted vault — never in config.yaml[/dim]"
|
|
)
|
|
key = questionary.password(f"Enter your {provider.display_name} API key:").ask()
|
|
if key is None:
|
|
raise SystemExit(0)
|
|
key = key.strip()
|
|
if not key:
|
|
console.print("[red]No key entered — skipping.[/red]")
|
|
return
|
|
set_key(provider.id, key)
|
|
console.print(" [green]✓ Key stored in vault[/green]")
|
|
|
|
|
|
def _test_connection(provider: Provider, model: str) -> str:
|
|
from pyra.vault.reader import get_key
|
|
|
|
while True:
|
|
console.print("\n Running connection test...", end=" ")
|
|
try:
|
|
import litellm
|
|
|
|
api_key = get_key(provider.id) if provider.requires_key else "local"
|
|
kwargs: dict = {
|
|
"model": f"{provider.litellm_prefix}{model}",
|
|
"messages": [{"role": "user", "content": "Reply with exactly: OK"}],
|
|
"max_tokens": 10,
|
|
"api_key": api_key,
|
|
}
|
|
if provider.base_url:
|
|
kwargs["api_base"] = provider.base_url
|
|
|
|
litellm.completion(**kwargs)
|
|
console.print("[green]✓ Connection OK[/green]")
|
|
return model
|
|
|
|
except Exception as exc:
|
|
label, hint = _classify_error(exc)
|
|
console.print("[red]✗[/red]")
|
|
console.print()
|
|
console.print(Panel(
|
|
f"[bold red]{label}[/bold red]\n\n{hint}",
|
|
title="Test call failed",
|
|
border_style="red",
|
|
))
|
|
|
|
exc_name = type(exc).__name__
|
|
is_auth_error = "AuthenticationError" in exc_name
|
|
is_model_error = any(
|
|
kw in exc_name for kw in ("NotFoundError", "BadRequestError", "InvalidRequest")
|
|
)
|
|
|
|
choices = [questionary.Choice("Retry", value="retry")]
|
|
if is_model_error:
|
|
choices.append(questionary.Choice("Change model", value="change_model"))
|
|
if provider.requires_key and is_auth_error:
|
|
choices.append(questionary.Choice("Re-enter API key", value="rekey"))
|
|
choices += [
|
|
questionary.Choice("Skip test and continue setup", value="skip"),
|
|
questionary.Choice("Abort setup", value="abort"),
|
|
]
|
|
|
|
action = questionary.select(
|
|
"How would you like to proceed?",
|
|
choices=choices,
|
|
).ask()
|
|
|
|
if action is None or action == "abort":
|
|
raise SystemExit(0)
|
|
if action == "skip":
|
|
console.print(
|
|
" [dim]Test skipped — run [bold]pyra setup[/bold] again if chat doesn't work.[/dim]"
|
|
)
|
|
return model
|
|
if action == "change_model":
|
|
model = _choose_model(provider)
|
|
elif action == "rekey":
|
|
_collect_api_key(provider)
|
|
# loop → retry (with possibly new model or key)
|