From 833d1445f03b7a54163591d06edb010b91637353 Mon Sep 17 00:00:00 2001 From: curo1305 Date: Tue, 19 May 2026 14:46:54 +0200 Subject: [PATCH] feat(setup,chat): detect actually-loaded local model via provider-specific API For Ollama, /api/tags returns all installed models, not running ones. Add fetch_loaded_models() using /api/ps for Ollama (and /v1/models for LM Studio/llama.cpp, which already return only loaded models). _show_local_model_status() now calls fetch_loaded_models() so the setup wizard correctly shows only in-memory models for Ollama. At chat session startup, local providers warn when the configured model is not currently loaded, or when nothing is loaded at all. Co-Authored-By: Claude Sonnet 4.6 --- src/pyra/chat/session.py | 11 +++++++ src/pyra/setup/wizard.py | 19 +++++++++++- tests/unit/test_setup_wizard.py | 51 +++++++++++++++++++++++++++++++-- 3 files changed, 77 insertions(+), 4 deletions(-) diff --git a/src/pyra/chat/session.py b/src/pyra/chat/session.py index b217d5c..2ea7f5d 100644 --- a/src/pyra/chat/session.py +++ b/src/pyra/chat/session.py @@ -25,6 +25,7 @@ from pyra.plugins.executor import ToolExecutor from pyra.plugins.registry import PluginRegistry from pyra.security.injection import scan_response from pyra.setup.providers import get_provider +from pyra.setup.wizard import fetch_loaded_models from pyra.utils.paths import pyra_home _HISTORY_FILE = pyra_home() / ".chat_history" @@ -176,6 +177,16 @@ def start_chat() -> None: "[dim]Type /help for commands, /quit to exit.[/dim]" ) + if provider.group == "Local": + loaded = fetch_loaded_models(provider) + if not loaded: + render_info(f"No model currently loaded in {provider.display_name}.") + elif cfg.ai.model not in loaded: + render_info( + f"Model '{cfg.ai.model}' not loaded in {provider.display_name}. " + f"Loaded: {', '.join(loaded)}" + ) + _flags: dict = {"use_tools": True} while True: diff --git a/src/pyra/setup/wizard.py b/src/pyra/setup/wizard.py index 5c2c4e1..da66a0b 100644 --- a/src/pyra/setup/wizard.py +++ b/src/pyra/setup/wizard.py @@ -377,9 +377,26 @@ def _check_local_server(provider: Provider) -> None: # "retry" → loop +def fetch_loaded_models(provider: Provider) -> list[str]: + """Return models currently loaded in RAM from a local provider's API.""" + if not provider.base_url: + return [] + try: + if provider.id == "ollama": + resp = httpx.get(f"{provider.base_url}/api/ps", timeout=3.0) + resp.raise_for_status() + return [m["name"] for m in resp.json().get("models", [])] + else: + resp = httpx.get(f"{provider.base_url}/models", timeout=3.0) + resp.raise_for_status() + return [m["id"] for m in resp.json().get("data", [])] + except Exception: + return [] + + def _show_local_model_status(provider: Provider) -> None: """Print a one-line status showing which models are currently loaded.""" - models = _fetch_local_models(provider) + models = fetch_loaded_models(provider) if not models: console.print(" [yellow]No model currently loaded[/yellow]") elif len(models) == 1: diff --git a/tests/unit/test_setup_wizard.py b/tests/unit/test_setup_wizard.py index 9eceb27..89fc25f 100644 --- a/tests/unit/test_setup_wizard.py +++ b/tests/unit/test_setup_wizard.py @@ -377,12 +377,57 @@ def test_draft_file_has_correct_permissions(tmp_pyra_home): assert mode == "600" +# ── fetch_loaded_models ─────────────────────────────────────────────────────── + +def test_fetch_loaded_models_ollama_uses_api_ps(monkeypatch): + import pyra.setup.wizard as wiz + from pyra.setup.providers import get_provider + mock_resp = MagicMock() + mock_resp.json.return_value = {"models": [{"name": "llama3:latest"}, {"name": "mistral"}]} + mock_resp.raise_for_status = lambda: None + calls = [] + monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: (calls.append(url), mock_resp)[1]) + result = wiz.fetch_loaded_models(get_provider("ollama")) + assert result == ["llama3:latest", "mistral"] + assert any("/api/ps" in u for u in calls) + + +def test_fetch_loaded_models_lmstudio_uses_models_endpoint(monkeypatch): + import pyra.setup.wizard as wiz + from pyra.setup.providers import get_provider + mock_resp = MagicMock() + mock_resp.json.return_value = {"data": [{"id": "gemma-4b"}]} + mock_resp.raise_for_status = lambda: None + calls = [] + monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: (calls.append(url), mock_resp)[1]) + result = wiz.fetch_loaded_models(get_provider("lmstudio")) + assert result == ["gemma-4b"] + assert any("/models" in u for u in calls) + + +def test_fetch_loaded_models_returns_empty_on_error(monkeypatch): + import pyra.setup.wizard as wiz + from pyra.setup.providers import get_provider + monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=Exception("conn refused"))) + assert wiz.fetch_loaded_models(get_provider("ollama")) == [] + + +def test_fetch_loaded_models_returns_empty_when_no_base_url(): + import pyra.setup.wizard as wiz + from pyra.setup.providers import Provider + provider = Provider( + id="test", display_name="Test", requires_key=False, + default_model="x", litellm_prefix="openai/", group="Local", + ) + assert wiz.fetch_loaded_models(provider) == [] + + # ── _show_local_model_status ────────────────────────────────────────────────── def test_show_local_model_status_one_model(monkeypatch): import pyra.setup.wizard as wiz from pyra.setup.providers import get_provider - monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: ["gemma-4b"]) + monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: ["gemma-4b"]) printed = [] monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a))) wiz._show_local_model_status(get_provider("lmstudio")) @@ -392,7 +437,7 @@ def test_show_local_model_status_one_model(monkeypatch): def test_show_local_model_status_none(monkeypatch): import pyra.setup.wizard as wiz from pyra.setup.providers import get_provider - monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: []) + monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: []) printed = [] monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a))) wiz._show_local_model_status(get_provider("lmstudio")) @@ -402,7 +447,7 @@ def test_show_local_model_status_none(monkeypatch): def test_show_local_model_status_multiple(monkeypatch): import pyra.setup.wizard as wiz from pyra.setup.providers import get_provider - monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: ["a", "b", "c"]) + monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: ["a", "b", "c"]) printed = [] monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a))) wiz._show_local_model_status(get_provider("lmstudio"))