diff --git a/src/pyra/config/tui.py b/src/pyra/config/tui.py index 2500118..c9d0556 100644 --- a/src/pyra/config/tui.py +++ b/src/pyra/config/tui.py @@ -7,27 +7,46 @@ from textual.binding import Binding from textual.containers import Horizontal, VerticalScroll from textual.coordinate import Coordinate from textual.widget import Widget -from textual.widgets import DataTable, Footer, Input, Label, Static, Switch, TabbedContent, TabPane +from textual.widgets import DataTable, Footer, Input, Label, Select, Static, Switch, TabbedContent, TabPane from pyra.config.manager import load_config, save_config from pyra.plugins.base import BasePlugin, ConfigField from pyra.plugins.install import read_manifest from pyra.plugins.loader import load_plugin_by_name +from pyra.setup.providers import PROVIDERS, PROVIDERS_BY_ID from pyra.utils.paths import pyra_home +from pyra.vault.reader import get_key +from pyra.vault.writer import set_key class _CoreField(NamedTuple): - path: str # dotted path in PyraConfig, e.g. "general.user_name" + path: str # dotted path in PyraConfig; empty string for section headers label: str - type: str # "text" | "bool" + type: str # "text" | "bool" | "section" default: Any + cast: type | None = None # e.g. int — coerce text input value on save -# ── Add new core settings here — one line each ──────────────────────────────── +# ── Add new core settings here — one entry each ─────────────────────────────── GENERAL_FIELDS: list[_CoreField] = [ - _CoreField("general.user_name", "Your name", "text", "User"), - _CoreField("general.assistant_name", "Assistant name", "text", "Pyra"), - _CoreField("daemon.enabled", "Enable daemon", "bool", False), + _CoreField("", "── General ─────────────────────────────────────────", "section", None), + _CoreField("general.user_name", "Your name", "text", "User"), + _CoreField("general.assistant_name", "Assistant name", "text", "Pyra"), + + _CoreField("", "── Memory ──────────────────────────────────────────", "section", None), + _CoreField("memory.max_tokens_in_context", "Context limit (tokens)", "text", 4000, int), + _CoreField("memory.auto_load", "Auto-load memory", "bool", True), + + _CoreField("", "── Security ────────────────────────────────────────", "section", None), + _CoreField("security.injection_detection", "Injection detection", "bool", True), + _CoreField("security.log_injections", "Log injection events", "bool", True), + + _CoreField("", "── Plugins ─────────────────────────────────────────", "section", None), + _CoreField("plugins.require_approval", "Require tool approval", "bool", True), + _CoreField("plugins.log_executions", "Log tool executions", "bool", True), + + _CoreField("", "── Daemon ──────────────────────────────────────────", "section", None), + _CoreField("daemon.enabled", "Enable daemon", "bool", False), ] # ───────────────────────────────────────────────────────────────────────────── @@ -81,12 +100,90 @@ class _TitleBar(Static): # ── Tab widgets ─────────────────────────────────────────────────────────────── +class _AITab(VerticalScroll): + BINDINGS = [Binding("ctrl+s", "save", "Save")] + + def compose(self) -> ComposeResult: + cfg = load_config() + provider = PROVIDERS_BY_ID.get(cfg.ai.provider_id, PROVIDERS[0]) + with Horizontal(classes="row"): + yield Label("Provider") + yield Select( + [(p.display_name, p.id) for p in PROVIDERS], + value=cfg.ai.provider_id, + id="ai-provider", + ) + with Horizontal(classes="row"): + yield Label("Model") + yield Input(value=cfg.ai.model, placeholder=provider.default_model, id="ai-model") + with Horizontal(classes="row"): + yield Label("Base URL") + yield Input(value=cfg.ai.base_url or "", placeholder="Optional custom endpoint", id="ai-base-url") + yield Label("Leave blank to use provider default", classes="hint") + try: + has_key = get_key(cfg.ai.provider_id) is not None + except Exception: + has_key = False + with Horizontal(classes="row", id="ai-key-row"): + yield Label("API Key") + yield Input(placeholder="set" if has_key else "not set", password=True, id="ai-key") + yield Label("Leave blank to keep existing key", classes="hint", id="ai-key-hint") + + def on_mount(self) -> None: + try: + self._update_provider_ui(load_config().ai.provider_id) + except Exception: + pass + + def on_select_changed(self, event: Select.Changed) -> None: + if event.select.id == "ai-provider": + self._update_provider_ui(str(event.value)) + + def _update_provider_ui(self, provider_id: str) -> None: + provider = PROVIDERS_BY_ID.get(provider_id) + if not provider: + return + self.query_one("#ai-model", Input).placeholder = provider.default_model + if not self.query_one("#ai-base-url", Input).value: + self.query_one("#ai-base-url", Input).value = provider.base_url or "" + show_key = provider.requires_key + self.query_one("#ai-key-row").display = show_key + self.query_one("#ai-key-hint").display = show_key + + def action_save(self) -> None: + self._do_save() + + def _do_save(self) -> None: + sel = self.query_one("#ai-provider", Select) + if sel.value is Select.BLANK: + return + provider_id = str(sel.value) + provider = PROVIDERS_BY_ID[provider_id] + model = self.query_one("#ai-model", Input).value.strip() or provider.default_model + base_url = self.query_one("#ai-base-url", Input).value.strip() or None + api_key = self.query_one("#ai-key", Input).value.strip() + + cfg = load_config() + cfg.ai.provider_id = provider_id + cfg.ai.model = model + cfg.ai.base_url = base_url + save_config(cfg) + + if api_key: + set_key(provider_id, api_key) + + self.app.notify("AI settings saved.") + + class _GeneralTab(VerticalScroll): BINDINGS = [Binding("ctrl+s", "save", "Save")] def compose(self) -> ComposeResult: cfg = load_config() for f in GENERAL_FIELDS: + if f.type == "section": + yield Label(f.label, classes="section-header") + continue current = _get_nested(cfg, f.path) with Horizontal(classes="row"): yield Label(f.label) @@ -101,11 +198,18 @@ class _GeneralTab(VerticalScroll): def _do_save(self) -> None: cfg = load_config() for f in GENERAL_FIELDS: + if f.type == "section": + continue wid = _fid(f.path) if f.type == "bool": cfg_val: Any = self.query_one(f"#{wid}", Switch).value else: cfg_val = self.query_one(f"#{wid}", Input).value + if f.cast: + try: + cfg_val = f.cast(cfg_val) + except (ValueError, TypeError): + cfg_val = f.default _set_nested(cfg, f.path, cfg_val) save_config(cfg) self.app.notify("General settings saved.") @@ -224,6 +328,12 @@ class ConfigApp(App): Tab.-active { color: #ffffff; text-style: bold; background: #1a1a1a; } Input { border: ascii #444444; background: #111111; color: #ffffff; } Input:focus { border: ascii #888888; } + Select { border: ascii #444444; background: #111111; color: #c8c8c8; } + Select:focus { border: ascii #888888; } + Select > .select--arrow { color: #888888; } + SelectOverlay { background: #111111; border: ascii #888888; } + SelectOverlay > .option-list--option { color: #c8c8c8; } + SelectOverlay > .option-list--option-highlighted { background: #2a2a2a; color: #ffffff; } Switch { background: #111111; } DataTable { border: ascii #444444; height: 1fr; background: #0d0d0d; } DataTable > .datatable--header { text-style: bold; color: #aaaaaa; background: #1a1a1a; } @@ -233,12 +343,15 @@ class ConfigApp(App): .row { height: 3; margin: 0 2; } .row Label { width: 26; content-align: left middle; color: #aaaaaa; } .hint { color: #555555; margin: 0 2 1 28; } + .section-header { color: #555555; height: 2; padding: 1 2 0 2; } """ def compose(self) -> ComposeResult: yield _TitleBar("PYRA CONFIGURATION") plugins = _installed_plugins() with TabbedContent(): + with TabPane("AI"): + yield _AITab() with TabPane("General"): yield _GeneralTab() with TabPane("Plugins"): diff --git a/tests/unit/test_config_tui.py b/tests/unit/test_config_tui.py index 8b38201..9f2e2d6 100644 --- a/tests/unit/test_config_tui.py +++ b/tests/unit/test_config_tui.py @@ -1,7 +1,7 @@ import json from unittest.mock import patch -from textual.widgets import DataTable, Input, Switch, TabbedContent +from textual.widgets import DataTable, Input, Label, Select, Switch, TabbedContent from pyra.config.schema import ProviderConfig, PyraConfig @@ -68,7 +68,7 @@ def test_general_fields_non_empty(): def test_general_fields_all_valid_types(): from pyra.config.tui import GENERAL_FIELDS - valid_types = {"text", "bool", "select"} + valid_types = {"text", "bool", "select", "section"} for f in GENERAL_FIELDS: assert f.type in valid_types, f"Field '{f.path}' has unexpected type '{f.type}'" @@ -82,6 +82,8 @@ async def test_config_app_renders_all_general_fields(tmp_pyra_home): save_config(_make_cfg()) async with ConfigApp().run_test() as pilot: for f in GENERAL_FIELDS: + if f.type == "section": + continue wid = _fid(f.path) if f.type == "bool": assert pilot.app.query_one(f"#{wid}", Switch) @@ -150,8 +152,8 @@ async def test_plugin_config_tab_appears_for_plugin_with_config_fields(tmp_pyra_ async with ConfigApp().run_test() as pilot: content = pilot.app.query_one(TabbedContent) tab_count = len(list(content.query("TabPane"))) - # General + Plugins + fake plugin tab - assert tab_count == 3 + # AI + General + Plugins + fake plugin tab + assert tab_count == 4 async def test_q_key_exits_app(tmp_pyra_home): @@ -162,3 +164,101 @@ async def test_q_key_exits_app(tmp_pyra_home): async with ConfigApp().run_test() as pilot: await pilot.press("q") # Reaching here means the app exited cleanly + + +async def test_ai_tab_renders_provider_fields(tmp_pyra_home): + from pyra.config.manager import save_config + from pyra.config.tui import ConfigApp + + save_config(_make_cfg()) + async with ConfigApp().run_test() as pilot: + assert pilot.app.query_one("#ai-provider", Select) + assert pilot.app.query_one("#ai-model", Input) + assert pilot.app.query_one("#ai-base-url", Input) + + +async def test_ai_tab_save_updates_config(tmp_pyra_home): + from textual.app import App as TextualApp, ComposeResult as CR + + from pyra.config.manager import load_config, save_config as initial_save + from pyra.config.tui import _AITab + + initial_save(_make_cfg()) + + class _TestApp(TextualApp): + def compose(self) -> CR: + yield _AITab() + + saved = [] + with patch("pyra.config.tui.save_config", side_effect=lambda c: saved.append(c) or None): + async with _TestApp().run_test() as pilot: + pilot.app.query_one("#ai-model", Input).value = "llama3:70b" + await pilot.pause() + await pilot.press("ctrl+s") + + assert saved, "save_config was not called" + assert saved[-1].ai.model == "llama3:70b" + + +async def test_ai_tab_save_calls_set_key_when_provided(tmp_pyra_home): + from textual.app import App as TextualApp, ComposeResult as CR + + from pyra.config.manager import save_config as initial_save + from pyra.config.tui import _AITab + + initial_save(_make_cfg()) + + class _TestApp(TextualApp): + def compose(self) -> CR: + yield _AITab() + + calls = [] + with patch("pyra.config.tui.save_config"): + with patch("pyra.config.tui.set_key", side_effect=lambda p, k: calls.append((p, k))): + async with _TestApp().run_test() as pilot: + pilot.app.query_one("#ai-key", Input).value = "sk-test" + await pilot.pause() + await pilot.press("ctrl+s") + + assert calls, "set_key was not called" + assert calls[-1][1] == "sk-test" + + +async def test_ai_tab_save_skips_set_key_when_empty(tmp_pyra_home): + from textual.app import App as TextualApp, ComposeResult as CR + + from pyra.config.manager import save_config as initial_save + from pyra.config.tui import _AITab + + initial_save(_make_cfg()) + + class _TestApp(TextualApp): + def compose(self) -> CR: + yield _AITab() + + calls = [] + with patch("pyra.config.tui.save_config"): + with patch("pyra.config.tui.set_key", side_effect=lambda p, k: calls.append((p, k))): + async with _TestApp().run_test() as pilot: + # Leave api-key empty (default) + await pilot.pause() + await pilot.press("ctrl+s") + + assert not calls, "set_key should not be called when key input is empty" + + +async def test_general_tab_renders_section_headers(tmp_pyra_home): + from textual.app import App as TextualApp, ComposeResult as CR + + from pyra.config.manager import save_config as initial_save + from pyra.config.tui import _GeneralTab + + initial_save(_make_cfg()) + + class _TestApp(TextualApp): + def compose(self) -> CR: + yield _GeneralTab() + + async with _TestApp().run_test() as pilot: + headers = list(pilot.app.query(".section-header")) + assert len(headers) >= 5, "Expected at least 5 section headers"