Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| efc589cc56 | |||
| 9a392410e7 |
+141
-11
@@ -139,18 +139,115 @@ def _choose_provider() -> Provider:
|
|||||||
return provider
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_error(exc: Exception) -> tuple[str, str]:
|
||||||
|
"""Return (short_label, resolution_hint) for a provider or network error."""
|
||||||
|
name = type(exc).__name__
|
||||||
|
module = type(exc).__module__ or ""
|
||||||
|
is_llm = "litellm" in module or "openai" in module
|
||||||
|
|
||||||
|
if is_llm:
|
||||||
|
if "AuthenticationError" in name:
|
||||||
|
return (
|
||||||
|
"Invalid API key",
|
||||||
|
"The provider rejected your API key.\n"
|
||||||
|
"Double-check it on the provider dashboard and re-enter it.",
|
||||||
|
)
|
||||||
|
if "NotFoundError" in name:
|
||||||
|
return (
|
||||||
|
"Model not found",
|
||||||
|
"The model name doesn't exist for this provider.\n"
|
||||||
|
"Check the exact model identifier on the provider's model list.",
|
||||||
|
)
|
||||||
|
if "RateLimitError" in name:
|
||||||
|
return (
|
||||||
|
"Rate limit reached",
|
||||||
|
"You've exceeded the provider's request rate.\n"
|
||||||
|
"Wait a few seconds and retry.",
|
||||||
|
)
|
||||||
|
if "ServiceUnavailable" in name:
|
||||||
|
return (
|
||||||
|
"Service temporarily unavailable",
|
||||||
|
"The provider's servers returned a 5xx error.\n"
|
||||||
|
"This is usually transient — wait a minute and retry.",
|
||||||
|
)
|
||||||
|
if "APIConnectionError" in name or "ConnectError" in name:
|
||||||
|
return (
|
||||||
|
"Cannot reach provider",
|
||||||
|
"A network error prevented the connection.\n"
|
||||||
|
"Check your internet connection and firewall settings.",
|
||||||
|
)
|
||||||
|
if "Timeout" in name:
|
||||||
|
return (
|
||||||
|
"Request timed out",
|
||||||
|
"The provider did not respond in time.\n"
|
||||||
|
"The service may be overloaded — retry in a moment.",
|
||||||
|
)
|
||||||
|
if "BadRequestError" in name or "InvalidRequest" in name:
|
||||||
|
return (
|
||||||
|
"Bad request",
|
||||||
|
"The request was rejected — the model name or parameters may be wrong.\n"
|
||||||
|
"Verify the exact model identifier.",
|
||||||
|
)
|
||||||
|
return ("Provider error", str(exc)[:300])
|
||||||
|
|
||||||
|
if "httpx" in module:
|
||||||
|
if "ConnectError" in name or "ConnectTimeout" in name:
|
||||||
|
return (
|
||||||
|
"Server not reachable",
|
||||||
|
"Could not connect to the local server.\n"
|
||||||
|
"Make sure it is running and listening on the expected address.",
|
||||||
|
)
|
||||||
|
if "Timeout" in name:
|
||||||
|
return (
|
||||||
|
"Connection timed out",
|
||||||
|
"The local server did not respond in time.\n"
|
||||||
|
"It may still be starting up — wait a moment and retry.",
|
||||||
|
)
|
||||||
|
if "HTTPStatusError" in name:
|
||||||
|
code = getattr(getattr(exc, "response", None), "status_code", 0)
|
||||||
|
if code == 401:
|
||||||
|
return ("Unauthorized (401)", "The server requires credentials that were not provided.")
|
||||||
|
if code == 404:
|
||||||
|
return ("Endpoint not found (404)", "The API endpoint was not found — check the server version.")
|
||||||
|
return (f"HTTP {code}", f"The server returned an unexpected error (HTTP {code}).")
|
||||||
|
|
||||||
|
return ("Unexpected error", str(exc)[:300])
|
||||||
|
|
||||||
|
|
||||||
def _check_local_server(provider: Provider) -> None:
|
def _check_local_server(provider: Provider) -> None:
|
||||||
console.print(f" Checking connection to [bold]{provider.display_name}[/bold]...", end=" ")
|
while True:
|
||||||
|
console.print(
|
||||||
|
f" Checking connection to [bold]{provider.display_name}[/bold]...", end=" "
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
resp = httpx.get(provider.connectivity_check, timeout=3.0)
|
resp = httpx.get(provider.connectivity_check, timeout=3.0)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
console.print("[green]✓[/green]")
|
console.print("[green]✓[/green]")
|
||||||
except Exception:
|
return
|
||||||
console.print("[yellow]✗ (server not reachable)[/yellow]")
|
except Exception as exc:
|
||||||
console.print(
|
label, hint = _classify_error(exc)
|
||||||
f" [yellow]Warning:[/yellow] Could not reach {provider.base_url}.\n"
|
console.print("[yellow]✗[/yellow]")
|
||||||
f" Make sure {provider.display_name} is running before using Pyra."
|
console.print()
|
||||||
)
|
console.print(Panel(
|
||||||
|
f"[bold yellow]{label}[/bold yellow]\n\n{hint}",
|
||||||
|
title="Connection problem",
|
||||||
|
border_style="yellow",
|
||||||
|
))
|
||||||
|
action = questionary.select(
|
||||||
|
"How would you like to proceed?",
|
||||||
|
choices=[
|
||||||
|
questionary.Choice("Retry", value="retry"),
|
||||||
|
questionary.Choice(
|
||||||
|
"Continue anyway (model list may be unavailable)", value="continue"
|
||||||
|
),
|
||||||
|
questionary.Choice("Abort setup", value="abort"),
|
||||||
|
],
|
||||||
|
).ask()
|
||||||
|
if action is None or action == "abort":
|
||||||
|
raise SystemExit(0)
|
||||||
|
if action == "continue":
|
||||||
|
return
|
||||||
|
# "retry" → loop
|
||||||
|
|
||||||
|
|
||||||
def _fetch_local_models(provider: Provider) -> list[str]:
|
def _fetch_local_models(provider: Provider) -> list[str]:
|
||||||
@@ -267,11 +364,11 @@ def _collect_api_key(provider: Provider) -> None:
|
|||||||
def _test_connection(provider: Provider, model: str) -> None:
|
def _test_connection(provider: Provider, model: str) -> None:
|
||||||
from pyra.vault.reader import get_key
|
from pyra.vault.reader import get_key
|
||||||
|
|
||||||
console.print("\n Running test call...", end=" ")
|
while True:
|
||||||
|
console.print("\n Running connection test...", end=" ")
|
||||||
try:
|
try:
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
# Local providers don't need a real key but litellm still requires the field
|
|
||||||
api_key = get_key(provider.id) if provider.requires_key else "local"
|
api_key = get_key(provider.id) if provider.requires_key else "local"
|
||||||
kwargs: dict = {
|
kwargs: dict = {
|
||||||
"model": f"{provider.litellm_prefix}{model}",
|
"model": f"{provider.litellm_prefix}{model}",
|
||||||
@@ -284,6 +381,39 @@ def _test_connection(provider: Provider, model: str) -> None:
|
|||||||
|
|
||||||
litellm.completion(**kwargs)
|
litellm.completion(**kwargs)
|
||||||
console.print("[green]✓ Connection OK[/green]")
|
console.print("[green]✓ Connection OK[/green]")
|
||||||
|
return
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
console.print(f"[yellow]✗ Test call failed: {exc}[/yellow]")
|
label, hint = _classify_error(exc)
|
||||||
console.print(" [dim]You can still proceed — check your config with 'pyra setup' again.[/dim]")
|
console.print("[red]✗[/red]")
|
||||||
|
console.print()
|
||||||
|
console.print(Panel(
|
||||||
|
f"[bold red]{label}[/bold red]\n\n{hint}",
|
||||||
|
title="Test call failed",
|
||||||
|
border_style="red",
|
||||||
|
))
|
||||||
|
|
||||||
|
is_auth_error = "AuthenticationError" in type(exc).__name__
|
||||||
|
choices = [questionary.Choice("Retry", value="retry")]
|
||||||
|
if provider.requires_key and is_auth_error:
|
||||||
|
choices.append(questionary.Choice("Re-enter API key", value="rekey"))
|
||||||
|
choices += [
|
||||||
|
questionary.Choice("Skip test and continue setup", value="skip"),
|
||||||
|
questionary.Choice("Abort setup", value="abort"),
|
||||||
|
]
|
||||||
|
|
||||||
|
action = questionary.select(
|
||||||
|
"How would you like to proceed?",
|
||||||
|
choices=choices,
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
if action is None or action == "abort":
|
||||||
|
raise SystemExit(0)
|
||||||
|
if action == "skip":
|
||||||
|
console.print(
|
||||||
|
" [dim]Test skipped — run [bold]pyra setup[/bold] again if chat doesn't work.[/dim]"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if action == "rekey":
|
||||||
|
_collect_api_key(provider)
|
||||||
|
# "retry" or after "rekey" → loop
|
||||||
|
|||||||
@@ -181,3 +181,151 @@ def test_load_lmstudio_model_returns_false_on_exception(monkeypatch):
|
|||||||
import pyra.setup.wizard as wiz
|
import pyra.setup.wizard as wiz
|
||||||
monkeypatch.setattr(wiz.httpx, "post", MagicMock(side_effect=Exception("timeout")))
|
monkeypatch.setattr(wiz.httpx, "post", MagicMock(side_effect=Exception("timeout")))
|
||||||
assert wiz._load_lmstudio_model("gemma-4b") is False
|
assert wiz._load_lmstudio_model("gemma-4b") is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── _classify_error ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _fake_llm_exc(name: str) -> Exception:
|
||||||
|
"""Create a fake litellm exception with the given class name."""
|
||||||
|
cls = type(name, (Exception,), {"__module__": "litellm.exceptions"})
|
||||||
|
return cls("test error")
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_auth_error():
|
||||||
|
from pyra.setup.wizard import _classify_error
|
||||||
|
label, hint = _classify_error(_fake_llm_exc("AuthenticationError"))
|
||||||
|
assert "key" in label.lower()
|
||||||
|
assert "provider" in hint.lower() or "dashboard" in hint.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_not_found_error():
|
||||||
|
from pyra.setup.wizard import _classify_error
|
||||||
|
label, hint = _classify_error(_fake_llm_exc("NotFoundError"))
|
||||||
|
assert "model" in label.lower()
|
||||||
|
assert "model" in hint.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_rate_limit_error():
|
||||||
|
from pyra.setup.wizard import _classify_error
|
||||||
|
label, _ = _classify_error(_fake_llm_exc("RateLimitError"))
|
||||||
|
assert "rate" in label.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_service_unavailable():
|
||||||
|
from pyra.setup.wizard import _classify_error
|
||||||
|
label, _ = _classify_error(_fake_llm_exc("ServiceUnavailableError"))
|
||||||
|
assert "unavailable" in label.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_api_connection_error():
|
||||||
|
from pyra.setup.wizard import _classify_error
|
||||||
|
label, hint = _classify_error(_fake_llm_exc("APIConnectionError"))
|
||||||
|
assert "reach" in label.lower() or "network" in label.lower() or "connect" in label.lower()
|
||||||
|
assert "internet" in hint.lower() or "network" in hint.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_timeout_error():
|
||||||
|
from pyra.setup.wizard import _classify_error
|
||||||
|
label, _ = _classify_error(_fake_llm_exc("TimeoutError"))
|
||||||
|
assert "time" in label.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_bad_request_error():
|
||||||
|
from pyra.setup.wizard import _classify_error
|
||||||
|
label, hint = _classify_error(_fake_llm_exc("BadRequestError"))
|
||||||
|
assert "request" in label.lower() or "bad" in label.lower()
|
||||||
|
assert "model" in hint.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_httpx_connect_error():
|
||||||
|
import httpx
|
||||||
|
from pyra.setup.wizard import _classify_error
|
||||||
|
label, hint = _classify_error(httpx.ConnectError("refused"))
|
||||||
|
assert "reach" in label.lower() or "reachable" in label.lower()
|
||||||
|
assert "running" in hint.lower() or "listening" in hint.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_httpx_timeout():
|
||||||
|
import httpx
|
||||||
|
from pyra.setup.wizard import _classify_error
|
||||||
|
label, _ = _classify_error(httpx.TimeoutException("timeout"))
|
||||||
|
assert "time" in label.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_generic_error():
|
||||||
|
from pyra.setup.wizard import _classify_error
|
||||||
|
label, hint = _classify_error(ValueError("something went wrong"))
|
||||||
|
assert len(label) > 0
|
||||||
|
assert "something went wrong" in hint
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_error_always_returns_two_strings():
|
||||||
|
from pyra.setup.wizard import _classify_error
|
||||||
|
for exc in (RuntimeError("boom"), OSError("disk"), KeyError("k")):
|
||||||
|
label, hint = _classify_error(exc)
|
||||||
|
assert isinstance(label, str) and len(label) > 0
|
||||||
|
assert isinstance(hint, str) and len(hint) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── _check_local_server retry behaviour ──────────────────────────────────────
|
||||||
|
|
||||||
|
def _silent_print(*a, **kw):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_local_server_success(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.raise_for_status = lambda: None
|
||||||
|
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
|
||||||
|
monkeypatch.setattr(wiz.console, "print", _silent_print)
|
||||||
|
from pyra.setup.providers import get_provider
|
||||||
|
wiz._check_local_server(get_provider("lmstudio")) # must not raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_local_server_retry_then_success(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
from pyra.setup.providers import get_provider
|
||||||
|
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
def flaky_get(*a, **kw):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
raise ConnectionError("refused")
|
||||||
|
m = MagicMock()
|
||||||
|
m.raise_for_status = lambda: None
|
||||||
|
return m
|
||||||
|
|
||||||
|
monkeypatch.setattr(wiz.httpx, "get", flaky_get)
|
||||||
|
monkeypatch.setattr(wiz.console, "print", _silent_print)
|
||||||
|
monkeypatch.setattr(wiz.questionary, "select",
|
||||||
|
lambda *a, **kw: MagicMock(ask=lambda: "retry"))
|
||||||
|
|
||||||
|
wiz._check_local_server(get_provider("lmstudio"))
|
||||||
|
assert call_count["n"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_local_server_abort_raises_system_exit(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
from pyra.setup.providers import get_provider
|
||||||
|
|
||||||
|
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=ConnectionError("refused")))
|
||||||
|
monkeypatch.setattr(wiz.console, "print", _silent_print)
|
||||||
|
monkeypatch.setattr(wiz.questionary, "select",
|
||||||
|
lambda *a, **kw: MagicMock(ask=lambda: "abort"))
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
wiz._check_local_server(get_provider("lmstudio"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_local_server_continue_returns(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
from pyra.setup.providers import get_provider
|
||||||
|
|
||||||
|
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=ConnectionError("refused")))
|
||||||
|
monkeypatch.setattr(wiz.console, "print", _silent_print)
|
||||||
|
monkeypatch.setattr(wiz.questionary, "select",
|
||||||
|
lambda *a, **kw: MagicMock(ask=lambda: "continue"))
|
||||||
|
|
||||||
|
wiz._check_local_server(get_provider("lmstudio")) # must return without raising
|
||||||
|
|||||||
Reference in New Issue
Block a user