diff --git a/src/pyra/plugins/registry.py b/src/pyra/plugins/registry.py index 9ff6a74..bd73ec5 100644 --- a/src/pyra/plugins/registry.py +++ b/src/pyra/plugins/registry.py @@ -13,6 +13,7 @@ class PluginRegistry: def __init__(self) -> None: self._plugins: dict[str, PyraPlugin] = {} + self._tools: dict[str, Tool] = {} @classmethod def instance(cls) -> PluginRegistry: @@ -28,11 +29,14 @@ class PluginRegistry: def load_all(self, plugins_dir: Path, enabled_names: list[str]) -> None: all_plugins = load_plugins(plugins_dir) self._plugins = {} + self._tools = {} for plugin in all_plugins: if plugin.name in enabled_names: try: plugin.on_load(get_key) self._plugins[plugin.name] = plugin + for tool in plugin.tools(): + self._tools[tool.name] = tool except Exception as exc: _log_error(plugin.name, exc) @@ -40,13 +44,7 @@ class PluginRegistry: return list(self._plugins.values()) def get_all_tools(self) -> list[Tool]: - tools: list[Tool] = [] - for plugin in self._plugins.values(): - try: - tools.extend(plugin.tools()) - except Exception: - pass - return tools + return list(self._tools.values()) def get_slash_commands(self) -> dict[str, Callable[[], None]]: cmds: dict[str, Callable[[], None]] = {} @@ -78,7 +76,4 @@ class PluginRegistry: return tasks def find_tool(self, name: str) -> Tool | None: - for tool in self.get_all_tools(): - if tool.name == name: - return tool - return None + return self._tools.get(name) diff --git a/tests/unit/test_tool_executor.py b/tests/unit/test_tool_executor.py index ebddca0..fbe1ffd 100644 --- a/tests/unit/test_tool_executor.py +++ b/tests/unit/test_tool_executor.py @@ -20,6 +20,7 @@ def _make_registry_with_tools(*tools: Tool) -> PluginRegistry: fake_plugin.system_prompt_addition.return_value = "" fake_plugin.daemon_tasks.return_value = [] registry._plugins = {"mock_plugin": fake_plugin} + registry._tools = {tool.name: tool for tool in tools} return registry