Compare commits

2 Commits

Author SHA1 Message Date
curo1305 b3851a2715 test: add tests for draft persistence, model status, and model re-entry
10 new tests covering:
- _save_draft / _load_draft / _delete_draft / _mark_step_done helpers
- draft file permissions (chmod 600)
- _show_local_model_status with zero, one, and multiple loaded models
- _test_connection change_model path (model error → change → retry succeeds)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 13:43:53 +02:00
curo1305 019e8044a9 feat(setup): model re-entry, status indicator, and resumable setup wizard
- _test_connection() now returns the (possibly changed) model name and
  offers a "Change model" option when the error is model-related
- _show_local_model_status() prints which models are currently loaded
  immediately after selecting a local provider
- Draft persistence: each completed wizard step is saved to
  ~/.pyra/setup.draft.json (chmod 600); on the next run a yellow panel
  summarises progress and offers [Resume / Start fresh]; draft is
  deleted on successful completion or Ctrl-C with no completed steps

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 13:43:49 +02:00
2 changed files with 291 additions and 30 deletions
+156 -8
View File
@@ -1,3 +1,6 @@
import contextlib
import json
import httpx
import questionary
from rich.console import Console
@@ -7,6 +10,7 @@ from rich.text import Text
from pyra.config.manager import save_config
from pyra.config.schema import GeneralConfig, ProviderConfig, PyraConfig
from pyra.setup.providers import PROVIDERS, Provider, get_provider
from pyra.utils.paths import pyra_home, safe_chmod
console = Console()
@@ -20,6 +24,72 @@ _USE_CASE_PLUGINS: dict[str, list[str]] = {
}
_DRAFT_FILE = "setup.draft.json"
def _draft_path():
return pyra_home() / _DRAFT_FILE
def _save_draft(state: dict) -> None:
path = _draft_path()
path.write_text(json.dumps(state, indent=2))
safe_chmod(path, 0o600)
def _load_draft() -> dict | None:
path = _draft_path()
if not path.exists():
return None
try:
return json.loads(path.read_text())
except Exception:
return None
def _delete_draft() -> None:
with contextlib.suppress(FileNotFoundError):
_draft_path().unlink()
def _mark_step_done(state: dict, step: str) -> None:
state.setdefault("completed_steps", [])
if step not in state["completed_steps"]:
state["completed_steps"].append(step)
def _offer_resume(draft: dict) -> bool:
"""Show a summary of the incomplete setup and ask Resume / Start fresh."""
completed = draft.get("completed_steps", [])
step_display = {
"profile": f"Profile: {draft.get('user_name', '?')}",
"provider": f"Provider: {draft.get('provider_id', '?')}",
"model": f"Model: {draft.get('model', '?')}",
"api_key": "API key: stored in vault",
"connection": "Connection: test passed",
}
lines = ["[bold]An incomplete setup was found.[/bold]\n"]
for step, label in step_display.items():
if step in completed:
lines.append(f" [green]✓[/green] {label}")
else:
lines.append(f" [dim]○ {label.split(':')[0].strip()}: pending[/dim]")
console.print(Panel("\n".join(lines), title="Incomplete setup", border_style="yellow"))
console.print()
action = questionary.select(
"What would you like to do?",
choices=[
questionary.Choice("Resume from where you left off", value="resume"),
questionary.Choice("Start fresh", value="fresh"),
],
).ask()
if action is None:
raise SystemExit(0)
return action == "resume"
def run_setup() -> None:
console.print(Panel(
Text("Welcome to Pyra Setup", justify="center", style="bold cyan"),
@@ -28,16 +98,63 @@ def run_setup() -> None:
))
console.print()
state: dict = {}
draft = _load_draft()
if draft:
if _offer_resume(draft):
state = draft
else:
_delete_draft()
try:
# ── Step 1: profile ────────────────────────────────────────────────
if "profile" in state.get("completed_steps", []):
user_name = state["user_name"]
purpose = state["purpose"]
use_cases = state["use_cases"]
console.print(f" [dim]✓ Profile: {user_name}[/dim]")
else:
user_name, purpose, use_cases = _collect_user_profile()
state.update(user_name=user_name, purpose=purpose, use_cases=use_cases)
_mark_step_done(state, "profile")
_save_draft(state)
# ── Step 2: provider ───────────────────────────────────────────────
if "provider" in state.get("completed_steps", []):
provider = get_provider(state["provider_id"])
console.print(f" [dim]✓ Provider: {provider.display_name}[/dim]")
else:
provider = _choose_provider()
state.update(provider_id=provider.id)
_mark_step_done(state, "provider")
_save_draft(state)
# ── Step 3: model ──────────────────────────────────────────────────
if "model" in state.get("completed_steps", []):
model = state["model"]
console.print(f" [dim]✓ Model: {model}[/dim]")
else:
model = _choose_model(provider)
state.update(model=model)
_mark_step_done(state, "model")
_save_draft(state)
if provider.requires_key:
# ── Step 4: API key ────────────────────────────────────────────────
if "api_key" not in state.get("completed_steps", []) and provider.requires_key:
from pyra.vault.reader import get_key as _get_key
if not _get_key(provider.id):
_collect_api_key(provider)
_mark_step_done(state, "api_key")
_save_draft(state)
_test_connection(provider, model)
# ── Step 5: connection test ────────────────────────────────────────
if "connection" not in state.get("completed_steps", []):
model = _test_connection(provider, model)
state["model"] = model
_mark_step_done(state, "connection")
_save_draft(state)
# ── Finalise ───────────────────────────────────────────────────────
cfg = PyraConfig(
ai=ProviderConfig(
provider_id=provider.id,
@@ -47,6 +164,7 @@ def run_setup() -> None:
general=GeneralConfig(user_name=user_name, purpose=purpose),
)
save_config(cfg)
_delete_draft()
_suggest_plugins(use_cases)
@@ -59,6 +177,14 @@ def run_setup() -> None:
border_style="green",
))
except SystemExit:
if state.get("completed_steps"):
console.print()
console.print(
" [dim]Setup paused — run [bold]pyra setup[/bold] to resume.[/dim]"
)
raise
def _collect_user_profile() -> tuple[str, str, list[str]]:
console.print("[bold]Let's personalise your setup.[/bold]")
@@ -135,6 +261,7 @@ def _choose_provider() -> Provider:
if provider.connectivity_check:
_check_local_server(provider)
_show_local_model_status(provider)
return provider
@@ -250,6 +377,18 @@ def _check_local_server(provider: Provider) -> None:
# "retry" → loop
def _show_local_model_status(provider: Provider) -> None:
"""Print a one-line status showing which models are currently loaded."""
models = _fetch_local_models(provider)
if not models:
console.print(" [yellow]No model currently loaded[/yellow]")
elif len(models) == 1:
console.print(f" [green]Loaded model:[/green] {models[0]}")
else:
names = ", ".join(models)
console.print(f" [green]{len(models)} models loaded:[/green] {names}")
def _fetch_local_models(provider: Provider) -> list[str]:
"""Return currently loaded/available models from a local provider's API."""
if not provider.base_url:
@@ -361,7 +500,7 @@ def _collect_api_key(provider: Provider) -> None:
console.print(" [green]✓ Key stored in vault[/green]")
def _test_connection(provider: Provider, model: str) -> None:
def _test_connection(provider: Provider, model: str) -> str:
from pyra.vault.reader import get_key
while True:
@@ -381,7 +520,7 @@ def _test_connection(provider: Provider, model: str) -> None:
litellm.completion(**kwargs)
console.print("[green]✓ Connection OK[/green]")
return
return model
except Exception as exc:
label, hint = _classify_error(exc)
@@ -393,8 +532,15 @@ def _test_connection(provider: Provider, model: str) -> None:
border_style="red",
))
is_auth_error = "AuthenticationError" in type(exc).__name__
exc_name = type(exc).__name__
is_auth_error = "AuthenticationError" in exc_name
is_model_error = any(
kw in exc_name for kw in ("NotFoundError", "BadRequestError", "InvalidRequest")
)
choices = [questionary.Choice("Retry", value="retry")]
if is_model_error:
choices.append(questionary.Choice("Change model", value="change_model"))
if provider.requires_key and is_auth_error:
choices.append(questionary.Choice("Re-enter API key", value="rekey"))
choices += [
@@ -413,7 +559,9 @@ def _test_connection(provider: Provider, model: str) -> None:
console.print(
" [dim]Test skipped — run [bold]pyra setup[/bold] again if chat doesn't work.[/dim]"
)
return
if action == "rekey":
return model
if action == "change_model":
model = _choose_model(provider)
elif action == "rekey":
_collect_api_key(provider)
# "retry" or after "rekey" → loop
# loop → retry (with possibly new model or key)
+113
View File
@@ -329,3 +329,116 @@ def test_check_local_server_continue_returns(monkeypatch):
lambda *a, **kw: MagicMock(ask=lambda: "continue"))
wiz._check_local_server(get_provider("lmstudio")) # must return without raising
# ── draft persistence ─────────────────────────────────────────────────────────
def test_save_and_load_draft(tmp_pyra_home):
import pyra.setup.wizard as wiz
state = {"completed_steps": ["profile"], "user_name": "Alice"}
wiz._save_draft(state)
loaded = wiz._load_draft()
assert loaded == state
def test_load_draft_returns_none_when_no_file(tmp_pyra_home):
import pyra.setup.wizard as wiz
assert wiz._load_draft() is None
def test_delete_draft_removes_file(tmp_pyra_home):
import pyra.setup.wizard as wiz
wiz._save_draft({"completed_steps": []})
wiz._delete_draft()
assert wiz._load_draft() is None
def test_delete_draft_is_idempotent(tmp_pyra_home):
import pyra.setup.wizard as wiz
wiz._delete_draft()
wiz._delete_draft()
def test_mark_step_done_appends_once(tmp_pyra_home):
import pyra.setup.wizard as wiz
state = {}
wiz._mark_step_done(state, "profile")
wiz._mark_step_done(state, "profile")
assert state["completed_steps"].count("profile") == 1
def test_draft_file_has_correct_permissions(tmp_pyra_home):
import os
import pyra.setup.wizard as wiz
wiz._save_draft({"completed_steps": []})
path = wiz._draft_path()
if os.name != "nt":
mode = oct(path.stat().st_mode)[-3:]
assert mode == "600"
# ── _show_local_model_status ──────────────────────────────────────────────────
def test_show_local_model_status_one_model(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: ["gemma-4b"])
printed = []
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
wiz._show_local_model_status(get_provider("lmstudio"))
assert any("gemma-4b" in s for s in printed)
def test_show_local_model_status_none(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: [])
printed = []
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
wiz._show_local_model_status(get_provider("lmstudio"))
assert any("No model" in s for s in printed)
def test_show_local_model_status_multiple(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: ["a", "b", "c"])
printed = []
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
wiz._show_local_model_status(get_provider("lmstudio"))
combined = " ".join(printed)
assert "3" in combined or ("a" in combined and "b" in combined)
# ── _test_connection model re-entry ───────────────────────────────────────────
def test_test_connection_change_model(tmp_pyra_home, monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
call_count = {"n": 0}
class FakeNotFound(Exception):
pass
FakeNotFound.__module__ = "litellm.exceptions"
FakeNotFound.__name__ = "NotFoundError"
def fake_completion(**kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
raise FakeNotFound("model not found")
import litellm
monkeypatch.setattr(litellm, "completion", fake_completion)
monkeypatch.setattr(wiz.console, "print", _silent_print)
monkeypatch.setattr(wiz.questionary, "select",
lambda *a, **kw: MagicMock(ask=lambda: "change_model"))
monkeypatch.setattr(wiz.questionary, "text",
lambda *a, **kw: MagicMock(ask=lambda: "new-model"))
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: [])
monkeypatch.setattr("pyra.vault.reader.get_key", lambda k: "sk-test")
provider = get_provider("anthropic")
result = wiz._test_connection(provider, "old-model")
assert result == "new-model"
assert call_count["n"] == 2