feat(tui): AI provider tab + expanded General settings
Add a dedicated AI tab with provider Select, model Input, base URL Input, and masked API key Input (write-only, stored in vault). Switching providers reactively updates the model placeholder, base URL default, and shows/hides the API key row for cloud vs. local providers. ctrl+s saves config and vault. Extend GENERAL_FIELDS with Memory, Security, Plugin, and Daemon sections using a new "section" header type and optional int cast for numeric fields. _CoreField gains cast: type | None for automatic value coercion on save. Add 5 new tests covering AI tab rendering, config save, vault key write, vault key skip-on-empty, and section header rendering. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+120
-7
@@ -7,27 +7,46 @@ from textual.binding import Binding
|
||||
from textual.containers import Horizontal, VerticalScroll
|
||||
from textual.coordinate import Coordinate
|
||||
from textual.widget import Widget
|
||||
from textual.widgets import DataTable, Footer, Input, Label, Static, Switch, TabbedContent, TabPane
|
||||
from textual.widgets import DataTable, Footer, Input, Label, Select, Static, Switch, TabbedContent, TabPane
|
||||
|
||||
from pyra.config.manager import load_config, save_config
|
||||
from pyra.plugins.base import BasePlugin, ConfigField
|
||||
from pyra.plugins.install import read_manifest
|
||||
from pyra.plugins.loader import load_plugin_by_name
|
||||
from pyra.setup.providers import PROVIDERS, PROVIDERS_BY_ID
|
||||
from pyra.utils.paths import pyra_home
|
||||
from pyra.vault.reader import get_key
|
||||
from pyra.vault.writer import set_key
|
||||
|
||||
|
||||
class _CoreField(NamedTuple):
|
||||
path: str # dotted path in PyraConfig, e.g. "general.user_name"
|
||||
path: str # dotted path in PyraConfig; empty string for section headers
|
||||
label: str
|
||||
type: str # "text" | "bool"
|
||||
type: str # "text" | "bool" | "section"
|
||||
default: Any
|
||||
cast: type | None = None # e.g. int — coerce text input value on save
|
||||
|
||||
|
||||
# ── Add new core settings here — one line each ────────────────────────────────
|
||||
# ── Add new core settings here — one entry each ───────────────────────────────
|
||||
GENERAL_FIELDS: list[_CoreField] = [
|
||||
_CoreField("general.user_name", "Your name", "text", "User"),
|
||||
_CoreField("general.assistant_name", "Assistant name", "text", "Pyra"),
|
||||
_CoreField("daemon.enabled", "Enable daemon", "bool", False),
|
||||
_CoreField("", "── General ─────────────────────────────────────────", "section", None),
|
||||
_CoreField("general.user_name", "Your name", "text", "User"),
|
||||
_CoreField("general.assistant_name", "Assistant name", "text", "Pyra"),
|
||||
|
||||
_CoreField("", "── Memory ──────────────────────────────────────────", "section", None),
|
||||
_CoreField("memory.max_tokens_in_context", "Context limit (tokens)", "text", 4000, int),
|
||||
_CoreField("memory.auto_load", "Auto-load memory", "bool", True),
|
||||
|
||||
_CoreField("", "── Security ────────────────────────────────────────", "section", None),
|
||||
_CoreField("security.injection_detection", "Injection detection", "bool", True),
|
||||
_CoreField("security.log_injections", "Log injection events", "bool", True),
|
||||
|
||||
_CoreField("", "── Plugins ─────────────────────────────────────────", "section", None),
|
||||
_CoreField("plugins.require_approval", "Require tool approval", "bool", True),
|
||||
_CoreField("plugins.log_executions", "Log tool executions", "bool", True),
|
||||
|
||||
_CoreField("", "── Daemon ──────────────────────────────────────────", "section", None),
|
||||
_CoreField("daemon.enabled", "Enable daemon", "bool", False),
|
||||
]
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -81,12 +100,90 @@ class _TitleBar(Static):
|
||||
|
||||
# ── Tab widgets ───────────────────────────────────────────────────────────────
|
||||
|
||||
class _AITab(VerticalScroll):
|
||||
BINDINGS = [Binding("ctrl+s", "save", "Save")]
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
cfg = load_config()
|
||||
provider = PROVIDERS_BY_ID.get(cfg.ai.provider_id, PROVIDERS[0])
|
||||
with Horizontal(classes="row"):
|
||||
yield Label("Provider")
|
||||
yield Select(
|
||||
[(p.display_name, p.id) for p in PROVIDERS],
|
||||
value=cfg.ai.provider_id,
|
||||
id="ai-provider",
|
||||
)
|
||||
with Horizontal(classes="row"):
|
||||
yield Label("Model")
|
||||
yield Input(value=cfg.ai.model, placeholder=provider.default_model, id="ai-model")
|
||||
with Horizontal(classes="row"):
|
||||
yield Label("Base URL")
|
||||
yield Input(value=cfg.ai.base_url or "", placeholder="Optional custom endpoint", id="ai-base-url")
|
||||
yield Label("Leave blank to use provider default", classes="hint")
|
||||
try:
|
||||
has_key = get_key(cfg.ai.provider_id) is not None
|
||||
except Exception:
|
||||
has_key = False
|
||||
with Horizontal(classes="row", id="ai-key-row"):
|
||||
yield Label("API Key")
|
||||
yield Input(placeholder="set" if has_key else "not set", password=True, id="ai-key")
|
||||
yield Label("Leave blank to keep existing key", classes="hint", id="ai-key-hint")
|
||||
|
||||
def on_mount(self) -> None:
|
||||
try:
|
||||
self._update_provider_ui(load_config().ai.provider_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def on_select_changed(self, event: Select.Changed) -> None:
|
||||
if event.select.id == "ai-provider":
|
||||
self._update_provider_ui(str(event.value))
|
||||
|
||||
def _update_provider_ui(self, provider_id: str) -> None:
|
||||
provider = PROVIDERS_BY_ID.get(provider_id)
|
||||
if not provider:
|
||||
return
|
||||
self.query_one("#ai-model", Input).placeholder = provider.default_model
|
||||
if not self.query_one("#ai-base-url", Input).value:
|
||||
self.query_one("#ai-base-url", Input).value = provider.base_url or ""
|
||||
show_key = provider.requires_key
|
||||
self.query_one("#ai-key-row").display = show_key
|
||||
self.query_one("#ai-key-hint").display = show_key
|
||||
|
||||
def action_save(self) -> None:
|
||||
self._do_save()
|
||||
|
||||
def _do_save(self) -> None:
|
||||
sel = self.query_one("#ai-provider", Select)
|
||||
if sel.value is Select.BLANK:
|
||||
return
|
||||
provider_id = str(sel.value)
|
||||
provider = PROVIDERS_BY_ID[provider_id]
|
||||
model = self.query_one("#ai-model", Input).value.strip() or provider.default_model
|
||||
base_url = self.query_one("#ai-base-url", Input).value.strip() or None
|
||||
api_key = self.query_one("#ai-key", Input).value.strip()
|
||||
|
||||
cfg = load_config()
|
||||
cfg.ai.provider_id = provider_id
|
||||
cfg.ai.model = model
|
||||
cfg.ai.base_url = base_url
|
||||
save_config(cfg)
|
||||
|
||||
if api_key:
|
||||
set_key(provider_id, api_key)
|
||||
|
||||
self.app.notify("AI settings saved.")
|
||||
|
||||
|
||||
class _GeneralTab(VerticalScroll):
|
||||
BINDINGS = [Binding("ctrl+s", "save", "Save")]
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
cfg = load_config()
|
||||
for f in GENERAL_FIELDS:
|
||||
if f.type == "section":
|
||||
yield Label(f.label, classes="section-header")
|
||||
continue
|
||||
current = _get_nested(cfg, f.path)
|
||||
with Horizontal(classes="row"):
|
||||
yield Label(f.label)
|
||||
@@ -101,11 +198,18 @@ class _GeneralTab(VerticalScroll):
|
||||
def _do_save(self) -> None:
|
||||
cfg = load_config()
|
||||
for f in GENERAL_FIELDS:
|
||||
if f.type == "section":
|
||||
continue
|
||||
wid = _fid(f.path)
|
||||
if f.type == "bool":
|
||||
cfg_val: Any = self.query_one(f"#{wid}", Switch).value
|
||||
else:
|
||||
cfg_val = self.query_one(f"#{wid}", Input).value
|
||||
if f.cast:
|
||||
try:
|
||||
cfg_val = f.cast(cfg_val)
|
||||
except (ValueError, TypeError):
|
||||
cfg_val = f.default
|
||||
_set_nested(cfg, f.path, cfg_val)
|
||||
save_config(cfg)
|
||||
self.app.notify("General settings saved.")
|
||||
@@ -224,6 +328,12 @@ class ConfigApp(App):
|
||||
Tab.-active { color: #ffffff; text-style: bold; background: #1a1a1a; }
|
||||
Input { border: ascii #444444; background: #111111; color: #ffffff; }
|
||||
Input:focus { border: ascii #888888; }
|
||||
Select { border: ascii #444444; background: #111111; color: #c8c8c8; }
|
||||
Select:focus { border: ascii #888888; }
|
||||
Select > .select--arrow { color: #888888; }
|
||||
SelectOverlay { background: #111111; border: ascii #888888; }
|
||||
SelectOverlay > .option-list--option { color: #c8c8c8; }
|
||||
SelectOverlay > .option-list--option-highlighted { background: #2a2a2a; color: #ffffff; }
|
||||
Switch { background: #111111; }
|
||||
DataTable { border: ascii #444444; height: 1fr; background: #0d0d0d; }
|
||||
DataTable > .datatable--header { text-style: bold; color: #aaaaaa; background: #1a1a1a; }
|
||||
@@ -233,12 +343,15 @@ class ConfigApp(App):
|
||||
.row { height: 3; margin: 0 2; }
|
||||
.row Label { width: 26; content-align: left middle; color: #aaaaaa; }
|
||||
.hint { color: #555555; margin: 0 2 1 28; }
|
||||
.section-header { color: #555555; height: 2; padding: 1 2 0 2; }
|
||||
"""
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield _TitleBar("PYRA CONFIGURATION")
|
||||
plugins = _installed_plugins()
|
||||
with TabbedContent():
|
||||
with TabPane("AI"):
|
||||
yield _AITab()
|
||||
with TabPane("General"):
|
||||
yield _GeneralTab()
|
||||
with TabPane("Plugins"):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from textual.widgets import DataTable, Input, Switch, TabbedContent
|
||||
from textual.widgets import DataTable, Input, Label, Select, Switch, TabbedContent
|
||||
|
||||
from pyra.config.schema import ProviderConfig, PyraConfig
|
||||
|
||||
@@ -68,7 +68,7 @@ def test_general_fields_non_empty():
|
||||
def test_general_fields_all_valid_types():
|
||||
from pyra.config.tui import GENERAL_FIELDS
|
||||
|
||||
valid_types = {"text", "bool", "select"}
|
||||
valid_types = {"text", "bool", "select", "section"}
|
||||
for f in GENERAL_FIELDS:
|
||||
assert f.type in valid_types, f"Field '{f.path}' has unexpected type '{f.type}'"
|
||||
|
||||
@@ -82,6 +82,8 @@ async def test_config_app_renders_all_general_fields(tmp_pyra_home):
|
||||
save_config(_make_cfg())
|
||||
async with ConfigApp().run_test() as pilot:
|
||||
for f in GENERAL_FIELDS:
|
||||
if f.type == "section":
|
||||
continue
|
||||
wid = _fid(f.path)
|
||||
if f.type == "bool":
|
||||
assert pilot.app.query_one(f"#{wid}", Switch)
|
||||
@@ -150,8 +152,8 @@ async def test_plugin_config_tab_appears_for_plugin_with_config_fields(tmp_pyra_
|
||||
async with ConfigApp().run_test() as pilot:
|
||||
content = pilot.app.query_one(TabbedContent)
|
||||
tab_count = len(list(content.query("TabPane")))
|
||||
# General + Plugins + fake plugin tab
|
||||
assert tab_count == 3
|
||||
# AI + General + Plugins + fake plugin tab
|
||||
assert tab_count == 4
|
||||
|
||||
|
||||
async def test_q_key_exits_app(tmp_pyra_home):
|
||||
@@ -162,3 +164,101 @@ async def test_q_key_exits_app(tmp_pyra_home):
|
||||
async with ConfigApp().run_test() as pilot:
|
||||
await pilot.press("q")
|
||||
# Reaching here means the app exited cleanly
|
||||
|
||||
|
||||
async def test_ai_tab_renders_provider_fields(tmp_pyra_home):
|
||||
from pyra.config.manager import save_config
|
||||
from pyra.config.tui import ConfigApp
|
||||
|
||||
save_config(_make_cfg())
|
||||
async with ConfigApp().run_test() as pilot:
|
||||
assert pilot.app.query_one("#ai-provider", Select)
|
||||
assert pilot.app.query_one("#ai-model", Input)
|
||||
assert pilot.app.query_one("#ai-base-url", Input)
|
||||
|
||||
|
||||
async def test_ai_tab_save_updates_config(tmp_pyra_home):
|
||||
from textual.app import App as TextualApp, ComposeResult as CR
|
||||
|
||||
from pyra.config.manager import load_config, save_config as initial_save
|
||||
from pyra.config.tui import _AITab
|
||||
|
||||
initial_save(_make_cfg())
|
||||
|
||||
class _TestApp(TextualApp):
|
||||
def compose(self) -> CR:
|
||||
yield _AITab()
|
||||
|
||||
saved = []
|
||||
with patch("pyra.config.tui.save_config", side_effect=lambda c: saved.append(c) or None):
|
||||
async with _TestApp().run_test() as pilot:
|
||||
pilot.app.query_one("#ai-model", Input).value = "llama3:70b"
|
||||
await pilot.pause()
|
||||
await pilot.press("ctrl+s")
|
||||
|
||||
assert saved, "save_config was not called"
|
||||
assert saved[-1].ai.model == "llama3:70b"
|
||||
|
||||
|
||||
async def test_ai_tab_save_calls_set_key_when_provided(tmp_pyra_home):
|
||||
from textual.app import App as TextualApp, ComposeResult as CR
|
||||
|
||||
from pyra.config.manager import save_config as initial_save
|
||||
from pyra.config.tui import _AITab
|
||||
|
||||
initial_save(_make_cfg())
|
||||
|
||||
class _TestApp(TextualApp):
|
||||
def compose(self) -> CR:
|
||||
yield _AITab()
|
||||
|
||||
calls = []
|
||||
with patch("pyra.config.tui.save_config"):
|
||||
with patch("pyra.config.tui.set_key", side_effect=lambda p, k: calls.append((p, k))):
|
||||
async with _TestApp().run_test() as pilot:
|
||||
pilot.app.query_one("#ai-key", Input).value = "sk-test"
|
||||
await pilot.pause()
|
||||
await pilot.press("ctrl+s")
|
||||
|
||||
assert calls, "set_key was not called"
|
||||
assert calls[-1][1] == "sk-test"
|
||||
|
||||
|
||||
async def test_ai_tab_save_skips_set_key_when_empty(tmp_pyra_home):
|
||||
from textual.app import App as TextualApp, ComposeResult as CR
|
||||
|
||||
from pyra.config.manager import save_config as initial_save
|
||||
from pyra.config.tui import _AITab
|
||||
|
||||
initial_save(_make_cfg())
|
||||
|
||||
class _TestApp(TextualApp):
|
||||
def compose(self) -> CR:
|
||||
yield _AITab()
|
||||
|
||||
calls = []
|
||||
with patch("pyra.config.tui.save_config"):
|
||||
with patch("pyra.config.tui.set_key", side_effect=lambda p, k: calls.append((p, k))):
|
||||
async with _TestApp().run_test() as pilot:
|
||||
# Leave api-key empty (default)
|
||||
await pilot.pause()
|
||||
await pilot.press("ctrl+s")
|
||||
|
||||
assert not calls, "set_key should not be called when key input is empty"
|
||||
|
||||
|
||||
async def test_general_tab_renders_section_headers(tmp_pyra_home):
|
||||
from textual.app import App as TextualApp, ComposeResult as CR
|
||||
|
||||
from pyra.config.manager import save_config as initial_save
|
||||
from pyra.config.tui import _GeneralTab
|
||||
|
||||
initial_save(_make_cfg())
|
||||
|
||||
class _TestApp(TextualApp):
|
||||
def compose(self) -> CR:
|
||||
yield _GeneralTab()
|
||||
|
||||
async with _TestApp().run_test() as pilot:
|
||||
headers = list(pilot.app.query(".section-header"))
|
||||
assert len(headers) >= 5, "Expected at least 5 section headers"
|
||||
|
||||
Reference in New Issue
Block a user