From 5eb81404c205afb2e81e88d1bff60fe1f03c833b Mon Sep 17 00:00:00 2001 From: curo1305 Date: Tue, 19 May 2026 10:53:15 +0200 Subject: [PATCH] feat(setup): dynamic model discovery for local providers in wizard Replace the static model text prompt with live API queries: - _fetch_local_models(): queries /v1/models (LM Studio, llama.cpp) or /api/tags (Ollama) and returns a questionary.select list - _fetch_lmstudio_available_models(): queries LM Studio's beta /api/v0/models to list downloaded-but-not-loaded models - _load_lmstudio_model(): tries /api/v0/models/load to load a model in-place; falls back to telling the user to load manually - Cloud providers keep the existing text-input behaviour Also replace hardcoded LMSTUDIO_MODEL in integration tests with a lmstudio_model fixture that queries the API at runtime and uses whichever model is currently loaded (skips if none). Co-Authored-By: Claude Sonnet 4.6 --- src/pyra/setup/wizard.py | 92 ++++++++++++++++++++++++++++-- tests/integration/test_lmstudio.py | 61 +++++++++++--------- 2 files changed, 123 insertions(+), 30 deletions(-) diff --git a/src/pyra/setup/wizard.py b/src/pyra/setup/wizard.py index d322c12..4001d66 100644 --- a/src/pyra/setup/wizard.py +++ b/src/pyra/setup/wizard.py @@ -153,11 +153,95 @@ def _check_local_server(provider: Provider) -> None: ) +def _fetch_local_models(provider: Provider) -> list[str]: + """Return currently loaded/available models from a local provider's API.""" + if not provider.base_url: + return [] + try: + if provider.id == "ollama": + resp = httpx.get(f"{provider.base_url}/api/tags", timeout=3.0) + resp.raise_for_status() + return [m["name"] for m in resp.json().get("models", [])] + else: + resp = httpx.get(f"{provider.base_url}/models", timeout=3.0) + resp.raise_for_status() + return [m["id"] for m in resp.json().get("data", [])] + except Exception: + return [] + + +def _fetch_lmstudio_available_models() -> list[str]: + """Return all downloaded (not necessarily loaded) models from LM Studio's beta API.""" + try: + resp = httpx.get("http://localhost:1234/api/v0/models", timeout=3.0) + resp.raise_for_status() + return [m["id"] for m in resp.json().get("data", [])] + except Exception: + return [] + + +def _load_lmstudio_model(model_id: str) -> bool: + """Attempt to load a model via LM Studio's beta API. Returns True on success.""" + try: + resp = httpx.post( + "http://localhost:1234/api/v0/models/load", + json={"identifier": model_id}, + timeout=60.0, + ) + return resp.is_success + except Exception: + return False + + def _choose_model(provider: Provider) -> str: - model = questionary.text( - "Model name:", - default=provider.default_model, - ).ask() + if provider.group != "Local": + model = questionary.text("Model name:", default=provider.default_model).ask() + if model is None: + raise SystemExit(0) + return model.strip() + + _MANUAL = "__manual__" + loaded = _fetch_local_models(provider) + + if loaded: + choices = loaded + [questionary.Choice("── Enter manually ──", value=_MANUAL)] + selected = questionary.select("Select model:", choices=choices).ask() + if selected is None: + raise SystemExit(0) + if selected != _MANUAL: + return selected + + elif provider.id == "lmstudio": + console.print(" [yellow]No model currently loaded in LM Studio.[/yellow]") + available = _fetch_lmstudio_available_models() + if available: + choices = available + [questionary.Choice("── Enter manually ──", value=_MANUAL)] + selected = questionary.select( + "Select a downloaded model to load:", choices=choices + ).ask() + if selected is None: + raise SystemExit(0) + if selected != _MANUAL: + console.print(f" Loading [bold]{selected}[/bold]...", end=" ") + if _load_lmstudio_model(selected): + console.print("[green]✓ Loaded[/green]") + else: + console.print( + "[yellow]Could not load via API — " + "please load the model manually in LM Studio.[/yellow]" + ) + return selected + else: + console.print(Panel( + "No models are loaded or downloaded in LM Studio.\n" + "Open LM Studio → Local Server tab → load a model, then re-run setup.", + border_style="yellow", + )) + + else: + console.print(f" [yellow]No models found at {provider.base_url}.[/yellow]") + + model = questionary.text("Model name:", default=provider.default_model).ask() if model is None: raise SystemExit(0) return model.strip() diff --git a/tests/integration/test_lmstudio.py b/tests/integration/test_lmstudio.py index 8be1f90..30cbb58 100644 --- a/tests/integration/test_lmstudio.py +++ b/tests/integration/test_lmstudio.py @@ -1,31 +1,41 @@ """ Live integration test against LM Studio at localhost:1234. -Skipped automatically if LM Studio is not running. +Skipped automatically if LM Studio is not running or no model is loaded. """ +import httpx import pytest -LMSTUDIO_MODEL = "gemma-4-e4b-uncensored-hauhaucs-aggressive" -LMSTUDIO_BASE_URL = "http://localhost:1234/v1" +_LMSTUDIO_BASE_URL = "http://localhost:1234/v1" -@pytest.fixture(autouse=True) -def require_lmstudio(): - import httpx +def _get_loaded_model() -> str | None: + """Return the first currently loaded model ID from LM Studio, or None.""" try: - r = httpx.get(f"{LMSTUDIO_BASE_URL}/models", timeout=2.0) - r.raise_for_status() + resp = httpx.get(f"{_LMSTUDIO_BASE_URL}/models", timeout=2.0) + resp.raise_for_status() + models = resp.json().get("data", []) + return models[0]["id"] if models else None except Exception: - pytest.skip("LM Studio not reachable at localhost:1234") + return None -def test_basic_completion(): +@pytest.fixture() +def lmstudio_model() -> str: + """Resolve the first loaded model in LM Studio; skip if none available.""" + model = _get_loaded_model() + if model is None: + pytest.skip("LM Studio not reachable or no model currently loaded") + return model + + +def test_basic_completion(lmstudio_model): import litellm litellm.suppress_debug_info = True response = litellm.completion( - model=f"openai/{LMSTUDIO_MODEL}", + model=f"openai/{lmstudio_model}", messages=[{"role": "user", "content": "Reply with exactly the word: PONG"}], - api_base=LMSTUDIO_BASE_URL, + api_base=_LMSTUDIO_BASE_URL, api_key="lm-studio", max_tokens=20, stream=False, @@ -34,14 +44,14 @@ def test_basic_completion(): assert text and len(text) > 0 -def test_streaming_completion(): +def test_streaming_completion(lmstudio_model): import litellm litellm.suppress_debug_info = True stream = litellm.completion( - model=f"openai/{LMSTUDIO_MODEL}", + model=f"openai/{lmstudio_model}", messages=[{"role": "user", "content": "Count from 1 to 3."}], - api_base=LMSTUDIO_BASE_URL, + api_base=_LMSTUDIO_BASE_URL, api_key="lm-studio", max_tokens=50, stream=True, @@ -52,30 +62,29 @@ def test_streaming_completion(): assert len(full_text) > 0 -def test_injection_scan_on_live_response(tmp_pyra_home): +def test_injection_scan_on_live_response(tmp_pyra_home, lmstudio_model): """Verify injection scanner runs on real model output without false positives.""" import litellm from pyra.security.injection import scan_response litellm.suppress_debug_info = True response = litellm.completion( - model=f"openai/{LMSTUDIO_MODEL}", + model=f"openai/{lmstudio_model}", messages=[{"role": "user", "content": "Explain what a list is in Python."}], - api_base=LMSTUDIO_BASE_URL, + api_base=_LMSTUDIO_BASE_URL, api_key="lm-studio", max_tokens=200, stream=False, ) text = response.choices[0].message.content warnings = scan_response(text) - # Normal responses about Python lists should not trigger injection warnings - for w in warnings: - print(f"[warning] {w.pattern_label}: {w.matched_text!r}") # Not asserting zero warnings — some models may have quirky phrasing — # but at least the scanner must not crash on real output + for w in warnings: + print(f"[warning] {w.pattern_label}: {w.matched_text!r}") -def test_pyra_chat_session_with_lmstudio(tmp_pyra_home): +def test_pyra_chat_session_with_lmstudio(tmp_pyra_home, lmstudio_model): """Full stack: config → vault → history → litellm → injection scan.""" from pyra.config.schema import PyraConfig, ProviderConfig from pyra.config.manager import save_config @@ -87,8 +96,8 @@ def test_pyra_chat_session_with_lmstudio(tmp_pyra_home): cfg = PyraConfig( ai=ProviderConfig( provider_id="lmstudio", - model=LMSTUDIO_MODEL, - base_url=LMSTUDIO_BASE_URL, + model=lmstudio_model, + base_url=_LMSTUDIO_BASE_URL, ) ) save_config(cfg) @@ -98,9 +107,9 @@ def test_pyra_chat_session_with_lmstudio(tmp_pyra_home): messages = history.build_for_api() response = litellm.completion( - model=f"openai/{LMSTUDIO_MODEL}", + model=f"openai/{lmstudio_model}", messages=messages, - api_base=LMSTUDIO_BASE_URL, + api_base=_LMSTUDIO_BASE_URL, api_key="lm-studio", max_tokens=30, stream=False,