feat(plugins): Stage 2.1 — plugin framework and AI tool-use
Introduces a standalone plugin system where every integration lives as an independent Python script in ~/.pyra/plugins/, not hardcoded in core. Plugin framework (src/pyra/plugins/): - base.py: Tool dataclass, PyraPlugin Protocol, BasePlugin helper - loader.py: importlib-based discovery; one bad plugin never crashes pyra - registry.py: singleton aggregating tools, slash commands, system prompts - executor.py: approval gate — scans args, prompts y/N, scans result, logs - install.py: copies bundled_plugins/ to ~/.pyra/plugins/ on install Chat integration: - AI tool-use loop (litellm function calling, up to 10 iterations) - Plugin system prompt additions injected per session - Plugin slash commands merged with static commands CLI additions: - pyra plugin list/install/enable/disable/setup - pyra daemon start/stop/status/restart/install/uninstall (stubs for 2.4) Config: PluginConfig + DaemonConfig added to PyraConfig (backwards-compatible) Bootstrap: ~/.pyra/plugins/ and ~/.pyra/logs/ created on startup Security: tool args and results always injection-scanned; plugin dirs validated with assert_safe_path() before loading (symlink protection) Tests: 37 new tests (loader, registry, executor, plugin isolation security) 161 total, all passing. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Protocol, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from rich.console import Console
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tool:
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict[str, Any] # JSON Schema object
|
||||
handler: Callable[..., str]
|
||||
requires_approval: bool = True
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class PyraPlugin(Protocol):
|
||||
name: str
|
||||
description: str
|
||||
version: str
|
||||
|
||||
def on_load(self, vault_reader: Callable[[str], str | None]) -> None: ...
|
||||
def tools(self) -> list[Tool]: ...
|
||||
def slash_commands(self) -> dict[str, Callable[[], None]]: ...
|
||||
def system_prompt_addition(self) -> str: ...
|
||||
def setup(self, console: Console, vault_writer: Callable[[str, str], None]) -> None: ...
|
||||
def daemon_tasks(self) -> list[Coroutine]: ... # type: ignore[type-arg]
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
"""Concrete base class with no-op defaults. Plugins can inherit from this."""
|
||||
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
version: str = "0.1.0"
|
||||
|
||||
def on_load(self, vault_reader: Callable[[str], str | None]) -> None:
|
||||
pass
|
||||
|
||||
def tools(self) -> list[Tool]:
|
||||
return []
|
||||
|
||||
def slash_commands(self) -> dict[str, Callable[[], None]]:
|
||||
return {}
|
||||
|
||||
def system_prompt_addition(self) -> str:
|
||||
return ""
|
||||
|
||||
def setup(self, console: Any, vault_writer: Callable[[str, str], None]) -> None:
|
||||
pass
|
||||
|
||||
def daemon_tasks(self) -> list[Coroutine]: # type: ignore[type-arg]
|
||||
return []
|
||||
@@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from rich.console import Console
|
||||
from rich.markup import escape
|
||||
from rich.panel import Panel
|
||||
|
||||
from pyra.plugins.registry import PluginRegistry
|
||||
from pyra.security.injection import redact_api_keys, scan_response
|
||||
from pyra.utils.paths import pyra_home, safe_chmod
|
||||
|
||||
_LOG_FILE = pyra_home() / "logs" / "tool_executions.log"
|
||||
_MAX_RESULT_CHARS = 4000
|
||||
_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10 MB
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
def __init__(self, registry: PluginRegistry, console: Console) -> None:
|
||||
self._registry = registry
|
||||
self._console = console
|
||||
|
||||
def execute(self, tool_name: str, arguments: dict[str, Any]) -> str:
|
||||
tool = self._registry.find_tool(tool_name)
|
||||
if tool is None:
|
||||
return f"Error: unknown tool '{escape(tool_name)}'"
|
||||
|
||||
# Injection-scan arguments before any execution
|
||||
args_str = json.dumps(arguments)
|
||||
arg_warnings = scan_response(args_str)
|
||||
if arg_warnings:
|
||||
labels = ", ".join(w.pattern_label for w in arg_warnings)
|
||||
self._log(tool_name, arguments, approved=False, result=f"BLOCKED:{labels}")
|
||||
return f"Tool execution blocked: suspicious content in arguments ({labels})."
|
||||
|
||||
approved = True
|
||||
if tool.requires_approval:
|
||||
approved = self._ask_approval(tool_name, arguments)
|
||||
|
||||
if not approved:
|
||||
self._log(tool_name, arguments, approved=False, result="declined")
|
||||
return "Tool execution declined by user."
|
||||
|
||||
try:
|
||||
result = str(tool.handler(**arguments))
|
||||
except Exception as exc:
|
||||
result = f"Tool error: {exc}"
|
||||
|
||||
# Injection-scan result before returning to AI context
|
||||
result_warnings = scan_response(result)
|
||||
if result_warnings:
|
||||
labels = ", ".join(w.pattern_label for w in result_warnings)
|
||||
result = f"[Warning: suspicious content in tool result ({labels})] {result}"
|
||||
|
||||
if len(result) > _MAX_RESULT_CHARS:
|
||||
result = result[:_MAX_RESULT_CHARS] + f"\n[...truncated at {_MAX_RESULT_CHARS} chars]"
|
||||
|
||||
self._log(tool_name, arguments, approved=True, result=result[:200])
|
||||
return result
|
||||
|
||||
def execute_tool_call_batch(
|
||||
self, tool_calls: list[Any]
|
||||
) -> list[dict[str, str]]:
|
||||
results = []
|
||||
for tc in tool_calls:
|
||||
try:
|
||||
args = json.loads(tc.function.arguments)
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
args = {}
|
||||
result = self.execute(tc.function.name, args)
|
||||
results.append({"tool_call_id": tc.id, "result": result})
|
||||
return results
|
||||
|
||||
def _ask_approval(self, tool_name: str, arguments: dict[str, Any]) -> bool:
|
||||
lines = [f"[bold yellow]Tool:[/bold yellow] {escape(tool_name)}"]
|
||||
if arguments:
|
||||
lines.append("[bold yellow]Arguments:[/bold yellow]")
|
||||
for k, v in arguments.items():
|
||||
lines.append(f" {escape(str(k))}: {escape(str(v))}")
|
||||
self._console.print(Panel(
|
||||
"\n".join(lines),
|
||||
title="[bold]Pyra wants to run a tool[/bold]",
|
||||
border_style="yellow",
|
||||
))
|
||||
try:
|
||||
answer = self._console.input("[bold]Approve?[/bold] [dim][y/N][/dim] ").strip().lower()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return False
|
||||
return answer == "y"
|
||||
|
||||
def _log(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
approved: bool,
|
||||
result: str,
|
||||
) -> None:
|
||||
try:
|
||||
_LOG_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
if _LOG_FILE.exists() and _LOG_FILE.stat().st_size > _LOG_MAX_BYTES:
|
||||
_rotate_log()
|
||||
ts = datetime.datetime.now().isoformat()
|
||||
args_safe = redact_api_keys(json.dumps(arguments))
|
||||
status = "APPROVED" if approved else "DECLINED"
|
||||
with _LOG_FILE.open("a") as fh:
|
||||
fh.write(
|
||||
f"[{ts}] {status} tool={tool_name!r} "
|
||||
f"args={args_safe!r} result_preview={result[:100]!r}\n"
|
||||
)
|
||||
safe_chmod(_LOG_FILE, 0o600)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _rotate_log() -> None:
|
||||
rotated = _LOG_FILE.with_suffix(".log.1")
|
||||
_LOG_FILE.rename(rotated)
|
||||
try:
|
||||
safe_chmod(rotated, 0o000)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from pyra.utils.paths import ensure_dir
|
||||
|
||||
|
||||
def get_bundled_plugins_dir() -> Path:
|
||||
"""Return the path to bundled_plugins/ packaged alongside pyra."""
|
||||
# src/pyra/plugins/install.py → src/pyra/ → src/pyra/bundled_plugins/
|
||||
return Path(__file__).parent.parent / "bundled_plugins"
|
||||
|
||||
|
||||
def install_bundled_plugin(name: str, bundled_dir: Path, plugins_dir: Path) -> None:
|
||||
"""Copy bundled_plugins/<name>/ to ~/.pyra/plugins/<name>/."""
|
||||
src = bundled_dir / name
|
||||
if not src.is_dir():
|
||||
raise FileNotFoundError(f"Bundled plugin '{name}' not found in {bundled_dir}")
|
||||
if not (src / "manifest.json").exists():
|
||||
raise FileNotFoundError(f"Bundled plugin '{name}' is missing manifest.json")
|
||||
|
||||
dest = plugins_dir / name
|
||||
if dest.exists():
|
||||
shutil.rmtree(dest)
|
||||
shutil.copytree(src, dest)
|
||||
|
||||
ensure_dir(dest, 0o700)
|
||||
for f in dest.rglob("*"):
|
||||
if f.is_file():
|
||||
f.chmod(0o600)
|
||||
|
||||
|
||||
def list_bundled_plugins(bundled_dir: Path) -> list[str]:
|
||||
"""Return names of all available bundled plugins."""
|
||||
if not bundled_dir.is_dir():
|
||||
return []
|
||||
return sorted(
|
||||
e.name
|
||||
for e in bundled_dir.iterdir()
|
||||
if e.is_dir() and (e / "manifest.json").exists()
|
||||
)
|
||||
|
||||
|
||||
def read_manifest(plugin_dir: Path) -> dict:
|
||||
manifest_path = plugin_dir / "manifest.json"
|
||||
if not manifest_path.exists():
|
||||
return {}
|
||||
with manifest_path.open() as fh:
|
||||
return json.load(fh)
|
||||
@@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from pyra.plugins.base import PyraPlugin
|
||||
from pyra.security.boundaries import assert_safe_path
|
||||
from pyra.utils.paths import pyra_home
|
||||
|
||||
_LOG_FILE = pyra_home() / "logs" / "plugin_errors.log"
|
||||
|
||||
|
||||
def load_plugins(plugins_dir: Path) -> list[PyraPlugin]:
|
||||
"""Discover and load all valid plugin directories found in plugins_dir."""
|
||||
plugins: list[PyraPlugin] = []
|
||||
if not plugins_dir.is_dir():
|
||||
return plugins
|
||||
|
||||
for entry in sorted(plugins_dir.iterdir()):
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
plugin = load_plugin_by_name(entry.name, plugins_dir)
|
||||
if plugin is not None:
|
||||
plugins.append(plugin)
|
||||
return plugins
|
||||
|
||||
|
||||
def load_plugin_by_name(name: str, plugins_dir: Path) -> PyraPlugin | None:
|
||||
"""Load a single plugin by directory name. Returns None on any failure."""
|
||||
plugin_dir = plugins_dir / name
|
||||
try:
|
||||
assert_safe_path(plugin_dir)
|
||||
return _load_from_dir(name, plugin_dir)
|
||||
except Exception as exc:
|
||||
_log_error(name, exc)
|
||||
return None
|
||||
|
||||
|
||||
def _load_from_dir(name: str, plugin_dir: Path) -> PyraPlugin:
|
||||
manifest_path = plugin_dir / "manifest.json"
|
||||
plugin_path = plugin_dir / "plugin.py"
|
||||
|
||||
if not manifest_path.exists():
|
||||
raise FileNotFoundError(f"Missing manifest.json in {plugin_dir}")
|
||||
if not plugin_path.exists():
|
||||
raise FileNotFoundError(f"Missing plugin.py in {plugin_dir}")
|
||||
|
||||
with manifest_path.open() as fh:
|
||||
manifest = json.load(fh)
|
||||
|
||||
if "name" not in manifest or "version" not in manifest:
|
||||
raise ValueError(f"manifest.json missing required 'name'/'version' in {plugin_dir}")
|
||||
|
||||
module_name = f"pyra_plugin_{name}"
|
||||
spec = importlib.util.spec_from_file_location(module_name, plugin_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Cannot create module spec for {plugin_path}")
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module) # type: ignore[union-attr]
|
||||
|
||||
if not hasattr(module, "get_plugin"):
|
||||
raise AttributeError(f"plugin.py must export get_plugin() in {plugin_dir}")
|
||||
|
||||
plugin = module.get_plugin()
|
||||
|
||||
for attr in ("name", "description", "version"):
|
||||
if not hasattr(plugin, attr):
|
||||
raise AttributeError(f"Plugin missing required attribute '{attr}'")
|
||||
|
||||
return plugin # type: ignore[return-value]
|
||||
|
||||
|
||||
def _log_error(name: str, exc: Exception) -> None:
|
||||
try:
|
||||
_LOG_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
ts = datetime.datetime.now().isoformat()
|
||||
with _LOG_FILE.open("a") as fh:
|
||||
fh.write(f"[{ts}] Failed to load plugin '{name}': {type(exc).__name__}: {exc}\n")
|
||||
except Exception:
|
||||
pass
|
||||
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable, Coroutine
|
||||
|
||||
from pyra.plugins.base import PyraPlugin, Tool
|
||||
from pyra.plugins.loader import _log_error, load_plugins
|
||||
from pyra.vault.reader import get_key
|
||||
|
||||
|
||||
class PluginRegistry:
|
||||
_instance: PluginRegistry | None = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._plugins: dict[str, PyraPlugin] = {}
|
||||
|
||||
@classmethod
|
||||
def instance(cls) -> PluginRegistry:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset(cls) -> None:
|
||||
"""Reset singleton — for tests only."""
|
||||
cls._instance = None
|
||||
|
||||
def load_all(self, plugins_dir: Path, enabled_names: list[str]) -> None:
|
||||
all_plugins = load_plugins(plugins_dir)
|
||||
self._plugins = {}
|
||||
for plugin in all_plugins:
|
||||
if plugin.name in enabled_names:
|
||||
try:
|
||||
plugin.on_load(get_key)
|
||||
self._plugins[plugin.name] = plugin
|
||||
except Exception as exc:
|
||||
_log_error(plugin.name, exc)
|
||||
|
||||
def get_active_plugins(self) -> list[PyraPlugin]:
|
||||
return list(self._plugins.values())
|
||||
|
||||
def get_all_tools(self) -> list[Tool]:
|
||||
tools: list[Tool] = []
|
||||
for plugin in self._plugins.values():
|
||||
try:
|
||||
tools.extend(plugin.tools())
|
||||
except Exception:
|
||||
pass
|
||||
return tools
|
||||
|
||||
def get_slash_commands(self) -> dict[str, Callable[[], None]]:
|
||||
cmds: dict[str, Callable[[], None]] = {}
|
||||
for plugin in self._plugins.values():
|
||||
try:
|
||||
cmds.update(plugin.slash_commands())
|
||||
except Exception:
|
||||
pass
|
||||
return cmds
|
||||
|
||||
def get_system_prompt_additions(self) -> str:
|
||||
parts: list[str] = []
|
||||
for plugin in self._plugins.values():
|
||||
try:
|
||||
addition = plugin.system_prompt_addition()
|
||||
if addition:
|
||||
parts.append(addition.strip())
|
||||
except Exception:
|
||||
pass
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def get_daemon_tasks(self) -> list[Coroutine]: # type: ignore[type-arg]
|
||||
tasks: list[Coroutine] = [] # type: ignore[type-arg]
|
||||
for plugin in self._plugins.values():
|
||||
try:
|
||||
tasks.extend(plugin.daemon_tasks())
|
||||
except Exception:
|
||||
pass
|
||||
return tasks
|
||||
|
||||
def find_tool(self, name: str) -> Tool | None:
|
||||
for tool in self.get_all_tools():
|
||||
if tool.name == name:
|
||||
return tool
|
||||
return None
|
||||
Reference in New Issue
Block a user