"""Tests for setup wizard personalization and model-discovery helpers.""" from unittest.mock import MagicMock import pytest def test_use_case_plugin_mapping_all_categories_have_entries(): from pyra.setup.wizard import _USE_CASE_PLUGINS assert all(len(v) > 0 for v in _USE_CASE_PLUGINS.values()) def test_use_case_plugin_mapping_has_expected_categories(): from pyra.setup.wizard import _USE_CASE_PLUGINS assert "Email" in _USE_CASE_PLUGINS assert "Development & servers" in _USE_CASE_PLUGINS assert "Research & web" in _USE_CASE_PLUGINS def test_use_case_email_contains_email_plugin(): from pyra.setup.wizard import _USE_CASE_PLUGINS assert "email" in _USE_CASE_PLUGINS["Email"] def test_use_case_dev_contains_ssh_and_docker(): from pyra.setup.wizard import _USE_CASE_PLUGINS assert "ssh_tool" in _USE_CASE_PLUGINS["Development & servers"] assert "docker_tool" in _USE_CASE_PLUGINS["Development & servers"] def test_use_case_file_management_contains_cloud_stores(): from pyra.setup.wizard import _USE_CASE_PLUGINS plugins = _USE_CASE_PLUGINS["File management"] assert "gdrive" in plugins assert "onedrive" in plugins assert "dropbox_tool" in plugins def test_suggest_plugins_empty_use_cases_returns_early(monkeypatch): calls = [] import pyra.setup.wizard as wiz monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: calls.append(a)) wiz._suggest_plugins([]) assert calls == [] def test_suggest_plugins_unknown_use_case_returns_early(monkeypatch): calls = [] import pyra.setup.wizard as wiz monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: calls.append(a)) wiz._suggest_plugins(["Not a real category"]) assert calls == [] def test_suggest_plugins_valid_use_case_calls_print(monkeypatch): calls = [] import pyra.setup.wizard as wiz monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: calls.append(str(a))) wiz._suggest_plugins(["Email"]) assert len(calls) > 0 def test_suggest_plugins_panel_text_contains_plugin_name(monkeypatch): from rich.panel import Panel panels = [] import pyra.setup.wizard as wiz def capture_print(*args, **kwargs): for a in args: if isinstance(a, Panel): panels.append(a.renderable) monkeypatch.setattr(wiz.console, "print", capture_print) wiz._suggest_plugins(["Email"]) assert any("email" in str(p) for p in panels) def test_suggest_plugins_multiple_categories(monkeypatch): from rich.panel import Panel panels = [] import pyra.setup.wizard as wiz def capture_print(*args, **kwargs): for a in args: if isinstance(a, Panel): panels.append(a.renderable) monkeypatch.setattr(wiz.console, "print", capture_print) wiz._suggest_plugins(["Email", "Development & servers"]) combined = " ".join(str(p) for p in panels) assert "email" in combined assert "ssh_tool" in combined # ── _fetch_local_models ──────────────────────────────────────────────────────── def test_fetch_local_models_lmstudio_returns_loaded_model_ids(monkeypatch): import pyra.setup.wizard as wiz mock_resp = MagicMock() mock_resp.json.return_value = { "data": [ {"id": "gemma-4b", "state": "loaded"}, {"id": "llama3", "state": "not_loaded"}, ] } mock_resp.raise_for_status = lambda: None monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp) from pyra.setup.providers import get_provider assert wiz._fetch_local_models(get_provider("lmstudio")) == ["gemma-4b"] def test_fetch_local_models_lmstudio_filters_unloaded(monkeypatch): import pyra.setup.wizard as wiz mock_resp = MagicMock() mock_resp.json.return_value = { "data": [ {"id": "model-a", "state": "not_loaded"}, {"id": "model-b", "state": "not_loaded"}, ] } mock_resp.raise_for_status = lambda: None monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp) from pyra.setup.providers import get_provider assert wiz._fetch_local_models(get_provider("lmstudio")) == [] def test_fetch_local_models_ollama_returns_model_names(monkeypatch): import pyra.setup.wizard as wiz mock_resp = MagicMock() mock_resp.json.return_value = {"models": [{"name": "llama3:latest"}, {"name": "mistral"}]} mock_resp.raise_for_status = lambda: None monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp) from pyra.setup.providers import get_provider assert wiz._fetch_local_models(get_provider("ollama")) == ["llama3:latest", "mistral"] def test_fetch_local_models_returns_empty_on_connection_error(monkeypatch): import pyra.setup.wizard as wiz monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=Exception("conn refused"))) from pyra.setup.providers import get_provider assert wiz._fetch_local_models(get_provider("lmstudio")) == [] def test_fetch_local_models_returns_empty_when_no_base_url(): import pyra.setup.wizard as wiz from pyra.setup.providers import Provider provider = Provider( id="test", display_name="Test", requires_key=False, default_model="x", litellm_prefix="openai/", group="Local", ) assert wiz._fetch_local_models(provider) == [] # ── _fetch_lmstudio_available_models ────────────────────────────────────────── def test_fetch_lmstudio_available_models_returns_ids(monkeypatch): import pyra.setup.wizard as wiz mock_resp = MagicMock() mock_resp.json.return_value = {"data": [{"id": "model-a"}, {"id": "model-b"}]} mock_resp.raise_for_status = lambda: None monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp) assert wiz._fetch_lmstudio_available_models() == ["model-a", "model-b"] def test_fetch_lmstudio_available_models_returns_empty_on_error(monkeypatch): import pyra.setup.wizard as wiz monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=Exception("not found"))) assert wiz._fetch_lmstudio_available_models() == [] def test_fetch_lmstudio_available_models_empty_data(monkeypatch): import pyra.setup.wizard as wiz mock_resp = MagicMock() mock_resp.json.return_value = {"data": []} mock_resp.raise_for_status = lambda: None monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp) assert wiz._fetch_lmstudio_available_models() == [] # ── _load_lmstudio_model ────────────────────────────────────────────────────── def test_load_lmstudio_model_returns_true_on_success(monkeypatch): import pyra.setup.wizard as wiz mock_resp = MagicMock() mock_resp.is_success = True monkeypatch.setattr(wiz.httpx, "post", lambda *a, **kw: mock_resp) assert wiz._load_lmstudio_model("gemma-4b") is True def test_load_lmstudio_model_returns_false_on_api_failure(monkeypatch): import pyra.setup.wizard as wiz mock_resp = MagicMock() mock_resp.is_success = False monkeypatch.setattr(wiz.httpx, "post", lambda *a, **kw: mock_resp) assert wiz._load_lmstudio_model("gemma-4b") is False def test_load_lmstudio_model_returns_false_on_exception(monkeypatch): import pyra.setup.wizard as wiz monkeypatch.setattr(wiz.httpx, "post", MagicMock(side_effect=Exception("timeout"))) assert wiz._load_lmstudio_model("gemma-4b") is False # ── _classify_error ─────────────────────────────────────────────────────────── def _fake_llm_exc(name: str) -> Exception: """Create a fake litellm exception with the given class name.""" cls = type(name, (Exception,), {"__module__": "litellm.exceptions"}) return cls("test error") def test_classify_auth_error(): from pyra.setup.wizard import _classify_error label, hint = _classify_error(_fake_llm_exc("AuthenticationError")) assert "key" in label.lower() assert "provider" in hint.lower() or "dashboard" in hint.lower() def test_classify_not_found_error(): from pyra.setup.wizard import _classify_error label, hint = _classify_error(_fake_llm_exc("NotFoundError")) assert "model" in label.lower() assert "model" in hint.lower() def test_classify_rate_limit_error(): from pyra.setup.wizard import _classify_error label, _ = _classify_error(_fake_llm_exc("RateLimitError")) assert "rate" in label.lower() def test_classify_service_unavailable(): from pyra.setup.wizard import _classify_error label, _ = _classify_error(_fake_llm_exc("ServiceUnavailableError")) assert "unavailable" in label.lower() def test_classify_api_connection_error(): from pyra.setup.wizard import _classify_error label, hint = _classify_error(_fake_llm_exc("APIConnectionError")) assert "reach" in label.lower() or "network" in label.lower() or "connect" in label.lower() assert "internet" in hint.lower() or "network" in hint.lower() def test_classify_timeout_error(): from pyra.setup.wizard import _classify_error label, _ = _classify_error(_fake_llm_exc("TimeoutError")) assert "time" in label.lower() def test_classify_bad_request_error(): from pyra.setup.wizard import _classify_error label, hint = _classify_error(_fake_llm_exc("BadRequestError")) assert "request" in label.lower() or "bad" in label.lower() assert "model" in hint.lower() def test_classify_httpx_connect_error(): import httpx from pyra.setup.wizard import _classify_error label, hint = _classify_error(httpx.ConnectError("refused")) assert "reach" in label.lower() or "reachable" in label.lower() assert "running" in hint.lower() or "listening" in hint.lower() def test_classify_httpx_timeout(): import httpx from pyra.setup.wizard import _classify_error label, _ = _classify_error(httpx.TimeoutException("timeout")) assert "time" in label.lower() def test_classify_generic_error(): from pyra.setup.wizard import _classify_error label, hint = _classify_error(ValueError("something went wrong")) assert len(label) > 0 assert "something went wrong" in hint def test_classify_error_always_returns_two_strings(): from pyra.setup.wizard import _classify_error for exc in (RuntimeError("boom"), OSError("disk"), KeyError("k")): label, hint = _classify_error(exc) assert isinstance(label, str) and len(label) > 0 assert isinstance(hint, str) and len(hint) > 0 # ── _check_local_server retry behaviour ────────────────────────────────────── def _silent_print(*a, **kw): pass def test_check_local_server_success(monkeypatch): import pyra.setup.wizard as wiz mock_resp = MagicMock() mock_resp.raise_for_status = lambda: None monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp) monkeypatch.setattr(wiz.console, "print", _silent_print) from pyra.setup.providers import get_provider wiz._check_local_server(get_provider("lmstudio")) # must not raise def test_check_local_server_retry_then_success(monkeypatch): import pyra.setup.wizard as wiz from pyra.setup.providers import get_provider call_count = {"n": 0} def flaky_get(*a, **kw): call_count["n"] += 1 if call_count["n"] == 1: raise ConnectionError("refused") m = MagicMock() m.raise_for_status = lambda: None return m monkeypatch.setattr(wiz.httpx, "get", flaky_get) monkeypatch.setattr(wiz.console, "print", _silent_print) monkeypatch.setattr(wiz.questionary, "select", lambda *a, **kw: MagicMock(ask=lambda: "retry")) wiz._check_local_server(get_provider("lmstudio")) assert call_count["n"] == 2 def test_check_local_server_abort_raises_system_exit(monkeypatch): import pyra.setup.wizard as wiz from pyra.setup.providers import get_provider monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=ConnectionError("refused"))) monkeypatch.setattr(wiz.console, "print", _silent_print) monkeypatch.setattr(wiz.questionary, "select", lambda *a, **kw: MagicMock(ask=lambda: "abort")) with pytest.raises(SystemExit): wiz._check_local_server(get_provider("lmstudio")) def test_check_local_server_continue_returns(monkeypatch): import pyra.setup.wizard as wiz from pyra.setup.providers import get_provider monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=ConnectionError("refused"))) monkeypatch.setattr(wiz.console, "print", _silent_print) monkeypatch.setattr(wiz.questionary, "select", lambda *a, **kw: MagicMock(ask=lambda: "continue")) wiz._check_local_server(get_provider("lmstudio")) # must return without raising # ── draft persistence ───────────────────────────────────────────────────────── def test_save_and_load_draft(tmp_pyra_home): import pyra.setup.wizard as wiz state = {"completed_steps": ["profile"], "user_name": "Alice"} wiz._save_draft(state) loaded = wiz._load_draft() assert loaded == state def test_load_draft_returns_none_when_no_file(tmp_pyra_home): import pyra.setup.wizard as wiz assert wiz._load_draft() is None def test_delete_draft_removes_file(tmp_pyra_home): import pyra.setup.wizard as wiz wiz._save_draft({"completed_steps": []}) wiz._delete_draft() assert wiz._load_draft() is None def test_delete_draft_is_idempotent(tmp_pyra_home): import pyra.setup.wizard as wiz wiz._delete_draft() wiz._delete_draft() def test_mark_step_done_appends_once(tmp_pyra_home): import pyra.setup.wizard as wiz state = {} wiz._mark_step_done(state, "profile") wiz._mark_step_done(state, "profile") assert state["completed_steps"].count("profile") == 1 def test_draft_file_has_correct_permissions(tmp_pyra_home): import os import pyra.setup.wizard as wiz wiz._save_draft({"completed_steps": []}) path = wiz._draft_path() if os.name != "nt": mode = oct(path.stat().st_mode)[-3:] assert mode == "600" # ── fetch_loaded_models ─────────────────────────────────────────────────────── def test_fetch_loaded_models_ollama_uses_api_ps(monkeypatch): import pyra.setup.wizard as wiz from pyra.setup.providers import get_provider mock_resp = MagicMock() mock_resp.json.return_value = {"models": [{"name": "llama3:latest"}, {"name": "mistral"}]} mock_resp.raise_for_status = lambda: None calls = [] monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: (calls.append(url), mock_resp)[1]) result = wiz.fetch_loaded_models(get_provider("ollama")) assert result == ["llama3:latest", "mistral"] assert any("/api/ps" in u for u in calls) def test_fetch_loaded_models_lmstudio_uses_beta_api_and_filters(monkeypatch): import pyra.setup.wizard as wiz from pyra.setup.providers import get_provider mock_resp = MagicMock() mock_resp.json.return_value = { "data": [ {"id": "gemma-4b", "state": "loaded"}, {"id": "llama3", "state": "not_loaded"}, ] } mock_resp.raise_for_status = lambda: None calls = [] monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: (calls.append(url), mock_resp)[1]) result = wiz.fetch_loaded_models(get_provider("lmstudio")) assert result == ["gemma-4b"] assert any("/api/v0/models" in u for u in calls) def test_fetch_loaded_models_lmstudio_filters_unloaded(monkeypatch): import pyra.setup.wizard as wiz from pyra.setup.providers import get_provider mock_resp = MagicMock() mock_resp.json.return_value = { "data": [ {"id": "model-a", "state": "not_loaded"}, {"id": "model-b", "state": "not_loaded"}, ] } mock_resp.raise_for_status = lambda: None monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: mock_resp) assert wiz.fetch_loaded_models(get_provider("lmstudio")) == [] def test_fetch_loaded_models_returns_empty_on_error(monkeypatch): import pyra.setup.wizard as wiz from pyra.setup.providers import get_provider monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=Exception("conn refused"))) assert wiz.fetch_loaded_models(get_provider("ollama")) == [] def test_fetch_loaded_models_returns_empty_when_no_base_url(): import pyra.setup.wizard as wiz from pyra.setup.providers import Provider provider = Provider( id="test", display_name="Test", requires_key=False, default_model="x", litellm_prefix="openai/", group="Local", ) assert wiz.fetch_loaded_models(provider) == [] # ── _show_local_model_status ────────────────────────────────────────────────── def test_show_local_model_status_one_model(monkeypatch): import pyra.setup.wizard as wiz from pyra.setup.providers import get_provider monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: ["gemma-4b"]) printed = [] monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a))) wiz._show_local_model_status(get_provider("lmstudio")) assert any("gemma-4b" in s for s in printed) def test_show_local_model_status_none(monkeypatch): import pyra.setup.wizard as wiz from pyra.setup.providers import get_provider monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: []) printed = [] monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a))) wiz._show_local_model_status(get_provider("lmstudio")) assert any("No model" in s for s in printed) def test_show_local_model_status_multiple(monkeypatch): import pyra.setup.wizard as wiz from pyra.setup.providers import get_provider monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: ["a", "b", "c"]) printed = [] monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a))) wiz._show_local_model_status(get_provider("lmstudio")) combined = " ".join(printed) assert "3" in combined or ("a" in combined and "b" in combined) # ── _test_connection model re-entry ─────────────────────────────────────────── def test_test_connection_change_model(tmp_pyra_home, monkeypatch): import pyra.setup.wizard as wiz from pyra.setup.providers import get_provider call_count = {"n": 0} class FakeNotFound(Exception): pass FakeNotFound.__module__ = "litellm.exceptions" FakeNotFound.__name__ = "NotFoundError" def fake_completion(**kwargs): call_count["n"] += 1 if call_count["n"] == 1: raise FakeNotFound("model not found") import litellm monkeypatch.setattr(litellm, "completion", fake_completion) monkeypatch.setattr(wiz.console, "print", _silent_print) monkeypatch.setattr(wiz.questionary, "select", lambda *a, **kw: MagicMock(ask=lambda: "change_model")) monkeypatch.setattr(wiz.questionary, "text", lambda *a, **kw: MagicMock(ask=lambda: "new-model")) monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: []) monkeypatch.setattr("pyra.vault.reader.get_key", lambda k: "sk-test") provider = get_provider("anthropic") result = wiz._test_connection(provider, "old-model") assert result == "new-model" assert call_count["n"] == 2