From efc589cc56db40c0b504afe03130d55472355955 Mon Sep 17 00:00:00 2001 From: curo1305 Date: Tue, 19 May 2026 13:27:04 +0200 Subject: [PATCH] test: add tests for _classify_error and _check_local_server retry behaviour MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _classify_error: covers all litellm error types (auth, not-found, rate-limit, service-unavailable, connection, timeout, bad-request), httpx connect and timeout errors, and generic fallback — using dynamically constructed fake exception classes to avoid importing litellm in tests. _check_local_server: covers success, retry-then-success, abort (SystemExit), and continue-anyway paths via monkeypatched httpx and questionary. Co-Authored-By: Claude Sonnet 4.6 --- tests/unit/test_setup_wizard.py | 148 ++++++++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) diff --git a/tests/unit/test_setup_wizard.py b/tests/unit/test_setup_wizard.py index 88e46c9..397d962 100644 --- a/tests/unit/test_setup_wizard.py +++ b/tests/unit/test_setup_wizard.py @@ -181,3 +181,151 @@ def test_load_lmstudio_model_returns_false_on_exception(monkeypatch): import pyra.setup.wizard as wiz monkeypatch.setattr(wiz.httpx, "post", MagicMock(side_effect=Exception("timeout"))) assert wiz._load_lmstudio_model("gemma-4b") is False + + +# ── _classify_error ─────────────────────────────────────────────────────────── + +def _fake_llm_exc(name: str) -> Exception: + """Create a fake litellm exception with the given class name.""" + cls = type(name, (Exception,), {"__module__": "litellm.exceptions"}) + return cls("test error") + + +def test_classify_auth_error(): + from pyra.setup.wizard import _classify_error + label, hint = _classify_error(_fake_llm_exc("AuthenticationError")) + assert "key" in label.lower() + assert "provider" in hint.lower() or "dashboard" in hint.lower() + + +def test_classify_not_found_error(): + from pyra.setup.wizard import _classify_error + label, hint = _classify_error(_fake_llm_exc("NotFoundError")) + assert "model" in label.lower() + assert "model" in hint.lower() + + +def test_classify_rate_limit_error(): + from pyra.setup.wizard import _classify_error + label, _ = _classify_error(_fake_llm_exc("RateLimitError")) + assert "rate" in label.lower() + + +def test_classify_service_unavailable(): + from pyra.setup.wizard import _classify_error + label, _ = _classify_error(_fake_llm_exc("ServiceUnavailableError")) + assert "unavailable" in label.lower() + + +def test_classify_api_connection_error(): + from pyra.setup.wizard import _classify_error + label, hint = _classify_error(_fake_llm_exc("APIConnectionError")) + assert "reach" in label.lower() or "network" in label.lower() or "connect" in label.lower() + assert "internet" in hint.lower() or "network" in hint.lower() + + +def test_classify_timeout_error(): + from pyra.setup.wizard import _classify_error + label, _ = _classify_error(_fake_llm_exc("TimeoutError")) + assert "time" in label.lower() + + +def test_classify_bad_request_error(): + from pyra.setup.wizard import _classify_error + label, hint = _classify_error(_fake_llm_exc("BadRequestError")) + assert "request" in label.lower() or "bad" in label.lower() + assert "model" in hint.lower() + + +def test_classify_httpx_connect_error(): + import httpx + from pyra.setup.wizard import _classify_error + label, hint = _classify_error(httpx.ConnectError("refused")) + assert "reach" in label.lower() or "reachable" in label.lower() + assert "running" in hint.lower() or "listening" in hint.lower() + + +def test_classify_httpx_timeout(): + import httpx + from pyra.setup.wizard import _classify_error + label, _ = _classify_error(httpx.TimeoutException("timeout")) + assert "time" in label.lower() + + +def test_classify_generic_error(): + from pyra.setup.wizard import _classify_error + label, hint = _classify_error(ValueError("something went wrong")) + assert len(label) > 0 + assert "something went wrong" in hint + + +def test_classify_error_always_returns_two_strings(): + from pyra.setup.wizard import _classify_error + for exc in (RuntimeError("boom"), OSError("disk"), KeyError("k")): + label, hint = _classify_error(exc) + assert isinstance(label, str) and len(label) > 0 + assert isinstance(hint, str) and len(hint) > 0 + + +# ── _check_local_server retry behaviour ────────────────────────────────────── + +def _silent_print(*a, **kw): + pass + + +def test_check_local_server_success(monkeypatch): + import pyra.setup.wizard as wiz + mock_resp = MagicMock() + mock_resp.raise_for_status = lambda: None + monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp) + monkeypatch.setattr(wiz.console, "print", _silent_print) + from pyra.setup.providers import get_provider + wiz._check_local_server(get_provider("lmstudio")) # must not raise + + +def test_check_local_server_retry_then_success(monkeypatch): + import pyra.setup.wizard as wiz + from pyra.setup.providers import get_provider + + call_count = {"n": 0} + + def flaky_get(*a, **kw): + call_count["n"] += 1 + if call_count["n"] == 1: + raise ConnectionError("refused") + m = MagicMock() + m.raise_for_status = lambda: None + return m + + monkeypatch.setattr(wiz.httpx, "get", flaky_get) + monkeypatch.setattr(wiz.console, "print", _silent_print) + monkeypatch.setattr(wiz.questionary, "select", + lambda *a, **kw: MagicMock(ask=lambda: "retry")) + + wiz._check_local_server(get_provider("lmstudio")) + assert call_count["n"] == 2 + + +def test_check_local_server_abort_raises_system_exit(monkeypatch): + import pyra.setup.wizard as wiz + from pyra.setup.providers import get_provider + + monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=ConnectionError("refused"))) + monkeypatch.setattr(wiz.console, "print", _silent_print) + monkeypatch.setattr(wiz.questionary, "select", + lambda *a, **kw: MagicMock(ask=lambda: "abort")) + + with pytest.raises(SystemExit): + wiz._check_local_server(get_provider("lmstudio")) + + +def test_check_local_server_continue_returns(monkeypatch): + import pyra.setup.wizard as wiz + from pyra.setup.providers import get_provider + + monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=ConnectionError("refused"))) + monkeypatch.setattr(wiz.console, "print", _silent_print) + monkeypatch.setattr(wiz.questionary, "select", + lambda *a, **kw: MagicMock(ask=lambda: "continue")) + + wiz._check_local_server(get_provider("lmstudio")) # must return without raising