feat(setup,chat): detect actually-loaded local model via provider-specific API
For Ollama, /api/tags returns all installed models, not running ones. Add fetch_loaded_models() using /api/ps for Ollama (and /v1/models for LM Studio/llama.cpp, which already return only loaded models). _show_local_model_status() now calls fetch_loaded_models() so the setup wizard correctly shows only in-memory models for Ollama. At chat session startup, local providers warn when the configured model is not currently loaded, or when nothing is loaded at all. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -25,6 +25,7 @@ from pyra.plugins.executor import ToolExecutor
|
||||
from pyra.plugins.registry import PluginRegistry
|
||||
from pyra.security.injection import scan_response
|
||||
from pyra.setup.providers import get_provider
|
||||
from pyra.setup.wizard import fetch_loaded_models
|
||||
from pyra.utils.paths import pyra_home
|
||||
|
||||
_HISTORY_FILE = pyra_home() / ".chat_history"
|
||||
@@ -176,6 +177,16 @@ def start_chat() -> None:
|
||||
"[dim]Type /help for commands, /quit to exit.[/dim]"
|
||||
)
|
||||
|
||||
if provider.group == "Local":
|
||||
loaded = fetch_loaded_models(provider)
|
||||
if not loaded:
|
||||
render_info(f"No model currently loaded in {provider.display_name}.")
|
||||
elif cfg.ai.model not in loaded:
|
||||
render_info(
|
||||
f"Model '{cfg.ai.model}' not loaded in {provider.display_name}. "
|
||||
f"Loaded: {', '.join(loaded)}"
|
||||
)
|
||||
|
||||
_flags: dict = {"use_tools": True}
|
||||
|
||||
while True:
|
||||
|
||||
@@ -377,9 +377,26 @@ def _check_local_server(provider: Provider) -> None:
|
||||
# "retry" → loop
|
||||
|
||||
|
||||
def fetch_loaded_models(provider: Provider) -> list[str]:
|
||||
"""Return models currently loaded in RAM from a local provider's API."""
|
||||
if not provider.base_url:
|
||||
return []
|
||||
try:
|
||||
if provider.id == "ollama":
|
||||
resp = httpx.get(f"{provider.base_url}/api/ps", timeout=3.0)
|
||||
resp.raise_for_status()
|
||||
return [m["name"] for m in resp.json().get("models", [])]
|
||||
else:
|
||||
resp = httpx.get(f"{provider.base_url}/models", timeout=3.0)
|
||||
resp.raise_for_status()
|
||||
return [m["id"] for m in resp.json().get("data", [])]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def _show_local_model_status(provider: Provider) -> None:
|
||||
"""Print a one-line status showing which models are currently loaded."""
|
||||
models = _fetch_local_models(provider)
|
||||
models = fetch_loaded_models(provider)
|
||||
if not models:
|
||||
console.print(" [yellow]No model currently loaded[/yellow]")
|
||||
elif len(models) == 1:
|
||||
|
||||
@@ -377,12 +377,57 @@ def test_draft_file_has_correct_permissions(tmp_pyra_home):
|
||||
assert mode == "600"
|
||||
|
||||
|
||||
# ── fetch_loaded_models ───────────────────────────────────────────────────────
|
||||
|
||||
def test_fetch_loaded_models_ollama_uses_api_ps(monkeypatch):
|
||||
import pyra.setup.wizard as wiz
|
||||
from pyra.setup.providers import get_provider
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = {"models": [{"name": "llama3:latest"}, {"name": "mistral"}]}
|
||||
mock_resp.raise_for_status = lambda: None
|
||||
calls = []
|
||||
monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: (calls.append(url), mock_resp)[1])
|
||||
result = wiz.fetch_loaded_models(get_provider("ollama"))
|
||||
assert result == ["llama3:latest", "mistral"]
|
||||
assert any("/api/ps" in u for u in calls)
|
||||
|
||||
|
||||
def test_fetch_loaded_models_lmstudio_uses_models_endpoint(monkeypatch):
|
||||
import pyra.setup.wizard as wiz
|
||||
from pyra.setup.providers import get_provider
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = {"data": [{"id": "gemma-4b"}]}
|
||||
mock_resp.raise_for_status = lambda: None
|
||||
calls = []
|
||||
monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: (calls.append(url), mock_resp)[1])
|
||||
result = wiz.fetch_loaded_models(get_provider("lmstudio"))
|
||||
assert result == ["gemma-4b"]
|
||||
assert any("/models" in u for u in calls)
|
||||
|
||||
|
||||
def test_fetch_loaded_models_returns_empty_on_error(monkeypatch):
|
||||
import pyra.setup.wizard as wiz
|
||||
from pyra.setup.providers import get_provider
|
||||
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=Exception("conn refused")))
|
||||
assert wiz.fetch_loaded_models(get_provider("ollama")) == []
|
||||
|
||||
|
||||
def test_fetch_loaded_models_returns_empty_when_no_base_url():
|
||||
import pyra.setup.wizard as wiz
|
||||
from pyra.setup.providers import Provider
|
||||
provider = Provider(
|
||||
id="test", display_name="Test", requires_key=False,
|
||||
default_model="x", litellm_prefix="openai/", group="Local",
|
||||
)
|
||||
assert wiz.fetch_loaded_models(provider) == []
|
||||
|
||||
|
||||
# ── _show_local_model_status ──────────────────────────────────────────────────
|
||||
|
||||
def test_show_local_model_status_one_model(monkeypatch):
|
||||
import pyra.setup.wizard as wiz
|
||||
from pyra.setup.providers import get_provider
|
||||
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: ["gemma-4b"])
|
||||
monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: ["gemma-4b"])
|
||||
printed = []
|
||||
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
|
||||
wiz._show_local_model_status(get_provider("lmstudio"))
|
||||
@@ -392,7 +437,7 @@ def test_show_local_model_status_one_model(monkeypatch):
|
||||
def test_show_local_model_status_none(monkeypatch):
|
||||
import pyra.setup.wizard as wiz
|
||||
from pyra.setup.providers import get_provider
|
||||
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: [])
|
||||
monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: [])
|
||||
printed = []
|
||||
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
|
||||
wiz._show_local_model_status(get_provider("lmstudio"))
|
||||
@@ -402,7 +447,7 @@ def test_show_local_model_status_none(monkeypatch):
|
||||
def test_show_local_model_status_multiple(monkeypatch):
|
||||
import pyra.setup.wizard as wiz
|
||||
from pyra.setup.providers import get_provider
|
||||
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: ["a", "b", "c"])
|
||||
monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: ["a", "b", "c"])
|
||||
printed = []
|
||||
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
|
||||
wiz._show_local_model_status(get_provider("lmstudio"))
|
||||
|
||||
Reference in New Issue
Block a user