feat(setup): model re-entry, status indicator, and resumable setup wizard
- _test_connection() now returns the (possibly changed) model name and offers a "Change model" option when the error is model-related - _show_local_model_status() prints which models are currently loaded immediately after selecting a local provider - Draft persistence: each completed wizard step is saved to ~/.pyra/setup.draft.json (chmod 600); on the next run a yellow panel summarises progress and offers [Resume / Start fresh]; draft is deleted on successful completion or Ctrl-C with no completed steps Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+178
-30
@@ -1,3 +1,6 @@
|
||||
import contextlib
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import questionary
|
||||
from rich.console import Console
|
||||
@@ -7,6 +10,7 @@ from rich.text import Text
|
||||
from pyra.config.manager import save_config
|
||||
from pyra.config.schema import GeneralConfig, ProviderConfig, PyraConfig
|
||||
from pyra.setup.providers import PROVIDERS, Provider, get_provider
|
||||
from pyra.utils.paths import pyra_home, safe_chmod
|
||||
|
||||
console = Console()
|
||||
|
||||
@@ -20,6 +24,72 @@ _USE_CASE_PLUGINS: dict[str, list[str]] = {
|
||||
}
|
||||
|
||||
|
||||
_DRAFT_FILE = "setup.draft.json"
|
||||
|
||||
|
||||
def _draft_path():
|
||||
return pyra_home() / _DRAFT_FILE
|
||||
|
||||
|
||||
def _save_draft(state: dict) -> None:
|
||||
path = _draft_path()
|
||||
path.write_text(json.dumps(state, indent=2))
|
||||
safe_chmod(path, 0o600)
|
||||
|
||||
|
||||
def _load_draft() -> dict | None:
|
||||
path = _draft_path()
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
return json.loads(path.read_text())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _delete_draft() -> None:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
_draft_path().unlink()
|
||||
|
||||
|
||||
def _mark_step_done(state: dict, step: str) -> None:
|
||||
state.setdefault("completed_steps", [])
|
||||
if step not in state["completed_steps"]:
|
||||
state["completed_steps"].append(step)
|
||||
|
||||
|
||||
def _offer_resume(draft: dict) -> bool:
|
||||
"""Show a summary of the incomplete setup and ask Resume / Start fresh."""
|
||||
completed = draft.get("completed_steps", [])
|
||||
step_display = {
|
||||
"profile": f"Profile: {draft.get('user_name', '?')}",
|
||||
"provider": f"Provider: {draft.get('provider_id', '?')}",
|
||||
"model": f"Model: {draft.get('model', '?')}",
|
||||
"api_key": "API key: stored in vault",
|
||||
"connection": "Connection: test passed",
|
||||
}
|
||||
lines = ["[bold]An incomplete setup was found.[/bold]\n"]
|
||||
for step, label in step_display.items():
|
||||
if step in completed:
|
||||
lines.append(f" [green]✓[/green] {label}")
|
||||
else:
|
||||
lines.append(f" [dim]○ {label.split(':')[0].strip()}: pending[/dim]")
|
||||
|
||||
console.print(Panel("\n".join(lines), title="Incomplete setup", border_style="yellow"))
|
||||
console.print()
|
||||
|
||||
action = questionary.select(
|
||||
"What would you like to do?",
|
||||
choices=[
|
||||
questionary.Choice("Resume from where you left off", value="resume"),
|
||||
questionary.Choice("Start fresh", value="fresh"),
|
||||
],
|
||||
).ask()
|
||||
if action is None:
|
||||
raise SystemExit(0)
|
||||
return action == "resume"
|
||||
|
||||
|
||||
def run_setup() -> None:
|
||||
console.print(Panel(
|
||||
Text("Welcome to Pyra Setup", justify="center", style="bold cyan"),
|
||||
@@ -28,36 +98,92 @@ def run_setup() -> None:
|
||||
))
|
||||
console.print()
|
||||
|
||||
user_name, purpose, use_cases = _collect_user_profile()
|
||||
state: dict = {}
|
||||
draft = _load_draft()
|
||||
if draft:
|
||||
if _offer_resume(draft):
|
||||
state = draft
|
||||
else:
|
||||
_delete_draft()
|
||||
|
||||
provider = _choose_provider()
|
||||
model = _choose_model(provider)
|
||||
try:
|
||||
# ── Step 1: profile ────────────────────────────────────────────────
|
||||
if "profile" in state.get("completed_steps", []):
|
||||
user_name = state["user_name"]
|
||||
purpose = state["purpose"]
|
||||
use_cases = state["use_cases"]
|
||||
console.print(f" [dim]✓ Profile: {user_name}[/dim]")
|
||||
else:
|
||||
user_name, purpose, use_cases = _collect_user_profile()
|
||||
state.update(user_name=user_name, purpose=purpose, use_cases=use_cases)
|
||||
_mark_step_done(state, "profile")
|
||||
_save_draft(state)
|
||||
|
||||
if provider.requires_key:
|
||||
_collect_api_key(provider)
|
||||
# ── Step 2: provider ───────────────────────────────────────────────
|
||||
if "provider" in state.get("completed_steps", []):
|
||||
provider = get_provider(state["provider_id"])
|
||||
console.print(f" [dim]✓ Provider: {provider.display_name}[/dim]")
|
||||
else:
|
||||
provider = _choose_provider()
|
||||
state.update(provider_id=provider.id)
|
||||
_mark_step_done(state, "provider")
|
||||
_save_draft(state)
|
||||
|
||||
_test_connection(provider, model)
|
||||
# ── Step 3: model ──────────────────────────────────────────────────
|
||||
if "model" in state.get("completed_steps", []):
|
||||
model = state["model"]
|
||||
console.print(f" [dim]✓ Model: {model}[/dim]")
|
||||
else:
|
||||
model = _choose_model(provider)
|
||||
state.update(model=model)
|
||||
_mark_step_done(state, "model")
|
||||
_save_draft(state)
|
||||
|
||||
cfg = PyraConfig(
|
||||
ai=ProviderConfig(
|
||||
provider_id=provider.id,
|
||||
model=model,
|
||||
base_url=provider.base_url,
|
||||
),
|
||||
general=GeneralConfig(user_name=user_name, purpose=purpose),
|
||||
)
|
||||
save_config(cfg)
|
||||
# ── Step 4: API key ────────────────────────────────────────────────
|
||||
if "api_key" not in state.get("completed_steps", []) and provider.requires_key:
|
||||
from pyra.vault.reader import get_key as _get_key
|
||||
if not _get_key(provider.id):
|
||||
_collect_api_key(provider)
|
||||
_mark_step_done(state, "api_key")
|
||||
_save_draft(state)
|
||||
|
||||
_suggest_plugins(use_cases)
|
||||
# ── Step 5: connection test ────────────────────────────────────────
|
||||
if "connection" not in state.get("completed_steps", []):
|
||||
model = _test_connection(provider, model)
|
||||
state["model"] = model
|
||||
_mark_step_done(state, "connection")
|
||||
_save_draft(state)
|
||||
|
||||
console.print()
|
||||
console.print(Panel(
|
||||
f"[green]Setup complete![/green]\n\n"
|
||||
f"Provider: [bold]{provider.display_name}[/bold]\n"
|
||||
f"Model: [bold]{model}[/bold]\n\n"
|
||||
"Run [bold cyan]pyra chat[/bold cyan] to start talking.",
|
||||
border_style="green",
|
||||
))
|
||||
# ── Finalise ───────────────────────────────────────────────────────
|
||||
cfg = PyraConfig(
|
||||
ai=ProviderConfig(
|
||||
provider_id=provider.id,
|
||||
model=model,
|
||||
base_url=provider.base_url,
|
||||
),
|
||||
general=GeneralConfig(user_name=user_name, purpose=purpose),
|
||||
)
|
||||
save_config(cfg)
|
||||
_delete_draft()
|
||||
|
||||
_suggest_plugins(use_cases)
|
||||
|
||||
console.print()
|
||||
console.print(Panel(
|
||||
f"[green]Setup complete![/green]\n\n"
|
||||
f"Provider: [bold]{provider.display_name}[/bold]\n"
|
||||
f"Model: [bold]{model}[/bold]\n\n"
|
||||
"Run [bold cyan]pyra chat[/bold cyan] to start talking.",
|
||||
border_style="green",
|
||||
))
|
||||
|
||||
except SystemExit:
|
||||
if state.get("completed_steps"):
|
||||
console.print()
|
||||
console.print(
|
||||
" [dim]Setup paused — run [bold]pyra setup[/bold] to resume.[/dim]"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def _collect_user_profile() -> tuple[str, str, list[str]]:
|
||||
@@ -135,6 +261,7 @@ def _choose_provider() -> Provider:
|
||||
|
||||
if provider.connectivity_check:
|
||||
_check_local_server(provider)
|
||||
_show_local_model_status(provider)
|
||||
|
||||
return provider
|
||||
|
||||
@@ -250,6 +377,18 @@ def _check_local_server(provider: Provider) -> None:
|
||||
# "retry" → loop
|
||||
|
||||
|
||||
def _show_local_model_status(provider: Provider) -> None:
|
||||
"""Print a one-line status showing which models are currently loaded."""
|
||||
models = _fetch_local_models(provider)
|
||||
if not models:
|
||||
console.print(" [yellow]No model currently loaded[/yellow]")
|
||||
elif len(models) == 1:
|
||||
console.print(f" [green]Loaded model:[/green] {models[0]}")
|
||||
else:
|
||||
names = ", ".join(models)
|
||||
console.print(f" [green]{len(models)} models loaded:[/green] {names}")
|
||||
|
||||
|
||||
def _fetch_local_models(provider: Provider) -> list[str]:
|
||||
"""Return currently loaded/available models from a local provider's API."""
|
||||
if not provider.base_url:
|
||||
@@ -361,7 +500,7 @@ def _collect_api_key(provider: Provider) -> None:
|
||||
console.print(" [green]✓ Key stored in vault[/green]")
|
||||
|
||||
|
||||
def _test_connection(provider: Provider, model: str) -> None:
|
||||
def _test_connection(provider: Provider, model: str) -> str:
|
||||
from pyra.vault.reader import get_key
|
||||
|
||||
while True:
|
||||
@@ -381,7 +520,7 @@ def _test_connection(provider: Provider, model: str) -> None:
|
||||
|
||||
litellm.completion(**kwargs)
|
||||
console.print("[green]✓ Connection OK[/green]")
|
||||
return
|
||||
return model
|
||||
|
||||
except Exception as exc:
|
||||
label, hint = _classify_error(exc)
|
||||
@@ -393,8 +532,15 @@ def _test_connection(provider: Provider, model: str) -> None:
|
||||
border_style="red",
|
||||
))
|
||||
|
||||
is_auth_error = "AuthenticationError" in type(exc).__name__
|
||||
exc_name = type(exc).__name__
|
||||
is_auth_error = "AuthenticationError" in exc_name
|
||||
is_model_error = any(
|
||||
kw in exc_name for kw in ("NotFoundError", "BadRequestError", "InvalidRequest")
|
||||
)
|
||||
|
||||
choices = [questionary.Choice("Retry", value="retry")]
|
||||
if is_model_error:
|
||||
choices.append(questionary.Choice("Change model", value="change_model"))
|
||||
if provider.requires_key and is_auth_error:
|
||||
choices.append(questionary.Choice("Re-enter API key", value="rekey"))
|
||||
choices += [
|
||||
@@ -413,7 +559,9 @@ def _test_connection(provider: Provider, model: str) -> None:
|
||||
console.print(
|
||||
" [dim]Test skipped — run [bold]pyra setup[/bold] again if chat doesn't work.[/dim]"
|
||||
)
|
||||
return
|
||||
if action == "rekey":
|
||||
return model
|
||||
if action == "change_model":
|
||||
model = _choose_model(provider)
|
||||
elif action == "rekey":
|
||||
_collect_api_key(provider)
|
||||
# "retry" or after "rekey" → loop
|
||||
# loop → retry (with possibly new model or key)
|
||||
|
||||
Reference in New Issue
Block a user