test: add tests for _classify_error and _check_local_server retry behaviour
_classify_error: covers all litellm error types (auth, not-found, rate-limit, service-unavailable, connection, timeout, bad-request), httpx connect and timeout errors, and generic fallback — using dynamically constructed fake exception classes to avoid importing litellm in tests. _check_local_server: covers success, retry-then-success, abort (SystemExit), and continue-anyway paths via monkeypatched httpx and questionary. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -181,3 +181,151 @@ def test_load_lmstudio_model_returns_false_on_exception(monkeypatch):
|
||||
import pyra.setup.wizard as wiz
|
||||
monkeypatch.setattr(wiz.httpx, "post", MagicMock(side_effect=Exception("timeout")))
|
||||
assert wiz._load_lmstudio_model("gemma-4b") is False
|
||||
|
||||
|
||||
# ── _classify_error ───────────────────────────────────────────────────────────
|
||||
|
||||
def _fake_llm_exc(name: str) -> Exception:
|
||||
"""Create a fake litellm exception with the given class name."""
|
||||
cls = type(name, (Exception,), {"__module__": "litellm.exceptions"})
|
||||
return cls("test error")
|
||||
|
||||
|
||||
def test_classify_auth_error():
|
||||
from pyra.setup.wizard import _classify_error
|
||||
label, hint = _classify_error(_fake_llm_exc("AuthenticationError"))
|
||||
assert "key" in label.lower()
|
||||
assert "provider" in hint.lower() or "dashboard" in hint.lower()
|
||||
|
||||
|
||||
def test_classify_not_found_error():
|
||||
from pyra.setup.wizard import _classify_error
|
||||
label, hint = _classify_error(_fake_llm_exc("NotFoundError"))
|
||||
assert "model" in label.lower()
|
||||
assert "model" in hint.lower()
|
||||
|
||||
|
||||
def test_classify_rate_limit_error():
|
||||
from pyra.setup.wizard import _classify_error
|
||||
label, _ = _classify_error(_fake_llm_exc("RateLimitError"))
|
||||
assert "rate" in label.lower()
|
||||
|
||||
|
||||
def test_classify_service_unavailable():
|
||||
from pyra.setup.wizard import _classify_error
|
||||
label, _ = _classify_error(_fake_llm_exc("ServiceUnavailableError"))
|
||||
assert "unavailable" in label.lower()
|
||||
|
||||
|
||||
def test_classify_api_connection_error():
|
||||
from pyra.setup.wizard import _classify_error
|
||||
label, hint = _classify_error(_fake_llm_exc("APIConnectionError"))
|
||||
assert "reach" in label.lower() or "network" in label.lower() or "connect" in label.lower()
|
||||
assert "internet" in hint.lower() or "network" in hint.lower()
|
||||
|
||||
|
||||
def test_classify_timeout_error():
|
||||
from pyra.setup.wizard import _classify_error
|
||||
label, _ = _classify_error(_fake_llm_exc("TimeoutError"))
|
||||
assert "time" in label.lower()
|
||||
|
||||
|
||||
def test_classify_bad_request_error():
|
||||
from pyra.setup.wizard import _classify_error
|
||||
label, hint = _classify_error(_fake_llm_exc("BadRequestError"))
|
||||
assert "request" in label.lower() or "bad" in label.lower()
|
||||
assert "model" in hint.lower()
|
||||
|
||||
|
||||
def test_classify_httpx_connect_error():
|
||||
import httpx
|
||||
from pyra.setup.wizard import _classify_error
|
||||
label, hint = _classify_error(httpx.ConnectError("refused"))
|
||||
assert "reach" in label.lower() or "reachable" in label.lower()
|
||||
assert "running" in hint.lower() or "listening" in hint.lower()
|
||||
|
||||
|
||||
def test_classify_httpx_timeout():
|
||||
import httpx
|
||||
from pyra.setup.wizard import _classify_error
|
||||
label, _ = _classify_error(httpx.TimeoutException("timeout"))
|
||||
assert "time" in label.lower()
|
||||
|
||||
|
||||
def test_classify_generic_error():
|
||||
from pyra.setup.wizard import _classify_error
|
||||
label, hint = _classify_error(ValueError("something went wrong"))
|
||||
assert len(label) > 0
|
||||
assert "something went wrong" in hint
|
||||
|
||||
|
||||
def test_classify_error_always_returns_two_strings():
|
||||
from pyra.setup.wizard import _classify_error
|
||||
for exc in (RuntimeError("boom"), OSError("disk"), KeyError("k")):
|
||||
label, hint = _classify_error(exc)
|
||||
assert isinstance(label, str) and len(label) > 0
|
||||
assert isinstance(hint, str) and len(hint) > 0
|
||||
|
||||
|
||||
# ── _check_local_server retry behaviour ──────────────────────────────────────
|
||||
|
||||
def _silent_print(*a, **kw):
|
||||
pass
|
||||
|
||||
|
||||
def test_check_local_server_success(monkeypatch):
|
||||
import pyra.setup.wizard as wiz
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status = lambda: None
|
||||
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
|
||||
monkeypatch.setattr(wiz.console, "print", _silent_print)
|
||||
from pyra.setup.providers import get_provider
|
||||
wiz._check_local_server(get_provider("lmstudio")) # must not raise
|
||||
|
||||
|
||||
def test_check_local_server_retry_then_success(monkeypatch):
|
||||
import pyra.setup.wizard as wiz
|
||||
from pyra.setup.providers import get_provider
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def flaky_get(*a, **kw):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
raise ConnectionError("refused")
|
||||
m = MagicMock()
|
||||
m.raise_for_status = lambda: None
|
||||
return m
|
||||
|
||||
monkeypatch.setattr(wiz.httpx, "get", flaky_get)
|
||||
monkeypatch.setattr(wiz.console, "print", _silent_print)
|
||||
monkeypatch.setattr(wiz.questionary, "select",
|
||||
lambda *a, **kw: MagicMock(ask=lambda: "retry"))
|
||||
|
||||
wiz._check_local_server(get_provider("lmstudio"))
|
||||
assert call_count["n"] == 2
|
||||
|
||||
|
||||
def test_check_local_server_abort_raises_system_exit(monkeypatch):
|
||||
import pyra.setup.wizard as wiz
|
||||
from pyra.setup.providers import get_provider
|
||||
|
||||
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=ConnectionError("refused")))
|
||||
monkeypatch.setattr(wiz.console, "print", _silent_print)
|
||||
monkeypatch.setattr(wiz.questionary, "select",
|
||||
lambda *a, **kw: MagicMock(ask=lambda: "abort"))
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
wiz._check_local_server(get_provider("lmstudio"))
|
||||
|
||||
|
||||
def test_check_local_server_continue_returns(monkeypatch):
|
||||
import pyra.setup.wizard as wiz
|
||||
from pyra.setup.providers import get_provider
|
||||
|
||||
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=ConnectionError("refused")))
|
||||
monkeypatch.setattr(wiz.console, "print", _silent_print)
|
||||
monkeypatch.setattr(wiz.questionary, "select",
|
||||
lambda *a, **kw: MagicMock(ask=lambda: "continue"))
|
||||
|
||||
wiz._check_local_server(get_provider("lmstudio")) # must return without raising
|
||||
|
||||
Reference in New Issue
Block a user