Files
Pyra/tests/unit/test_setup_wizard.py
T
curo1305 0e052c4992 fix(setup): correct LM Studio loaded state value to "loaded" not "loaded_instance"
Querying the live /api/v0/models endpoint shows LM Studio uses state="loaded"
for in-memory models (not "loaded_instance"), so the filter never matched.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:00:01 +02:00

530 lines
20 KiB
Python

"""Tests for setup wizard personalization and model-discovery helpers."""
from unittest.mock import MagicMock
import pytest
def test_use_case_plugin_mapping_all_categories_have_entries():
from pyra.setup.wizard import _USE_CASE_PLUGINS
assert all(len(v) > 0 for v in _USE_CASE_PLUGINS.values())
def test_use_case_plugin_mapping_has_expected_categories():
from pyra.setup.wizard import _USE_CASE_PLUGINS
assert "Email" in _USE_CASE_PLUGINS
assert "Development & servers" in _USE_CASE_PLUGINS
assert "Research & web" in _USE_CASE_PLUGINS
def test_use_case_email_contains_email_plugin():
from pyra.setup.wizard import _USE_CASE_PLUGINS
assert "email" in _USE_CASE_PLUGINS["Email"]
def test_use_case_dev_contains_ssh_and_docker():
from pyra.setup.wizard import _USE_CASE_PLUGINS
assert "ssh_tool" in _USE_CASE_PLUGINS["Development & servers"]
assert "docker_tool" in _USE_CASE_PLUGINS["Development & servers"]
def test_use_case_file_management_contains_cloud_stores():
from pyra.setup.wizard import _USE_CASE_PLUGINS
plugins = _USE_CASE_PLUGINS["File management"]
assert "gdrive" in plugins
assert "onedrive" in plugins
assert "dropbox_tool" in plugins
def test_suggest_plugins_empty_use_cases_returns_early(monkeypatch):
calls = []
import pyra.setup.wizard as wiz
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: calls.append(a))
wiz._suggest_plugins([])
assert calls == []
def test_suggest_plugins_unknown_use_case_returns_early(monkeypatch):
calls = []
import pyra.setup.wizard as wiz
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: calls.append(a))
wiz._suggest_plugins(["Not a real category"])
assert calls == []
def test_suggest_plugins_valid_use_case_calls_print(monkeypatch):
calls = []
import pyra.setup.wizard as wiz
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: calls.append(str(a)))
wiz._suggest_plugins(["Email"])
assert len(calls) > 0
def test_suggest_plugins_panel_text_contains_plugin_name(monkeypatch):
from rich.panel import Panel
panels = []
import pyra.setup.wizard as wiz
def capture_print(*args, **kwargs):
for a in args:
if isinstance(a, Panel):
panels.append(a.renderable)
monkeypatch.setattr(wiz.console, "print", capture_print)
wiz._suggest_plugins(["Email"])
assert any("email" in str(p) for p in panels)
def test_suggest_plugins_multiple_categories(monkeypatch):
from rich.panel import Panel
panels = []
import pyra.setup.wizard as wiz
def capture_print(*args, **kwargs):
for a in args:
if isinstance(a, Panel):
panels.append(a.renderable)
monkeypatch.setattr(wiz.console, "print", capture_print)
wiz._suggest_plugins(["Email", "Development & servers"])
combined = " ".join(str(p) for p in panels)
assert "email" in combined
assert "ssh_tool" in combined
# ── _fetch_local_models ────────────────────────────────────────────────────────
def test_fetch_local_models_lmstudio_returns_loaded_model_ids(monkeypatch):
import pyra.setup.wizard as wiz
mock_resp = MagicMock()
mock_resp.json.return_value = {
"data": [
{"id": "gemma-4b", "state": "loaded"},
{"id": "llama3", "state": "not_loaded"},
]
}
mock_resp.raise_for_status = lambda: None
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
from pyra.setup.providers import get_provider
assert wiz._fetch_local_models(get_provider("lmstudio")) == ["gemma-4b"]
def test_fetch_local_models_lmstudio_filters_unloaded(monkeypatch):
import pyra.setup.wizard as wiz
mock_resp = MagicMock()
mock_resp.json.return_value = {
"data": [
{"id": "model-a", "state": "not_loaded"},
{"id": "model-b", "state": "not_loaded"},
]
}
mock_resp.raise_for_status = lambda: None
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
from pyra.setup.providers import get_provider
assert wiz._fetch_local_models(get_provider("lmstudio")) == []
def test_fetch_local_models_ollama_returns_model_names(monkeypatch):
import pyra.setup.wizard as wiz
mock_resp = MagicMock()
mock_resp.json.return_value = {"models": [{"name": "llama3:latest"}, {"name": "mistral"}]}
mock_resp.raise_for_status = lambda: None
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
from pyra.setup.providers import get_provider
assert wiz._fetch_local_models(get_provider("ollama")) == ["llama3:latest", "mistral"]
def test_fetch_local_models_returns_empty_on_connection_error(monkeypatch):
import pyra.setup.wizard as wiz
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=Exception("conn refused")))
from pyra.setup.providers import get_provider
assert wiz._fetch_local_models(get_provider("lmstudio")) == []
def test_fetch_local_models_returns_empty_when_no_base_url():
import pyra.setup.wizard as wiz
from pyra.setup.providers import Provider
provider = Provider(
id="test", display_name="Test", requires_key=False,
default_model="x", litellm_prefix="openai/", group="Local",
)
assert wiz._fetch_local_models(provider) == []
# ── _fetch_lmstudio_available_models ──────────────────────────────────────────
def test_fetch_lmstudio_available_models_returns_ids(monkeypatch):
import pyra.setup.wizard as wiz
mock_resp = MagicMock()
mock_resp.json.return_value = {"data": [{"id": "model-a"}, {"id": "model-b"}]}
mock_resp.raise_for_status = lambda: None
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
assert wiz._fetch_lmstudio_available_models() == ["model-a", "model-b"]
def test_fetch_lmstudio_available_models_returns_empty_on_error(monkeypatch):
import pyra.setup.wizard as wiz
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=Exception("not found")))
assert wiz._fetch_lmstudio_available_models() == []
def test_fetch_lmstudio_available_models_empty_data(monkeypatch):
import pyra.setup.wizard as wiz
mock_resp = MagicMock()
mock_resp.json.return_value = {"data": []}
mock_resp.raise_for_status = lambda: None
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
assert wiz._fetch_lmstudio_available_models() == []
# ── _load_lmstudio_model ──────────────────────────────────────────────────────
def test_load_lmstudio_model_returns_true_on_success(monkeypatch):
import pyra.setup.wizard as wiz
mock_resp = MagicMock()
mock_resp.is_success = True
monkeypatch.setattr(wiz.httpx, "post", lambda *a, **kw: mock_resp)
assert wiz._load_lmstudio_model("gemma-4b") is True
def test_load_lmstudio_model_returns_false_on_api_failure(monkeypatch):
import pyra.setup.wizard as wiz
mock_resp = MagicMock()
mock_resp.is_success = False
monkeypatch.setattr(wiz.httpx, "post", lambda *a, **kw: mock_resp)
assert wiz._load_lmstudio_model("gemma-4b") is False
def test_load_lmstudio_model_returns_false_on_exception(monkeypatch):
import pyra.setup.wizard as wiz
monkeypatch.setattr(wiz.httpx, "post", MagicMock(side_effect=Exception("timeout")))
assert wiz._load_lmstudio_model("gemma-4b") is False
# ── _classify_error ───────────────────────────────────────────────────────────
def _fake_llm_exc(name: str) -> Exception:
"""Create a fake litellm exception with the given class name."""
cls = type(name, (Exception,), {"__module__": "litellm.exceptions"})
return cls("test error")
def test_classify_auth_error():
from pyra.setup.wizard import _classify_error
label, hint = _classify_error(_fake_llm_exc("AuthenticationError"))
assert "key" in label.lower()
assert "provider" in hint.lower() or "dashboard" in hint.lower()
def test_classify_not_found_error():
from pyra.setup.wizard import _classify_error
label, hint = _classify_error(_fake_llm_exc("NotFoundError"))
assert "model" in label.lower()
assert "model" in hint.lower()
def test_classify_rate_limit_error():
from pyra.setup.wizard import _classify_error
label, _ = _classify_error(_fake_llm_exc("RateLimitError"))
assert "rate" in label.lower()
def test_classify_service_unavailable():
from pyra.setup.wizard import _classify_error
label, _ = _classify_error(_fake_llm_exc("ServiceUnavailableError"))
assert "unavailable" in label.lower()
def test_classify_api_connection_error():
from pyra.setup.wizard import _classify_error
label, hint = _classify_error(_fake_llm_exc("APIConnectionError"))
assert "reach" in label.lower() or "network" in label.lower() or "connect" in label.lower()
assert "internet" in hint.lower() or "network" in hint.lower()
def test_classify_timeout_error():
from pyra.setup.wizard import _classify_error
label, _ = _classify_error(_fake_llm_exc("TimeoutError"))
assert "time" in label.lower()
def test_classify_bad_request_error():
from pyra.setup.wizard import _classify_error
label, hint = _classify_error(_fake_llm_exc("BadRequestError"))
assert "request" in label.lower() or "bad" in label.lower()
assert "model" in hint.lower()
def test_classify_httpx_connect_error():
import httpx
from pyra.setup.wizard import _classify_error
label, hint = _classify_error(httpx.ConnectError("refused"))
assert "reach" in label.lower() or "reachable" in label.lower()
assert "running" in hint.lower() or "listening" in hint.lower()
def test_classify_httpx_timeout():
import httpx
from pyra.setup.wizard import _classify_error
label, _ = _classify_error(httpx.TimeoutException("timeout"))
assert "time" in label.lower()
def test_classify_generic_error():
from pyra.setup.wizard import _classify_error
label, hint = _classify_error(ValueError("something went wrong"))
assert len(label) > 0
assert "something went wrong" in hint
def test_classify_error_always_returns_two_strings():
from pyra.setup.wizard import _classify_error
for exc in (RuntimeError("boom"), OSError("disk"), KeyError("k")):
label, hint = _classify_error(exc)
assert isinstance(label, str) and len(label) > 0
assert isinstance(hint, str) and len(hint) > 0
# ── _check_local_server retry behaviour ──────────────────────────────────────
def _silent_print(*a, **kw):
pass
def test_check_local_server_success(monkeypatch):
import pyra.setup.wizard as wiz
mock_resp = MagicMock()
mock_resp.raise_for_status = lambda: None
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
monkeypatch.setattr(wiz.console, "print", _silent_print)
from pyra.setup.providers import get_provider
wiz._check_local_server(get_provider("lmstudio")) # must not raise
def test_check_local_server_retry_then_success(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
call_count = {"n": 0}
def flaky_get(*a, **kw):
call_count["n"] += 1
if call_count["n"] == 1:
raise ConnectionError("refused")
m = MagicMock()
m.raise_for_status = lambda: None
return m
monkeypatch.setattr(wiz.httpx, "get", flaky_get)
monkeypatch.setattr(wiz.console, "print", _silent_print)
monkeypatch.setattr(wiz.questionary, "select",
lambda *a, **kw: MagicMock(ask=lambda: "retry"))
wiz._check_local_server(get_provider("lmstudio"))
assert call_count["n"] == 2
def test_check_local_server_abort_raises_system_exit(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=ConnectionError("refused")))
monkeypatch.setattr(wiz.console, "print", _silent_print)
monkeypatch.setattr(wiz.questionary, "select",
lambda *a, **kw: MagicMock(ask=lambda: "abort"))
with pytest.raises(SystemExit):
wiz._check_local_server(get_provider("lmstudio"))
def test_check_local_server_continue_returns(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=ConnectionError("refused")))
monkeypatch.setattr(wiz.console, "print", _silent_print)
monkeypatch.setattr(wiz.questionary, "select",
lambda *a, **kw: MagicMock(ask=lambda: "continue"))
wiz._check_local_server(get_provider("lmstudio")) # must return without raising
# ── draft persistence ─────────────────────────────────────────────────────────
def test_save_and_load_draft(tmp_pyra_home):
import pyra.setup.wizard as wiz
state = {"completed_steps": ["profile"], "user_name": "Alice"}
wiz._save_draft(state)
loaded = wiz._load_draft()
assert loaded == state
def test_load_draft_returns_none_when_no_file(tmp_pyra_home):
import pyra.setup.wizard as wiz
assert wiz._load_draft() is None
def test_delete_draft_removes_file(tmp_pyra_home):
import pyra.setup.wizard as wiz
wiz._save_draft({"completed_steps": []})
wiz._delete_draft()
assert wiz._load_draft() is None
def test_delete_draft_is_idempotent(tmp_pyra_home):
import pyra.setup.wizard as wiz
wiz._delete_draft()
wiz._delete_draft()
def test_mark_step_done_appends_once(tmp_pyra_home):
import pyra.setup.wizard as wiz
state = {}
wiz._mark_step_done(state, "profile")
wiz._mark_step_done(state, "profile")
assert state["completed_steps"].count("profile") == 1
def test_draft_file_has_correct_permissions(tmp_pyra_home):
import os
import pyra.setup.wizard as wiz
wiz._save_draft({"completed_steps": []})
path = wiz._draft_path()
if os.name != "nt":
mode = oct(path.stat().st_mode)[-3:]
assert mode == "600"
# ── fetch_loaded_models ───────────────────────────────────────────────────────
def test_fetch_loaded_models_ollama_uses_api_ps(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
mock_resp = MagicMock()
mock_resp.json.return_value = {"models": [{"name": "llama3:latest"}, {"name": "mistral"}]}
mock_resp.raise_for_status = lambda: None
calls = []
monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: (calls.append(url), mock_resp)[1])
result = wiz.fetch_loaded_models(get_provider("ollama"))
assert result == ["llama3:latest", "mistral"]
assert any("/api/ps" in u for u in calls)
def test_fetch_loaded_models_lmstudio_uses_beta_api_and_filters(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
mock_resp = MagicMock()
mock_resp.json.return_value = {
"data": [
{"id": "gemma-4b", "state": "loaded"},
{"id": "llama3", "state": "not_loaded"},
]
}
mock_resp.raise_for_status = lambda: None
calls = []
monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: (calls.append(url), mock_resp)[1])
result = wiz.fetch_loaded_models(get_provider("lmstudio"))
assert result == ["gemma-4b"]
assert any("/api/v0/models" in u for u in calls)
def test_fetch_loaded_models_lmstudio_filters_unloaded(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
mock_resp = MagicMock()
mock_resp.json.return_value = {
"data": [
{"id": "model-a", "state": "not_loaded"},
{"id": "model-b", "state": "not_loaded"},
]
}
mock_resp.raise_for_status = lambda: None
monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: mock_resp)
assert wiz.fetch_loaded_models(get_provider("lmstudio")) == []
def test_fetch_loaded_models_returns_empty_on_error(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=Exception("conn refused")))
assert wiz.fetch_loaded_models(get_provider("ollama")) == []
def test_fetch_loaded_models_returns_empty_when_no_base_url():
import pyra.setup.wizard as wiz
from pyra.setup.providers import Provider
provider = Provider(
id="test", display_name="Test", requires_key=False,
default_model="x", litellm_prefix="openai/", group="Local",
)
assert wiz.fetch_loaded_models(provider) == []
# ── _show_local_model_status ──────────────────────────────────────────────────
def test_show_local_model_status_one_model(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: ["gemma-4b"])
printed = []
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
wiz._show_local_model_status(get_provider("lmstudio"))
assert any("gemma-4b" in s for s in printed)
def test_show_local_model_status_none(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: [])
printed = []
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
wiz._show_local_model_status(get_provider("lmstudio"))
assert any("No model" in s for s in printed)
def test_show_local_model_status_multiple(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: ["a", "b", "c"])
printed = []
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
wiz._show_local_model_status(get_provider("lmstudio"))
combined = " ".join(printed)
assert "3" in combined or ("a" in combined and "b" in combined)
# ── _test_connection model re-entry ───────────────────────────────────────────
def test_test_connection_change_model(tmp_pyra_home, monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
call_count = {"n": 0}
class FakeNotFound(Exception):
pass
FakeNotFound.__module__ = "litellm.exceptions"
FakeNotFound.__name__ = "NotFoundError"
def fake_completion(**kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
raise FakeNotFound("model not found")
import litellm
monkeypatch.setattr(litellm, "completion", fake_completion)
monkeypatch.setattr(wiz.console, "print", _silent_print)
monkeypatch.setattr(wiz.questionary, "select",
lambda *a, **kw: MagicMock(ask=lambda: "change_model"))
monkeypatch.setattr(wiz.questionary, "text",
lambda *a, **kw: MagicMock(ask=lambda: "new-model"))
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: [])
monkeypatch.setattr("pyra.vault.reader.get_key", lambda k: "sk-test")
provider = get_provider("anthropic")
result = wiz._test_connection(provider, "old-model")
assert result == "new-model"
assert call_count["n"] == 2