Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bafdafea02 | |||
| 5eb81404c2 |
@@ -153,11 +153,95 @@ def _check_local_server(provider: Provider) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_local_models(provider: Provider) -> list[str]:
|
||||||
|
"""Return currently loaded/available models from a local provider's API."""
|
||||||
|
if not provider.base_url:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
if provider.id == "ollama":
|
||||||
|
resp = httpx.get(f"{provider.base_url}/api/tags", timeout=3.0)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return [m["name"] for m in resp.json().get("models", [])]
|
||||||
|
else:
|
||||||
|
resp = httpx.get(f"{provider.base_url}/models", timeout=3.0)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return [m["id"] for m in resp.json().get("data", [])]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_lmstudio_available_models() -> list[str]:
|
||||||
|
"""Return all downloaded (not necessarily loaded) models from LM Studio's beta API."""
|
||||||
|
try:
|
||||||
|
resp = httpx.get("http://localhost:1234/api/v0/models", timeout=3.0)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return [m["id"] for m in resp.json().get("data", [])]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _load_lmstudio_model(model_id: str) -> bool:
|
||||||
|
"""Attempt to load a model via LM Studio's beta API. Returns True on success."""
|
||||||
|
try:
|
||||||
|
resp = httpx.post(
|
||||||
|
"http://localhost:1234/api/v0/models/load",
|
||||||
|
json={"identifier": model_id},
|
||||||
|
timeout=60.0,
|
||||||
|
)
|
||||||
|
return resp.is_success
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _choose_model(provider: Provider) -> str:
|
def _choose_model(provider: Provider) -> str:
|
||||||
model = questionary.text(
|
if provider.group != "Local":
|
||||||
"Model name:",
|
model = questionary.text("Model name:", default=provider.default_model).ask()
|
||||||
default=provider.default_model,
|
if model is None:
|
||||||
|
raise SystemExit(0)
|
||||||
|
return model.strip()
|
||||||
|
|
||||||
|
_MANUAL = "__manual__"
|
||||||
|
loaded = _fetch_local_models(provider)
|
||||||
|
|
||||||
|
if loaded:
|
||||||
|
choices = loaded + [questionary.Choice("── Enter manually ──", value=_MANUAL)]
|
||||||
|
selected = questionary.select("Select model:", choices=choices).ask()
|
||||||
|
if selected is None:
|
||||||
|
raise SystemExit(0)
|
||||||
|
if selected != _MANUAL:
|
||||||
|
return selected
|
||||||
|
|
||||||
|
elif provider.id == "lmstudio":
|
||||||
|
console.print(" [yellow]No model currently loaded in LM Studio.[/yellow]")
|
||||||
|
available = _fetch_lmstudio_available_models()
|
||||||
|
if available:
|
||||||
|
choices = available + [questionary.Choice("── Enter manually ──", value=_MANUAL)]
|
||||||
|
selected = questionary.select(
|
||||||
|
"Select a downloaded model to load:", choices=choices
|
||||||
).ask()
|
).ask()
|
||||||
|
if selected is None:
|
||||||
|
raise SystemExit(0)
|
||||||
|
if selected != _MANUAL:
|
||||||
|
console.print(f" Loading [bold]{selected}[/bold]...", end=" ")
|
||||||
|
if _load_lmstudio_model(selected):
|
||||||
|
console.print("[green]✓ Loaded[/green]")
|
||||||
|
else:
|
||||||
|
console.print(
|
||||||
|
"[yellow]Could not load via API — "
|
||||||
|
"please load the model manually in LM Studio.[/yellow]"
|
||||||
|
)
|
||||||
|
return selected
|
||||||
|
else:
|
||||||
|
console.print(Panel(
|
||||||
|
"No models are loaded or downloaded in LM Studio.\n"
|
||||||
|
"Open LM Studio → Local Server tab → load a model, then re-run setup.",
|
||||||
|
border_style="yellow",
|
||||||
|
))
|
||||||
|
|
||||||
|
else:
|
||||||
|
console.print(f" [yellow]No models found at {provider.base_url}.[/yellow]")
|
||||||
|
|
||||||
|
model = questionary.text("Model name:", default=provider.default_model).ask()
|
||||||
if model is None:
|
if model is None:
|
||||||
raise SystemExit(0)
|
raise SystemExit(0)
|
||||||
return model.strip()
|
return model.strip()
|
||||||
|
|||||||
@@ -1,31 +1,41 @@
|
|||||||
"""
|
"""
|
||||||
Live integration test against LM Studio at localhost:1234.
|
Live integration test against LM Studio at localhost:1234.
|
||||||
Skipped automatically if LM Studio is not running.
|
Skipped automatically if LM Studio is not running or no model is loaded.
|
||||||
"""
|
"""
|
||||||
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
LMSTUDIO_MODEL = "gemma-4-e4b-uncensored-hauhaucs-aggressive"
|
_LMSTUDIO_BASE_URL = "http://localhost:1234/v1"
|
||||||
LMSTUDIO_BASE_URL = "http://localhost:1234/v1"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
def _get_loaded_model() -> str | None:
|
||||||
def require_lmstudio():
|
"""Return the first currently loaded model ID from LM Studio, or None."""
|
||||||
import httpx
|
|
||||||
try:
|
try:
|
||||||
r = httpx.get(f"{LMSTUDIO_BASE_URL}/models", timeout=2.0)
|
resp = httpx.get(f"{_LMSTUDIO_BASE_URL}/models", timeout=2.0)
|
||||||
r.raise_for_status()
|
resp.raise_for_status()
|
||||||
|
models = resp.json().get("data", [])
|
||||||
|
return models[0]["id"] if models else None
|
||||||
except Exception:
|
except Exception:
|
||||||
pytest.skip("LM Studio not reachable at localhost:1234")
|
return None
|
||||||
|
|
||||||
|
|
||||||
def test_basic_completion():
|
@pytest.fixture()
|
||||||
|
def lmstudio_model() -> str:
|
||||||
|
"""Resolve the first loaded model in LM Studio; skip if none available."""
|
||||||
|
model = _get_loaded_model()
|
||||||
|
if model is None:
|
||||||
|
pytest.skip("LM Studio not reachable or no model currently loaded")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_completion(lmstudio_model):
|
||||||
import litellm
|
import litellm
|
||||||
litellm.suppress_debug_info = True
|
litellm.suppress_debug_info = True
|
||||||
|
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model=f"openai/{LMSTUDIO_MODEL}",
|
model=f"openai/{lmstudio_model}",
|
||||||
messages=[{"role": "user", "content": "Reply with exactly the word: PONG"}],
|
messages=[{"role": "user", "content": "Reply with exactly the word: PONG"}],
|
||||||
api_base=LMSTUDIO_BASE_URL,
|
api_base=_LMSTUDIO_BASE_URL,
|
||||||
api_key="lm-studio",
|
api_key="lm-studio",
|
||||||
max_tokens=20,
|
max_tokens=20,
|
||||||
stream=False,
|
stream=False,
|
||||||
@@ -34,14 +44,14 @@ def test_basic_completion():
|
|||||||
assert text and len(text) > 0
|
assert text and len(text) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_streaming_completion():
|
def test_streaming_completion(lmstudio_model):
|
||||||
import litellm
|
import litellm
|
||||||
litellm.suppress_debug_info = True
|
litellm.suppress_debug_info = True
|
||||||
|
|
||||||
stream = litellm.completion(
|
stream = litellm.completion(
|
||||||
model=f"openai/{LMSTUDIO_MODEL}",
|
model=f"openai/{lmstudio_model}",
|
||||||
messages=[{"role": "user", "content": "Count from 1 to 3."}],
|
messages=[{"role": "user", "content": "Count from 1 to 3."}],
|
||||||
api_base=LMSTUDIO_BASE_URL,
|
api_base=_LMSTUDIO_BASE_URL,
|
||||||
api_key="lm-studio",
|
api_key="lm-studio",
|
||||||
max_tokens=50,
|
max_tokens=50,
|
||||||
stream=True,
|
stream=True,
|
||||||
@@ -52,30 +62,29 @@ def test_streaming_completion():
|
|||||||
assert len(full_text) > 0
|
assert len(full_text) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_injection_scan_on_live_response(tmp_pyra_home):
|
def test_injection_scan_on_live_response(tmp_pyra_home, lmstudio_model):
|
||||||
"""Verify injection scanner runs on real model output without false positives."""
|
"""Verify injection scanner runs on real model output without false positives."""
|
||||||
import litellm
|
import litellm
|
||||||
from pyra.security.injection import scan_response
|
from pyra.security.injection import scan_response
|
||||||
litellm.suppress_debug_info = True
|
litellm.suppress_debug_info = True
|
||||||
|
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model=f"openai/{LMSTUDIO_MODEL}",
|
model=f"openai/{lmstudio_model}",
|
||||||
messages=[{"role": "user", "content": "Explain what a list is in Python."}],
|
messages=[{"role": "user", "content": "Explain what a list is in Python."}],
|
||||||
api_base=LMSTUDIO_BASE_URL,
|
api_base=_LMSTUDIO_BASE_URL,
|
||||||
api_key="lm-studio",
|
api_key="lm-studio",
|
||||||
max_tokens=200,
|
max_tokens=200,
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
text = response.choices[0].message.content
|
text = response.choices[0].message.content
|
||||||
warnings = scan_response(text)
|
warnings = scan_response(text)
|
||||||
# Normal responses about Python lists should not trigger injection warnings
|
|
||||||
for w in warnings:
|
|
||||||
print(f"[warning] {w.pattern_label}: {w.matched_text!r}")
|
|
||||||
# Not asserting zero warnings — some models may have quirky phrasing —
|
# Not asserting zero warnings — some models may have quirky phrasing —
|
||||||
# but at least the scanner must not crash on real output
|
# but at least the scanner must not crash on real output
|
||||||
|
for w in warnings:
|
||||||
|
print(f"[warning] {w.pattern_label}: {w.matched_text!r}")
|
||||||
|
|
||||||
|
|
||||||
def test_pyra_chat_session_with_lmstudio(tmp_pyra_home):
|
def test_pyra_chat_session_with_lmstudio(tmp_pyra_home, lmstudio_model):
|
||||||
"""Full stack: config → vault → history → litellm → injection scan."""
|
"""Full stack: config → vault → history → litellm → injection scan."""
|
||||||
from pyra.config.schema import PyraConfig, ProviderConfig
|
from pyra.config.schema import PyraConfig, ProviderConfig
|
||||||
from pyra.config.manager import save_config
|
from pyra.config.manager import save_config
|
||||||
@@ -87,8 +96,8 @@ def test_pyra_chat_session_with_lmstudio(tmp_pyra_home):
|
|||||||
cfg = PyraConfig(
|
cfg = PyraConfig(
|
||||||
ai=ProviderConfig(
|
ai=ProviderConfig(
|
||||||
provider_id="lmstudio",
|
provider_id="lmstudio",
|
||||||
model=LMSTUDIO_MODEL,
|
model=lmstudio_model,
|
||||||
base_url=LMSTUDIO_BASE_URL,
|
base_url=_LMSTUDIO_BASE_URL,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
save_config(cfg)
|
save_config(cfg)
|
||||||
@@ -98,9 +107,9 @@ def test_pyra_chat_session_with_lmstudio(tmp_pyra_home):
|
|||||||
messages = history.build_for_api()
|
messages = history.build_for_api()
|
||||||
|
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model=f"openai/{LMSTUDIO_MODEL}",
|
model=f"openai/{lmstudio_model}",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base=LMSTUDIO_BASE_URL,
|
api_base=_LMSTUDIO_BASE_URL,
|
||||||
api_key="lm-studio",
|
api_key="lm-studio",
|
||||||
max_tokens=30,
|
max_tokens=30,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
"""Tests for setup wizard personalization helpers."""
|
"""Tests for setup wizard personalization and model-discovery helpers."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -90,3 +92,92 @@ def test_suggest_plugins_multiple_categories(monkeypatch):
|
|||||||
combined = " ".join(str(p) for p in panels)
|
combined = " ".join(str(p) for p in panels)
|
||||||
assert "email" in combined
|
assert "email" in combined
|
||||||
assert "ssh_tool" in combined
|
assert "ssh_tool" in combined
|
||||||
|
|
||||||
|
|
||||||
|
# ── _fetch_local_models ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_fetch_local_models_lmstudio_returns_model_ids(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.json.return_value = {"data": [{"id": "gemma-4b"}, {"id": "llama3"}]}
|
||||||
|
mock_resp.raise_for_status = lambda: None
|
||||||
|
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
|
||||||
|
from pyra.setup.providers import get_provider
|
||||||
|
assert wiz._fetch_local_models(get_provider("lmstudio")) == ["gemma-4b", "llama3"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_local_models_ollama_returns_model_names(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.json.return_value = {"models": [{"name": "llama3:latest"}, {"name": "mistral"}]}
|
||||||
|
mock_resp.raise_for_status = lambda: None
|
||||||
|
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
|
||||||
|
from pyra.setup.providers import get_provider
|
||||||
|
assert wiz._fetch_local_models(get_provider("ollama")) == ["llama3:latest", "mistral"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_local_models_returns_empty_on_connection_error(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=Exception("conn refused")))
|
||||||
|
from pyra.setup.providers import get_provider
|
||||||
|
assert wiz._fetch_local_models(get_provider("lmstudio")) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_local_models_returns_empty_when_no_base_url():
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
from pyra.setup.providers import Provider
|
||||||
|
provider = Provider(
|
||||||
|
id="test", display_name="Test", requires_key=False,
|
||||||
|
default_model="x", litellm_prefix="openai/", group="Local",
|
||||||
|
)
|
||||||
|
assert wiz._fetch_local_models(provider) == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── _fetch_lmstudio_available_models ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_fetch_lmstudio_available_models_returns_ids(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.json.return_value = {"data": [{"id": "model-a"}, {"id": "model-b"}]}
|
||||||
|
mock_resp.raise_for_status = lambda: None
|
||||||
|
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
|
||||||
|
assert wiz._fetch_lmstudio_available_models() == ["model-a", "model-b"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_lmstudio_available_models_returns_empty_on_error(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=Exception("not found")))
|
||||||
|
assert wiz._fetch_lmstudio_available_models() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_lmstudio_available_models_empty_data(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.json.return_value = {"data": []}
|
||||||
|
mock_resp.raise_for_status = lambda: None
|
||||||
|
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
|
||||||
|
assert wiz._fetch_lmstudio_available_models() == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── _load_lmstudio_model ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_load_lmstudio_model_returns_true_on_success(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.is_success = True
|
||||||
|
monkeypatch.setattr(wiz.httpx, "post", lambda *a, **kw: mock_resp)
|
||||||
|
assert wiz._load_lmstudio_model("gemma-4b") is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_lmstudio_model_returns_false_on_api_failure(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.is_success = False
|
||||||
|
monkeypatch.setattr(wiz.httpx, "post", lambda *a, **kw: mock_resp)
|
||||||
|
assert wiz._load_lmstudio_model("gemma-4b") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_lmstudio_model_returns_false_on_exception(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
monkeypatch.setattr(wiz.httpx, "post", MagicMock(side_effect=Exception("timeout")))
|
||||||
|
assert wiz._load_lmstudio_model("gemma-4b") is False
|
||||||
|
|||||||
Reference in New Issue
Block a user