From 019e8044a94f0f7fae5154e1edf468356ff3a012 Mon Sep 17 00:00:00 2001 From: curo1305 Date: Tue, 19 May 2026 13:43:49 +0200 Subject: [PATCH] feat(setup): model re-entry, status indicator, and resumable setup wizard - _test_connection() now returns the (possibly changed) model name and offers a "Change model" option when the error is model-related - _show_local_model_status() prints which models are currently loaded immediately after selecting a local provider - Draft persistence: each completed wizard step is saved to ~/.pyra/setup.draft.json (chmod 600); on the next run a yellow panel summarises progress and offers [Resume / Start fresh]; draft is deleted on successful completion or Ctrl-C with no completed steps Co-Authored-By: Claude Sonnet 4.6 --- src/pyra/setup/wizard.py | 208 +++++++++++++++++++++++++++++++++------ 1 file changed, 178 insertions(+), 30 deletions(-) diff --git a/src/pyra/setup/wizard.py b/src/pyra/setup/wizard.py index a41f239..5c2c4e1 100644 --- a/src/pyra/setup/wizard.py +++ b/src/pyra/setup/wizard.py @@ -1,3 +1,6 @@ +import contextlib +import json + import httpx import questionary from rich.console import Console @@ -7,6 +10,7 @@ from rich.text import Text from pyra.config.manager import save_config from pyra.config.schema import GeneralConfig, ProviderConfig, PyraConfig from pyra.setup.providers import PROVIDERS, Provider, get_provider +from pyra.utils.paths import pyra_home, safe_chmod console = Console() @@ -20,6 +24,72 @@ _USE_CASE_PLUGINS: dict[str, list[str]] = { } +_DRAFT_FILE = "setup.draft.json" + + +def _draft_path(): + return pyra_home() / _DRAFT_FILE + + +def _save_draft(state: dict) -> None: + path = _draft_path() + path.write_text(json.dumps(state, indent=2)) + safe_chmod(path, 0o600) + + +def _load_draft() -> dict | None: + path = _draft_path() + if not path.exists(): + return None + try: + return json.loads(path.read_text()) + except Exception: + return None + + +def _delete_draft() -> None: + with contextlib.suppress(FileNotFoundError): + _draft_path().unlink() + + +def _mark_step_done(state: dict, step: str) -> None: + state.setdefault("completed_steps", []) + if step not in state["completed_steps"]: + state["completed_steps"].append(step) + + +def _offer_resume(draft: dict) -> bool: + """Show a summary of the incomplete setup and ask Resume / Start fresh.""" + completed = draft.get("completed_steps", []) + step_display = { + "profile": f"Profile: {draft.get('user_name', '?')}", + "provider": f"Provider: {draft.get('provider_id', '?')}", + "model": f"Model: {draft.get('model', '?')}", + "api_key": "API key: stored in vault", + "connection": "Connection: test passed", + } + lines = ["[bold]An incomplete setup was found.[/bold]\n"] + for step, label in step_display.items(): + if step in completed: + lines.append(f" [green]✓[/green] {label}") + else: + lines.append(f" [dim]○ {label.split(':')[0].strip()}: pending[/dim]") + + console.print(Panel("\n".join(lines), title="Incomplete setup", border_style="yellow")) + console.print() + + action = questionary.select( + "What would you like to do?", + choices=[ + questionary.Choice("Resume from where you left off", value="resume"), + questionary.Choice("Start fresh", value="fresh"), + ], + ).ask() + if action is None: + raise SystemExit(0) + return action == "resume" + + def run_setup() -> None: console.print(Panel( Text("Welcome to Pyra Setup", justify="center", style="bold cyan"), @@ -28,36 +98,92 @@ def run_setup() -> None: )) console.print() - user_name, purpose, use_cases = _collect_user_profile() + state: dict = {} + draft = _load_draft() + if draft: + if _offer_resume(draft): + state = draft + else: + _delete_draft() - provider = _choose_provider() - model = _choose_model(provider) + try: + # ── Step 1: profile ──────────────────────────────────────────────── + if "profile" in state.get("completed_steps", []): + user_name = state["user_name"] + purpose = state["purpose"] + use_cases = state["use_cases"] + console.print(f" [dim]✓ Profile: {user_name}[/dim]") + else: + user_name, purpose, use_cases = _collect_user_profile() + state.update(user_name=user_name, purpose=purpose, use_cases=use_cases) + _mark_step_done(state, "profile") + _save_draft(state) - if provider.requires_key: - _collect_api_key(provider) + # ── Step 2: provider ─────────────────────────────────────────────── + if "provider" in state.get("completed_steps", []): + provider = get_provider(state["provider_id"]) + console.print(f" [dim]✓ Provider: {provider.display_name}[/dim]") + else: + provider = _choose_provider() + state.update(provider_id=provider.id) + _mark_step_done(state, "provider") + _save_draft(state) - _test_connection(provider, model) + # ── Step 3: model ────────────────────────────────────────────────── + if "model" in state.get("completed_steps", []): + model = state["model"] + console.print(f" [dim]✓ Model: {model}[/dim]") + else: + model = _choose_model(provider) + state.update(model=model) + _mark_step_done(state, "model") + _save_draft(state) - cfg = PyraConfig( - ai=ProviderConfig( - provider_id=provider.id, - model=model, - base_url=provider.base_url, - ), - general=GeneralConfig(user_name=user_name, purpose=purpose), - ) - save_config(cfg) + # ── Step 4: API key ──────────────────────────────────────────────── + if "api_key" not in state.get("completed_steps", []) and provider.requires_key: + from pyra.vault.reader import get_key as _get_key + if not _get_key(provider.id): + _collect_api_key(provider) + _mark_step_done(state, "api_key") + _save_draft(state) - _suggest_plugins(use_cases) + # ── Step 5: connection test ──────────────────────────────────────── + if "connection" not in state.get("completed_steps", []): + model = _test_connection(provider, model) + state["model"] = model + _mark_step_done(state, "connection") + _save_draft(state) - console.print() - console.print(Panel( - f"[green]Setup complete![/green]\n\n" - f"Provider: [bold]{provider.display_name}[/bold]\n" - f"Model: [bold]{model}[/bold]\n\n" - "Run [bold cyan]pyra chat[/bold cyan] to start talking.", - border_style="green", - )) + # ── Finalise ─────────────────────────────────────────────────────── + cfg = PyraConfig( + ai=ProviderConfig( + provider_id=provider.id, + model=model, + base_url=provider.base_url, + ), + general=GeneralConfig(user_name=user_name, purpose=purpose), + ) + save_config(cfg) + _delete_draft() + + _suggest_plugins(use_cases) + + console.print() + console.print(Panel( + f"[green]Setup complete![/green]\n\n" + f"Provider: [bold]{provider.display_name}[/bold]\n" + f"Model: [bold]{model}[/bold]\n\n" + "Run [bold cyan]pyra chat[/bold cyan] to start talking.", + border_style="green", + )) + + except SystemExit: + if state.get("completed_steps"): + console.print() + console.print( + " [dim]Setup paused — run [bold]pyra setup[/bold] to resume.[/dim]" + ) + raise def _collect_user_profile() -> tuple[str, str, list[str]]: @@ -135,6 +261,7 @@ def _choose_provider() -> Provider: if provider.connectivity_check: _check_local_server(provider) + _show_local_model_status(provider) return provider @@ -250,6 +377,18 @@ def _check_local_server(provider: Provider) -> None: # "retry" → loop +def _show_local_model_status(provider: Provider) -> None: + """Print a one-line status showing which models are currently loaded.""" + models = _fetch_local_models(provider) + if not models: + console.print(" [yellow]No model currently loaded[/yellow]") + elif len(models) == 1: + console.print(f" [green]Loaded model:[/green] {models[0]}") + else: + names = ", ".join(models) + console.print(f" [green]{len(models)} models loaded:[/green] {names}") + + def _fetch_local_models(provider: Provider) -> list[str]: """Return currently loaded/available models from a local provider's API.""" if not provider.base_url: @@ -361,7 +500,7 @@ def _collect_api_key(provider: Provider) -> None: console.print(" [green]✓ Key stored in vault[/green]") -def _test_connection(provider: Provider, model: str) -> None: +def _test_connection(provider: Provider, model: str) -> str: from pyra.vault.reader import get_key while True: @@ -381,7 +520,7 @@ def _test_connection(provider: Provider, model: str) -> None: litellm.completion(**kwargs) console.print("[green]✓ Connection OK[/green]") - return + return model except Exception as exc: label, hint = _classify_error(exc) @@ -393,8 +532,15 @@ def _test_connection(provider: Provider, model: str) -> None: border_style="red", )) - is_auth_error = "AuthenticationError" in type(exc).__name__ + exc_name = type(exc).__name__ + is_auth_error = "AuthenticationError" in exc_name + is_model_error = any( + kw in exc_name for kw in ("NotFoundError", "BadRequestError", "InvalidRequest") + ) + choices = [questionary.Choice("Retry", value="retry")] + if is_model_error: + choices.append(questionary.Choice("Change model", value="change_model")) if provider.requires_key and is_auth_error: choices.append(questionary.Choice("Re-enter API key", value="rekey")) choices += [ @@ -413,7 +559,9 @@ def _test_connection(provider: Provider, model: str) -> None: console.print( " [dim]Test skipped — run [bold]pyra setup[/bold] again if chat doesn't work.[/dim]" ) - return - if action == "rekey": + return model + if action == "change_model": + model = _choose_model(provider) + elif action == "rekey": _collect_api_key(provider) - # "retry" or after "rekey" → loop + # loop → retry (with possibly new model or key)