test: add tests for draft persistence, model status, and model re-entry

10 new tests covering:
- _save_draft / _load_draft / _delete_draft / _mark_step_done helpers
- draft file permissions (chmod 600)
- _show_local_model_status with zero, one, and multiple loaded models
- _test_connection change_model path (model error → change → retry succeeds)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
curo1305
2026-05-19 13:43:53 +02:00
parent 019e8044a9
commit b3851a2715
+113
View File
@@ -329,3 +329,116 @@ def test_check_local_server_continue_returns(monkeypatch):
lambda *a, **kw: MagicMock(ask=lambda: "continue"))
wiz._check_local_server(get_provider("lmstudio")) # must return without raising
# ── draft persistence ─────────────────────────────────────────────────────────
def test_save_and_load_draft(tmp_pyra_home):
import pyra.setup.wizard as wiz
state = {"completed_steps": ["profile"], "user_name": "Alice"}
wiz._save_draft(state)
loaded = wiz._load_draft()
assert loaded == state
def test_load_draft_returns_none_when_no_file(tmp_pyra_home):
import pyra.setup.wizard as wiz
assert wiz._load_draft() is None
def test_delete_draft_removes_file(tmp_pyra_home):
import pyra.setup.wizard as wiz
wiz._save_draft({"completed_steps": []})
wiz._delete_draft()
assert wiz._load_draft() is None
def test_delete_draft_is_idempotent(tmp_pyra_home):
import pyra.setup.wizard as wiz
wiz._delete_draft()
wiz._delete_draft()
def test_mark_step_done_appends_once(tmp_pyra_home):
import pyra.setup.wizard as wiz
state = {}
wiz._mark_step_done(state, "profile")
wiz._mark_step_done(state, "profile")
assert state["completed_steps"].count("profile") == 1
def test_draft_file_has_correct_permissions(tmp_pyra_home):
import os
import pyra.setup.wizard as wiz
wiz._save_draft({"completed_steps": []})
path = wiz._draft_path()
if os.name != "nt":
mode = oct(path.stat().st_mode)[-3:]
assert mode == "600"
# ── _show_local_model_status ──────────────────────────────────────────────────
def test_show_local_model_status_one_model(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: ["gemma-4b"])
printed = []
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
wiz._show_local_model_status(get_provider("lmstudio"))
assert any("gemma-4b" in s for s in printed)
def test_show_local_model_status_none(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: [])
printed = []
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
wiz._show_local_model_status(get_provider("lmstudio"))
assert any("No model" in s for s in printed)
def test_show_local_model_status_multiple(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: ["a", "b", "c"])
printed = []
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
wiz._show_local_model_status(get_provider("lmstudio"))
combined = " ".join(printed)
assert "3" in combined or ("a" in combined and "b" in combined)
# ── _test_connection model re-entry ───────────────────────────────────────────
def test_test_connection_change_model(tmp_pyra_home, monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
call_count = {"n": 0}
class FakeNotFound(Exception):
pass
FakeNotFound.__module__ = "litellm.exceptions"
FakeNotFound.__name__ = "NotFoundError"
def fake_completion(**kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
raise FakeNotFound("model not found")
import litellm
monkeypatch.setattr(litellm, "completion", fake_completion)
monkeypatch.setattr(wiz.console, "print", _silent_print)
monkeypatch.setattr(wiz.questionary, "select",
lambda *a, **kw: MagicMock(ask=lambda: "change_model"))
monkeypatch.setattr(wiz.questionary, "text",
lambda *a, **kw: MagicMock(ask=lambda: "new-model"))
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: [])
monkeypatch.setattr("pyra.vault.reader.get_key", lambda k: "sk-test")
provider = get_provider("anthropic")
result = wiz._test_connection(provider, "old-model")
assert result == "new-model"
assert call_count["n"] == 2