From b3851a2715bc09d84e0c9625d97cabde16beb457 Mon Sep 17 00:00:00 2001 From: curo1305 Date: Tue, 19 May 2026 13:43:53 +0200 Subject: [PATCH] test: add tests for draft persistence, model status, and model re-entry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 10 new tests covering: - _save_draft / _load_draft / _delete_draft / _mark_step_done helpers - draft file permissions (chmod 600) - _show_local_model_status with zero, one, and multiple loaded models - _test_connection change_model path (model error → change → retry succeeds) Co-Authored-By: Claude Sonnet 4.6 --- tests/unit/test_setup_wizard.py | 113 ++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/tests/unit/test_setup_wizard.py b/tests/unit/test_setup_wizard.py index 397d962..9eceb27 100644 --- a/tests/unit/test_setup_wizard.py +++ b/tests/unit/test_setup_wizard.py @@ -329,3 +329,116 @@ def test_check_local_server_continue_returns(monkeypatch): lambda *a, **kw: MagicMock(ask=lambda: "continue")) wiz._check_local_server(get_provider("lmstudio")) # must return without raising + + +# ── draft persistence ───────────────────────────────────────────────────────── + +def test_save_and_load_draft(tmp_pyra_home): + import pyra.setup.wizard as wiz + state = {"completed_steps": ["profile"], "user_name": "Alice"} + wiz._save_draft(state) + loaded = wiz._load_draft() + assert loaded == state + + +def test_load_draft_returns_none_when_no_file(tmp_pyra_home): + import pyra.setup.wizard as wiz + assert wiz._load_draft() is None + + +def test_delete_draft_removes_file(tmp_pyra_home): + import pyra.setup.wizard as wiz + wiz._save_draft({"completed_steps": []}) + wiz._delete_draft() + assert wiz._load_draft() is None + + +def test_delete_draft_is_idempotent(tmp_pyra_home): + import pyra.setup.wizard as wiz + wiz._delete_draft() + wiz._delete_draft() + + +def test_mark_step_done_appends_once(tmp_pyra_home): + import pyra.setup.wizard as wiz + state = {} + wiz._mark_step_done(state, "profile") + wiz._mark_step_done(state, "profile") + assert state["completed_steps"].count("profile") == 1 + + +def test_draft_file_has_correct_permissions(tmp_pyra_home): + import os + import pyra.setup.wizard as wiz + wiz._save_draft({"completed_steps": []}) + path = wiz._draft_path() + if os.name != "nt": + mode = oct(path.stat().st_mode)[-3:] + assert mode == "600" + + +# ── _show_local_model_status ────────────────────────────────────────────────── + +def test_show_local_model_status_one_model(monkeypatch): + import pyra.setup.wizard as wiz + from pyra.setup.providers import get_provider + monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: ["gemma-4b"]) + printed = [] + monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a))) + wiz._show_local_model_status(get_provider("lmstudio")) + assert any("gemma-4b" in s for s in printed) + + +def test_show_local_model_status_none(monkeypatch): + import pyra.setup.wizard as wiz + from pyra.setup.providers import get_provider + monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: []) + printed = [] + monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a))) + wiz._show_local_model_status(get_provider("lmstudio")) + assert any("No model" in s for s in printed) + + +def test_show_local_model_status_multiple(monkeypatch): + import pyra.setup.wizard as wiz + from pyra.setup.providers import get_provider + monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: ["a", "b", "c"]) + printed = [] + monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a))) + wiz._show_local_model_status(get_provider("lmstudio")) + combined = " ".join(printed) + assert "3" in combined or ("a" in combined and "b" in combined) + + +# ── _test_connection model re-entry ─────────────────────────────────────────── + +def test_test_connection_change_model(tmp_pyra_home, monkeypatch): + import pyra.setup.wizard as wiz + from pyra.setup.providers import get_provider + + call_count = {"n": 0} + + class FakeNotFound(Exception): + pass + FakeNotFound.__module__ = "litellm.exceptions" + FakeNotFound.__name__ = "NotFoundError" + + def fake_completion(**kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + raise FakeNotFound("model not found") + + import litellm + monkeypatch.setattr(litellm, "completion", fake_completion) + monkeypatch.setattr(wiz.console, "print", _silent_print) + monkeypatch.setattr(wiz.questionary, "select", + lambda *a, **kw: MagicMock(ask=lambda: "change_model")) + monkeypatch.setattr(wiz.questionary, "text", + lambda *a, **kw: MagicMock(ask=lambda: "new-model")) + monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: []) + monkeypatch.setattr("pyra.vault.reader.get_key", lambda k: "sk-test") + + provider = get_provider("anthropic") + result = wiz._test_connection(provider, "old-model") + assert result == "new-model" + assert call_count["n"] == 2