import contextlib import json import httpx import questionary from rich.console import Console from rich.panel import Panel from rich.text import Text from pyra.config.manager import save_config from pyra.config.schema import GeneralConfig, ProviderConfig, PyraConfig from pyra.setup.providers import PROVIDERS, Provider, get_provider from pyra.utils.paths import pyra_home, safe_chmod console = Console() _USE_CASE_PLUGINS: dict[str, list[str]] = { "Research & web": ["websearch", "headless_browser"], "Development & servers": ["server_manager", "ssh_tool", "docker_tool"], "File management": ["gdrive", "onedrive", "dropbox_tool"], "Communication bots": ["matrix_bot", "telegram_bot", "signal_bot"], "Email": ["email"], "Productivity & calendars": ["nextcloud"], } _DRAFT_FILE = "setup.draft.json" def _draft_path(): return pyra_home() / _DRAFT_FILE def _save_draft(state: dict) -> None: path = _draft_path() path.write_text(json.dumps(state, indent=2)) safe_chmod(path, 0o600) def _load_draft() -> dict | None: path = _draft_path() if not path.exists(): return None try: return json.loads(path.read_text()) except Exception: return None def _delete_draft() -> None: with contextlib.suppress(FileNotFoundError): _draft_path().unlink() def _mark_step_done(state: dict, step: str) -> None: state.setdefault("completed_steps", []) if step not in state["completed_steps"]: state["completed_steps"].append(step) def _offer_resume(draft: dict) -> bool: """Show a summary of the incomplete setup and ask Resume / Start fresh.""" completed = draft.get("completed_steps", []) step_display = { "profile": f"Profile: {draft.get('user_name', '?')}", "provider": f"Provider: {draft.get('provider_id', '?')}", "model": f"Model: {draft.get('model', '?')}", "api_key": "API key: stored in vault", "connection": "Connection: test passed", } lines = ["[bold]An incomplete setup was found.[/bold]\n"] for step, label in step_display.items(): if step in completed: lines.append(f" [green]✓[/green] {label}") else: lines.append(f" [dim]○ {label.split(':')[0].strip()}: pending[/dim]") console.print(Panel("\n".join(lines), title="Incomplete setup", border_style="yellow")) console.print() action = questionary.select( "What would you like to do?", choices=[ questionary.Choice("Resume from where you left off", value="resume"), questionary.Choice("Start fresh", value="fresh"), ], ).ask() if action is None: raise SystemExit(0) return action == "resume" def run_setup() -> None: console.print(Panel( Text("Welcome to Pyra Setup", justify="center", style="bold cyan"), subtitle="Personal AI Assistant", border_style="cyan", )) console.print() state: dict = {} draft = _load_draft() if draft: if _offer_resume(draft): state = draft else: _delete_draft() try: # ── Step 1: profile ──────────────────────────────────────────────── if "profile" in state.get("completed_steps", []): user_name = state["user_name"] purpose = state["purpose"] use_cases = state["use_cases"] console.print(f" [dim]✓ Profile: {user_name}[/dim]") else: user_name, purpose, use_cases = _collect_user_profile() state.update(user_name=user_name, purpose=purpose, use_cases=use_cases) _mark_step_done(state, "profile") _save_draft(state) # ── Step 2: provider ─────────────────────────────────────────────── if "provider" in state.get("completed_steps", []): provider = get_provider(state["provider_id"]) console.print(f" [dim]✓ Provider: {provider.display_name}[/dim]") else: provider = _choose_provider() state.update(provider_id=provider.id) _mark_step_done(state, "provider") _save_draft(state) # ── Step 3: model ────────────────────────────────────────────────── if "model" in state.get("completed_steps", []): model = state["model"] console.print(f" [dim]✓ Model: {model}[/dim]") else: model = _choose_model(provider) state.update(model=model) _mark_step_done(state, "model") _save_draft(state) # ── Step 4: API key ──────────────────────────────────────────────── if "api_key" not in state.get("completed_steps", []) and provider.requires_key: from pyra.vault.reader import get_key as _get_key if not _get_key(provider.id): _collect_api_key(provider) _mark_step_done(state, "api_key") _save_draft(state) # ── Step 5: connection test ──────────────────────────────────────── if "connection" not in state.get("completed_steps", []): model = _test_connection(provider, model) state["model"] = model _mark_step_done(state, "connection") _save_draft(state) # ── Finalise ─────────────────────────────────────────────────────── cfg = PyraConfig( ai=ProviderConfig( provider_id=provider.id, model=model, base_url=provider.base_url, ), general=GeneralConfig(user_name=user_name, purpose=purpose), ) save_config(cfg) _delete_draft() _suggest_plugins(use_cases) console.print() console.print(Panel( f"[green]Setup complete![/green]\n\n" f"Provider: [bold]{provider.display_name}[/bold]\n" f"Model: [bold]{model}[/bold]\n\n" "Run [bold cyan]pyra chat[/bold cyan] to start talking.", border_style="green", )) except SystemExit: if state.get("completed_steps"): console.print() console.print( " [dim]Setup paused — run [bold]pyra setup[/bold] to resume.[/dim]" ) raise def _collect_user_profile() -> tuple[str, str, list[str]]: console.print("[bold]Let's personalise your setup.[/bold]") console.print() name = questionary.text("What should Pyra call you?", default="User").ask() if name is None: raise SystemExit(0) name = name.strip() or "User" purpose = questionary.text( "In one sentence, what will you mainly use Pyra for? (optional)", ).ask() if purpose is None: raise SystemExit(0) purpose = purpose.strip() use_cases = questionary.checkbox( "Which areas interest you? (Space to select, Enter to confirm)", choices=list(_USE_CASE_PLUGINS.keys()), ).ask() if use_cases is None: raise SystemExit(0) console.print() return name, purpose, use_cases or [] def _suggest_plugins(use_cases: list[str]) -> None: if not use_cases: return lines: list[str] = [] for uc in use_cases: plugins = _USE_CASE_PLUGINS.get(uc, []) if plugins: lines.append(f"[bold]{uc}[/bold]") for p in plugins: lines.append(f" pyra plugin install {p}") if not lines: return lines.append("") lines.append("[dim]All listed plugins are in development — install when available.[/dim]") console.print() console.print(Panel( "\n".join(lines), title="Suggested plugins", border_style="dim cyan", )) def _choose_provider() -> Provider: local = [p for p in PROVIDERS if p.group == "Local"] cloud = [p for p in PROVIDERS if p.group == "Cloud"] choices = ( [questionary.Choice("── Local ──────────────────", disabled=True)] + [questionary.Choice(p.display_name, value=p.id) for p in local] + [questionary.Choice("── Cloud ──────────────────", disabled=True)] + [questionary.Choice(p.display_name, value=p.id) for p in cloud] ) provider_id = questionary.select( "Choose your AI provider:", choices=choices, ).ask() if provider_id is None: raise SystemExit(0) provider = get_provider(provider_id) if provider.connectivity_check: _check_local_server(provider) _show_local_model_status(provider) return provider def _classify_error(exc: Exception) -> tuple[str, str]: """Return (short_label, resolution_hint) for a provider or network error.""" name = type(exc).__name__ module = type(exc).__module__ or "" is_llm = "litellm" in module or "openai" in module if is_llm: if "AuthenticationError" in name: return ( "Invalid API key", "The provider rejected your API key.\n" "Double-check it on the provider dashboard and re-enter it.", ) if "NotFoundError" in name: return ( "Model not found", "The model name doesn't exist for this provider.\n" "Check the exact model identifier on the provider's model list.", ) if "RateLimitError" in name: return ( "Rate limit reached", "You've exceeded the provider's request rate.\n" "Wait a few seconds and retry.", ) if "ServiceUnavailable" in name: return ( "Service temporarily unavailable", "The provider's servers returned a 5xx error.\n" "This is usually transient — wait a minute and retry.", ) if "APIConnectionError" in name or "ConnectError" in name: return ( "Cannot reach provider", "A network error prevented the connection.\n" "Check your internet connection and firewall settings.", ) if "Timeout" in name: return ( "Request timed out", "The provider did not respond in time.\n" "The service may be overloaded — retry in a moment.", ) if "BadRequestError" in name or "InvalidRequest" in name: return ( "Bad request", "The request was rejected — the model name or parameters may be wrong.\n" "Verify the exact model identifier.", ) return ("Provider error", str(exc)[:300]) if "httpx" in module: if "ConnectError" in name or "ConnectTimeout" in name: return ( "Server not reachable", "Could not connect to the local server.\n" "Make sure it is running and listening on the expected address.", ) if "Timeout" in name: return ( "Connection timed out", "The local server did not respond in time.\n" "It may still be starting up — wait a moment and retry.", ) if "HTTPStatusError" in name: code = getattr(getattr(exc, "response", None), "status_code", 0) if code == 401: return ("Unauthorized (401)", "The server requires credentials that were not provided.") if code == 404: return ("Endpoint not found (404)", "The API endpoint was not found — check the server version.") return (f"HTTP {code}", f"The server returned an unexpected error (HTTP {code}).") return ("Unexpected error", str(exc)[:300]) def _check_local_server(provider: Provider) -> None: while True: console.print( f" Checking connection to [bold]{provider.display_name}[/bold]...", end=" " ) try: resp = httpx.get(provider.connectivity_check, timeout=3.0) resp.raise_for_status() console.print("[green]✓[/green]") return except Exception as exc: label, hint = _classify_error(exc) console.print("[yellow]✗[/yellow]") console.print() console.print(Panel( f"[bold yellow]{label}[/bold yellow]\n\n{hint}", title="Connection problem", border_style="yellow", )) action = questionary.select( "How would you like to proceed?", choices=[ questionary.Choice("Retry", value="retry"), questionary.Choice( "Continue anyway (model list may be unavailable)", value="continue" ), questionary.Choice("Abort setup", value="abort"), ], ).ask() if action is None or action == "abort": raise SystemExit(0) if action == "continue": return # "retry" → loop def _show_local_model_status(provider: Provider) -> None: """Print a one-line status showing which models are currently loaded.""" models = _fetch_local_models(provider) if not models: console.print(" [yellow]No model currently loaded[/yellow]") elif len(models) == 1: console.print(f" [green]Loaded model:[/green] {models[0]}") else: names = ", ".join(models) console.print(f" [green]{len(models)} models loaded:[/green] {names}") def _fetch_local_models(provider: Provider) -> list[str]: """Return currently loaded/available models from a local provider's API.""" if not provider.base_url: return [] try: if provider.id == "ollama": resp = httpx.get(f"{provider.base_url}/api/tags", timeout=3.0) resp.raise_for_status() return [m["name"] for m in resp.json().get("models", [])] else: resp = httpx.get(f"{provider.base_url}/models", timeout=3.0) resp.raise_for_status() return [m["id"] for m in resp.json().get("data", [])] except Exception: return [] def _fetch_lmstudio_available_models() -> list[str]: """Return all downloaded (not necessarily loaded) models from LM Studio's beta API.""" try: resp = httpx.get("http://localhost:1234/api/v0/models", timeout=3.0) resp.raise_for_status() return [m["id"] for m in resp.json().get("data", [])] except Exception: return [] def _load_lmstudio_model(model_id: str) -> bool: """Attempt to load a model via LM Studio's beta API. Returns True on success.""" try: resp = httpx.post( "http://localhost:1234/api/v0/models/load", json={"identifier": model_id}, timeout=60.0, ) return resp.is_success except Exception: return False def _choose_model(provider: Provider) -> str: if provider.group != "Local": model = questionary.text("Model name:", default=provider.default_model).ask() if model is None: raise SystemExit(0) return model.strip() _MANUAL = "__manual__" loaded = _fetch_local_models(provider) if loaded: choices = loaded + [questionary.Choice("── Enter manually ──", value=_MANUAL)] selected = questionary.select("Select model:", choices=choices).ask() if selected is None: raise SystemExit(0) if selected != _MANUAL: return selected elif provider.id == "lmstudio": console.print(" [yellow]No model currently loaded in LM Studio.[/yellow]") available = _fetch_lmstudio_available_models() if available: choices = available + [questionary.Choice("── Enter manually ──", value=_MANUAL)] selected = questionary.select( "Select a downloaded model to load:", choices=choices ).ask() if selected is None: raise SystemExit(0) if selected != _MANUAL: console.print(f" Loading [bold]{selected}[/bold]...", end=" ") if _load_lmstudio_model(selected): console.print("[green]✓ Loaded[/green]") else: console.print( "[yellow]Could not load via API — " "please load the model manually in LM Studio.[/yellow]" ) return selected else: console.print(Panel( "No models are loaded or downloaded in LM Studio.\n" "Open LM Studio → Local Server tab → load a model, then re-run setup.", border_style="yellow", )) else: console.print(f" [yellow]No models found at {provider.base_url}.[/yellow]") model = questionary.text("Model name:", default=provider.default_model).ask() if model is None: raise SystemExit(0) return model.strip() def _collect_api_key(provider: Provider) -> None: from pyra.vault.writer import set_key console.print( f"\n [dim]API key will be stored in the encrypted vault — never in config.yaml[/dim]" ) key = questionary.password(f"Enter your {provider.display_name} API key:").ask() if key is None: raise SystemExit(0) key = key.strip() if not key: console.print("[red]No key entered — skipping.[/red]") return set_key(provider.id, key) console.print(" [green]✓ Key stored in vault[/green]") def _test_connection(provider: Provider, model: str) -> str: from pyra.vault.reader import get_key while True: console.print("\n Running connection test...", end=" ") try: import litellm api_key = get_key(provider.id) if provider.requires_key else "local" kwargs: dict = { "model": f"{provider.litellm_prefix}{model}", "messages": [{"role": "user", "content": "Reply with exactly: OK"}], "max_tokens": 10, "api_key": api_key, } if provider.base_url: kwargs["api_base"] = provider.base_url litellm.completion(**kwargs) console.print("[green]✓ Connection OK[/green]") return model except Exception as exc: label, hint = _classify_error(exc) console.print("[red]✗[/red]") console.print() console.print(Panel( f"[bold red]{label}[/bold red]\n\n{hint}", title="Test call failed", border_style="red", )) exc_name = type(exc).__name__ is_auth_error = "AuthenticationError" in exc_name is_model_error = any( kw in exc_name for kw in ("NotFoundError", "BadRequestError", "InvalidRequest") ) choices = [questionary.Choice("Retry", value="retry")] if is_model_error: choices.append(questionary.Choice("Change model", value="change_model")) if provider.requires_key and is_auth_error: choices.append(questionary.Choice("Re-enter API key", value="rekey")) choices += [ questionary.Choice("Skip test and continue setup", value="skip"), questionary.Choice("Abort setup", value="abort"), ] action = questionary.select( "How would you like to proceed?", choices=choices, ).ask() if action is None or action == "abort": raise SystemExit(0) if action == "skip": console.print( " [dim]Test skipped — run [bold]pyra setup[/bold] again if chat doesn't work.[/dim]" ) return model if action == "change_model": model = _choose_model(provider) elif action == "rekey": _collect_api_key(provider) # loop → retry (with possibly new model or key)