diff --git a/src/pyra/setup/wizard.py b/src/pyra/setup/wizard.py index 4001d66..a41f239 100644 --- a/src/pyra/setup/wizard.py +++ b/src/pyra/setup/wizard.py @@ -139,18 +139,115 @@ def _choose_provider() -> Provider: return provider +def _classify_error(exc: Exception) -> tuple[str, str]: + """Return (short_label, resolution_hint) for a provider or network error.""" + name = type(exc).__name__ + module = type(exc).__module__ or "" + is_llm = "litellm" in module or "openai" in module + + if is_llm: + if "AuthenticationError" in name: + return ( + "Invalid API key", + "The provider rejected your API key.\n" + "Double-check it on the provider dashboard and re-enter it.", + ) + if "NotFoundError" in name: + return ( + "Model not found", + "The model name doesn't exist for this provider.\n" + "Check the exact model identifier on the provider's model list.", + ) + if "RateLimitError" in name: + return ( + "Rate limit reached", + "You've exceeded the provider's request rate.\n" + "Wait a few seconds and retry.", + ) + if "ServiceUnavailable" in name: + return ( + "Service temporarily unavailable", + "The provider's servers returned a 5xx error.\n" + "This is usually transient — wait a minute and retry.", + ) + if "APIConnectionError" in name or "ConnectError" in name: + return ( + "Cannot reach provider", + "A network error prevented the connection.\n" + "Check your internet connection and firewall settings.", + ) + if "Timeout" in name: + return ( + "Request timed out", + "The provider did not respond in time.\n" + "The service may be overloaded — retry in a moment.", + ) + if "BadRequestError" in name or "InvalidRequest" in name: + return ( + "Bad request", + "The request was rejected — the model name or parameters may be wrong.\n" + "Verify the exact model identifier.", + ) + return ("Provider error", str(exc)[:300]) + + if "httpx" in module: + if "ConnectError" in name or "ConnectTimeout" in name: + return ( + "Server not reachable", + "Could not connect to the local server.\n" + "Make sure it is running and listening on the expected address.", + ) + if "Timeout" in name: + return ( + "Connection timed out", + "The local server did not respond in time.\n" + "It may still be starting up — wait a moment and retry.", + ) + if "HTTPStatusError" in name: + code = getattr(getattr(exc, "response", None), "status_code", 0) + if code == 401: + return ("Unauthorized (401)", "The server requires credentials that were not provided.") + if code == 404: + return ("Endpoint not found (404)", "The API endpoint was not found — check the server version.") + return (f"HTTP {code}", f"The server returned an unexpected error (HTTP {code}).") + + return ("Unexpected error", str(exc)[:300]) + + def _check_local_server(provider: Provider) -> None: - console.print(f" Checking connection to [bold]{provider.display_name}[/bold]...", end=" ") - try: - resp = httpx.get(provider.connectivity_check, timeout=3.0) - resp.raise_for_status() - console.print("[green]✓[/green]") - except Exception: - console.print("[yellow]✗ (server not reachable)[/yellow]") + while True: console.print( - f" [yellow]Warning:[/yellow] Could not reach {provider.base_url}.\n" - f" Make sure {provider.display_name} is running before using Pyra." + f" Checking connection to [bold]{provider.display_name}[/bold]...", end=" " ) + try: + resp = httpx.get(provider.connectivity_check, timeout=3.0) + resp.raise_for_status() + console.print("[green]✓[/green]") + return + except Exception as exc: + label, hint = _classify_error(exc) + console.print("[yellow]✗[/yellow]") + console.print() + console.print(Panel( + f"[bold yellow]{label}[/bold yellow]\n\n{hint}", + title="Connection problem", + border_style="yellow", + )) + action = questionary.select( + "How would you like to proceed?", + choices=[ + questionary.Choice("Retry", value="retry"), + questionary.Choice( + "Continue anyway (model list may be unavailable)", value="continue" + ), + questionary.Choice("Abort setup", value="abort"), + ], + ).ask() + if action is None or action == "abort": + raise SystemExit(0) + if action == "continue": + return + # "retry" → loop def _fetch_local_models(provider: Provider) -> list[str]: @@ -267,23 +364,56 @@ def _collect_api_key(provider: Provider) -> None: def _test_connection(provider: Provider, model: str) -> None: from pyra.vault.reader import get_key - console.print("\n Running test call...", end=" ") - try: - import litellm + while True: + console.print("\n Running connection test...", end=" ") + try: + import litellm - # Local providers don't need a real key but litellm still requires the field - api_key = get_key(provider.id) if provider.requires_key else "local" - kwargs: dict = { - "model": f"{provider.litellm_prefix}{model}", - "messages": [{"role": "user", "content": "Reply with exactly: OK"}], - "max_tokens": 10, - "api_key": api_key, - } - if provider.base_url: - kwargs["api_base"] = provider.base_url + api_key = get_key(provider.id) if provider.requires_key else "local" + kwargs: dict = { + "model": f"{provider.litellm_prefix}{model}", + "messages": [{"role": "user", "content": "Reply with exactly: OK"}], + "max_tokens": 10, + "api_key": api_key, + } + if provider.base_url: + kwargs["api_base"] = provider.base_url - litellm.completion(**kwargs) - console.print("[green]✓ Connection OK[/green]") - except Exception as exc: - console.print(f"[yellow]✗ Test call failed: {exc}[/yellow]") - console.print(" [dim]You can still proceed — check your config with 'pyra setup' again.[/dim]") + litellm.completion(**kwargs) + console.print("[green]✓ Connection OK[/green]") + return + + except Exception as exc: + label, hint = _classify_error(exc) + console.print("[red]✗[/red]") + console.print() + console.print(Panel( + f"[bold red]{label}[/bold red]\n\n{hint}", + title="Test call failed", + border_style="red", + )) + + is_auth_error = "AuthenticationError" in type(exc).__name__ + choices = [questionary.Choice("Retry", value="retry")] + if provider.requires_key and is_auth_error: + choices.append(questionary.Choice("Re-enter API key", value="rekey")) + choices += [ + questionary.Choice("Skip test and continue setup", value="skip"), + questionary.Choice("Abort setup", value="abort"), + ] + + action = questionary.select( + "How would you like to proceed?", + choices=choices, + ).ask() + + if action is None or action == "abort": + raise SystemExit(0) + if action == "skip": + console.print( + " [dim]Test skipped — run [bold]pyra setup[/bold] again if chat doesn't work.[/dim]" + ) + return + if action == "rekey": + _collect_api_key(provider) + # "retry" or after "rekey" → loop