feat(setup): error controller with retry loops in setup wizard
Add _classify_error() that maps litellm/httpx exceptions to human-readable labels and resolution hints without requiring a top-level litellm import. _check_local_server() now loops with Retry / Continue anyway / Abort instead of printing a one-shot warning and silently continuing. _test_connection() now loops with Retry / Re-enter API key (auth errors only) / Skip test / Abort instead of printing the raw exception string and falling through. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+157
-27
@@ -139,18 +139,115 @@ def _choose_provider() -> Provider:
|
||||
return provider
|
||||
|
||||
|
||||
def _classify_error(exc: Exception) -> tuple[str, str]:
|
||||
"""Return (short_label, resolution_hint) for a provider or network error."""
|
||||
name = type(exc).__name__
|
||||
module = type(exc).__module__ or ""
|
||||
is_llm = "litellm" in module or "openai" in module
|
||||
|
||||
if is_llm:
|
||||
if "AuthenticationError" in name:
|
||||
return (
|
||||
"Invalid API key",
|
||||
"The provider rejected your API key.\n"
|
||||
"Double-check it on the provider dashboard and re-enter it.",
|
||||
)
|
||||
if "NotFoundError" in name:
|
||||
return (
|
||||
"Model not found",
|
||||
"The model name doesn't exist for this provider.\n"
|
||||
"Check the exact model identifier on the provider's model list.",
|
||||
)
|
||||
if "RateLimitError" in name:
|
||||
return (
|
||||
"Rate limit reached",
|
||||
"You've exceeded the provider's request rate.\n"
|
||||
"Wait a few seconds and retry.",
|
||||
)
|
||||
if "ServiceUnavailable" in name:
|
||||
return (
|
||||
"Service temporarily unavailable",
|
||||
"The provider's servers returned a 5xx error.\n"
|
||||
"This is usually transient — wait a minute and retry.",
|
||||
)
|
||||
if "APIConnectionError" in name or "ConnectError" in name:
|
||||
return (
|
||||
"Cannot reach provider",
|
||||
"A network error prevented the connection.\n"
|
||||
"Check your internet connection and firewall settings.",
|
||||
)
|
||||
if "Timeout" in name:
|
||||
return (
|
||||
"Request timed out",
|
||||
"The provider did not respond in time.\n"
|
||||
"The service may be overloaded — retry in a moment.",
|
||||
)
|
||||
if "BadRequestError" in name or "InvalidRequest" in name:
|
||||
return (
|
||||
"Bad request",
|
||||
"The request was rejected — the model name or parameters may be wrong.\n"
|
||||
"Verify the exact model identifier.",
|
||||
)
|
||||
return ("Provider error", str(exc)[:300])
|
||||
|
||||
if "httpx" in module:
|
||||
if "ConnectError" in name or "ConnectTimeout" in name:
|
||||
return (
|
||||
"Server not reachable",
|
||||
"Could not connect to the local server.\n"
|
||||
"Make sure it is running and listening on the expected address.",
|
||||
)
|
||||
if "Timeout" in name:
|
||||
return (
|
||||
"Connection timed out",
|
||||
"The local server did not respond in time.\n"
|
||||
"It may still be starting up — wait a moment and retry.",
|
||||
)
|
||||
if "HTTPStatusError" in name:
|
||||
code = getattr(getattr(exc, "response", None), "status_code", 0)
|
||||
if code == 401:
|
||||
return ("Unauthorized (401)", "The server requires credentials that were not provided.")
|
||||
if code == 404:
|
||||
return ("Endpoint not found (404)", "The API endpoint was not found — check the server version.")
|
||||
return (f"HTTP {code}", f"The server returned an unexpected error (HTTP {code}).")
|
||||
|
||||
return ("Unexpected error", str(exc)[:300])
|
||||
|
||||
|
||||
def _check_local_server(provider: Provider) -> None:
|
||||
console.print(f" Checking connection to [bold]{provider.display_name}[/bold]...", end=" ")
|
||||
try:
|
||||
resp = httpx.get(provider.connectivity_check, timeout=3.0)
|
||||
resp.raise_for_status()
|
||||
console.print("[green]✓[/green]")
|
||||
except Exception:
|
||||
console.print("[yellow]✗ (server not reachable)[/yellow]")
|
||||
while True:
|
||||
console.print(
|
||||
f" [yellow]Warning:[/yellow] Could not reach {provider.base_url}.\n"
|
||||
f" Make sure {provider.display_name} is running before using Pyra."
|
||||
f" Checking connection to [bold]{provider.display_name}[/bold]...", end=" "
|
||||
)
|
||||
try:
|
||||
resp = httpx.get(provider.connectivity_check, timeout=3.0)
|
||||
resp.raise_for_status()
|
||||
console.print("[green]✓[/green]")
|
||||
return
|
||||
except Exception as exc:
|
||||
label, hint = _classify_error(exc)
|
||||
console.print("[yellow]✗[/yellow]")
|
||||
console.print()
|
||||
console.print(Panel(
|
||||
f"[bold yellow]{label}[/bold yellow]\n\n{hint}",
|
||||
title="Connection problem",
|
||||
border_style="yellow",
|
||||
))
|
||||
action = questionary.select(
|
||||
"How would you like to proceed?",
|
||||
choices=[
|
||||
questionary.Choice("Retry", value="retry"),
|
||||
questionary.Choice(
|
||||
"Continue anyway (model list may be unavailable)", value="continue"
|
||||
),
|
||||
questionary.Choice("Abort setup", value="abort"),
|
||||
],
|
||||
).ask()
|
||||
if action is None or action == "abort":
|
||||
raise SystemExit(0)
|
||||
if action == "continue":
|
||||
return
|
||||
# "retry" → loop
|
||||
|
||||
|
||||
def _fetch_local_models(provider: Provider) -> list[str]:
|
||||
@@ -267,23 +364,56 @@ def _collect_api_key(provider: Provider) -> None:
|
||||
def _test_connection(provider: Provider, model: str) -> None:
|
||||
from pyra.vault.reader import get_key
|
||||
|
||||
console.print("\n Running test call...", end=" ")
|
||||
try:
|
||||
import litellm
|
||||
while True:
|
||||
console.print("\n Running connection test...", end=" ")
|
||||
try:
|
||||
import litellm
|
||||
|
||||
# Local providers don't need a real key but litellm still requires the field
|
||||
api_key = get_key(provider.id) if provider.requires_key else "local"
|
||||
kwargs: dict = {
|
||||
"model": f"{provider.litellm_prefix}{model}",
|
||||
"messages": [{"role": "user", "content": "Reply with exactly: OK"}],
|
||||
"max_tokens": 10,
|
||||
"api_key": api_key,
|
||||
}
|
||||
if provider.base_url:
|
||||
kwargs["api_base"] = provider.base_url
|
||||
api_key = get_key(provider.id) if provider.requires_key else "local"
|
||||
kwargs: dict = {
|
||||
"model": f"{provider.litellm_prefix}{model}",
|
||||
"messages": [{"role": "user", "content": "Reply with exactly: OK"}],
|
||||
"max_tokens": 10,
|
||||
"api_key": api_key,
|
||||
}
|
||||
if provider.base_url:
|
||||
kwargs["api_base"] = provider.base_url
|
||||
|
||||
litellm.completion(**kwargs)
|
||||
console.print("[green]✓ Connection OK[/green]")
|
||||
except Exception as exc:
|
||||
console.print(f"[yellow]✗ Test call failed: {exc}[/yellow]")
|
||||
console.print(" [dim]You can still proceed — check your config with 'pyra setup' again.[/dim]")
|
||||
litellm.completion(**kwargs)
|
||||
console.print("[green]✓ Connection OK[/green]")
|
||||
return
|
||||
|
||||
except Exception as exc:
|
||||
label, hint = _classify_error(exc)
|
||||
console.print("[red]✗[/red]")
|
||||
console.print()
|
||||
console.print(Panel(
|
||||
f"[bold red]{label}[/bold red]\n\n{hint}",
|
||||
title="Test call failed",
|
||||
border_style="red",
|
||||
))
|
||||
|
||||
is_auth_error = "AuthenticationError" in type(exc).__name__
|
||||
choices = [questionary.Choice("Retry", value="retry")]
|
||||
if provider.requires_key and is_auth_error:
|
||||
choices.append(questionary.Choice("Re-enter API key", value="rekey"))
|
||||
choices += [
|
||||
questionary.Choice("Skip test and continue setup", value="skip"),
|
||||
questionary.Choice("Abort setup", value="abort"),
|
||||
]
|
||||
|
||||
action = questionary.select(
|
||||
"How would you like to proceed?",
|
||||
choices=choices,
|
||||
).ask()
|
||||
|
||||
if action is None or action == "abort":
|
||||
raise SystemExit(0)
|
||||
if action == "skip":
|
||||
console.print(
|
||||
" [dim]Test skipped — run [bold]pyra setup[/bold] again if chat doesn't work.[/dim]"
|
||||
)
|
||||
return
|
||||
if action == "rekey":
|
||||
_collect_api_key(provider)
|
||||
# "retry" or after "rekey" → loop
|
||||
|
||||
Reference in New Issue
Block a user