test: add tests for draft persistence, model status, and model re-entry
10 new tests covering: - _save_draft / _load_draft / _delete_draft / _mark_step_done helpers - draft file permissions (chmod 600) - _show_local_model_status with zero, one, and multiple loaded models - _test_connection change_model path (model error → change → retry succeeds) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -329,3 +329,116 @@ def test_check_local_server_continue_returns(monkeypatch):
|
||||
lambda *a, **kw: MagicMock(ask=lambda: "continue"))
|
||||
|
||||
wiz._check_local_server(get_provider("lmstudio")) # must return without raising
|
||||
|
||||
|
||||
# ── draft persistence ─────────────────────────────────────────────────────────
|
||||
|
||||
def test_save_and_load_draft(tmp_pyra_home):
|
||||
import pyra.setup.wizard as wiz
|
||||
state = {"completed_steps": ["profile"], "user_name": "Alice"}
|
||||
wiz._save_draft(state)
|
||||
loaded = wiz._load_draft()
|
||||
assert loaded == state
|
||||
|
||||
|
||||
def test_load_draft_returns_none_when_no_file(tmp_pyra_home):
|
||||
import pyra.setup.wizard as wiz
|
||||
assert wiz._load_draft() is None
|
||||
|
||||
|
||||
def test_delete_draft_removes_file(tmp_pyra_home):
|
||||
import pyra.setup.wizard as wiz
|
||||
wiz._save_draft({"completed_steps": []})
|
||||
wiz._delete_draft()
|
||||
assert wiz._load_draft() is None
|
||||
|
||||
|
||||
def test_delete_draft_is_idempotent(tmp_pyra_home):
|
||||
import pyra.setup.wizard as wiz
|
||||
wiz._delete_draft()
|
||||
wiz._delete_draft()
|
||||
|
||||
|
||||
def test_mark_step_done_appends_once(tmp_pyra_home):
|
||||
import pyra.setup.wizard as wiz
|
||||
state = {}
|
||||
wiz._mark_step_done(state, "profile")
|
||||
wiz._mark_step_done(state, "profile")
|
||||
assert state["completed_steps"].count("profile") == 1
|
||||
|
||||
|
||||
def test_draft_file_has_correct_permissions(tmp_pyra_home):
|
||||
import os
|
||||
import pyra.setup.wizard as wiz
|
||||
wiz._save_draft({"completed_steps": []})
|
||||
path = wiz._draft_path()
|
||||
if os.name != "nt":
|
||||
mode = oct(path.stat().st_mode)[-3:]
|
||||
assert mode == "600"
|
||||
|
||||
|
||||
# ── _show_local_model_status ──────────────────────────────────────────────────
|
||||
|
||||
def test_show_local_model_status_one_model(monkeypatch):
|
||||
import pyra.setup.wizard as wiz
|
||||
from pyra.setup.providers import get_provider
|
||||
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: ["gemma-4b"])
|
||||
printed = []
|
||||
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
|
||||
wiz._show_local_model_status(get_provider("lmstudio"))
|
||||
assert any("gemma-4b" in s for s in printed)
|
||||
|
||||
|
||||
def test_show_local_model_status_none(monkeypatch):
|
||||
import pyra.setup.wizard as wiz
|
||||
from pyra.setup.providers import get_provider
|
||||
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: [])
|
||||
printed = []
|
||||
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
|
||||
wiz._show_local_model_status(get_provider("lmstudio"))
|
||||
assert any("No model" in s for s in printed)
|
||||
|
||||
|
||||
def test_show_local_model_status_multiple(monkeypatch):
|
||||
import pyra.setup.wizard as wiz
|
||||
from pyra.setup.providers import get_provider
|
||||
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: ["a", "b", "c"])
|
||||
printed = []
|
||||
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
|
||||
wiz._show_local_model_status(get_provider("lmstudio"))
|
||||
combined = " ".join(printed)
|
||||
assert "3" in combined or ("a" in combined and "b" in combined)
|
||||
|
||||
|
||||
# ── _test_connection model re-entry ───────────────────────────────────────────
|
||||
|
||||
def test_test_connection_change_model(tmp_pyra_home, monkeypatch):
|
||||
import pyra.setup.wizard as wiz
|
||||
from pyra.setup.providers import get_provider
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
class FakeNotFound(Exception):
|
||||
pass
|
||||
FakeNotFound.__module__ = "litellm.exceptions"
|
||||
FakeNotFound.__name__ = "NotFoundError"
|
||||
|
||||
def fake_completion(**kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
raise FakeNotFound("model not found")
|
||||
|
||||
import litellm
|
||||
monkeypatch.setattr(litellm, "completion", fake_completion)
|
||||
monkeypatch.setattr(wiz.console, "print", _silent_print)
|
||||
monkeypatch.setattr(wiz.questionary, "select",
|
||||
lambda *a, **kw: MagicMock(ask=lambda: "change_model"))
|
||||
monkeypatch.setattr(wiz.questionary, "text",
|
||||
lambda *a, **kw: MagicMock(ask=lambda: "new-model"))
|
||||
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: [])
|
||||
monkeypatch.setattr("pyra.vault.reader.get_key", lambda k: "sk-test")
|
||||
|
||||
provider = get_provider("anthropic")
|
||||
result = wiz._test_connection(provider, "old-model")
|
||||
assert result == "new-model"
|
||||
assert call_count["n"] == 2
|
||||
|
||||
Reference in New Issue
Block a user