diff --git a/pyproject.toml b/pyproject.toml index 955211f..a5784aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dev = [ ] nextcloud = ["caldav>=1.3.0", "webdav4>=0.9.0", "vobject>=0.9.6"] matrix = ["matrix-nio>=0.24.0", "aiofiles>=23.0.0"] -telegram = ["python-telegram-bot>=21.0"] +telegram = ["python-telegram-bot>=21.0", "bcrypt>=4.0.0"] ssh = ["paramiko>=3.4.0"] docker = ["docker>=7.0.0"] gdrive = ["google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0"] @@ -38,7 +38,7 @@ daemon = ["aiofiles>=23.0.0"] all-plugins = [ "caldav>=1.3.0", "webdav4>=0.9.0", "vobject>=0.9.6", "matrix-nio>=0.24.0", "aiofiles>=23.0.0", - "python-telegram-bot>=21.0", + "python-telegram-bot>=21.0", "bcrypt>=4.0.0", "paramiko>=3.4.0", "docker>=7.0.0", "google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0", diff --git a/src/pyra/bundled_plugins/telegram_bot/manifest.json b/src/pyra/bundled_plugins/telegram_bot/manifest.json new file mode 100644 index 0000000..d321683 --- /dev/null +++ b/src/pyra/bundled_plugins/telegram_bot/manifest.json @@ -0,0 +1,7 @@ +{ + "name": "telegram_bot", + "version": "1.0.0", + "description": "Remote Pyra chat over Telegram — full AI chat with tool approval via inline buttons", + "author": "pyra", + "requires": ["python-telegram-bot>=21.0", "bcrypt>=4.0.0"] +} diff --git a/src/pyra/bundled_plugins/telegram_bot/plugin.py b/src/pyra/bundled_plugins/telegram_bot/plugin.py new file mode 100644 index 0000000..fd93573 --- /dev/null +++ b/src/pyra/bundled_plugins/telegram_bot/plugin.py @@ -0,0 +1,592 @@ +"""Telegram bot plugin — remote Pyra chat over Telegram. + +Runs as a daemon task (long-polling). Each chat session requires passphrase +authentication and is rate-limited to 20 messages/hour. Incoming messages are +injection-scanned before reaching the AI. Tool calls are approved via Telegram +inline keyboard buttons (2-minute timeout). AI responses are streamed +progressively by editing a placeholder message. + +Vault keys used: + plugin:telegram_bot:token — bot token from @BotFather + plugin:telegram_bot:allowed_users — comma-separated Telegram user IDs + plugin:telegram_bot:passphrase_hash — bcrypt hash of the session passphrase +""" +from __future__ import annotations + +import asyncio +import json +import logging +import sqlite3 +import time +import uuid +from collections import deque +from typing import Any, Callable + +import bcrypt +import litellm +from telegram import Bot, InlineKeyboardButton, InlineKeyboardMarkup, Update +from telegram.ext import ( + Application, + CallbackQueryHandler, + CommandHandler, + ContextTypes, + MessageHandler, + filters, +) + +from pyra.plugins.base import BasePlugin, ConfigField +from pyra.utils.paths import pyra_home + +_log = logging.getLogger("pyra.plugin.telegram_bot") + +_HISTORY_DB = pyra_home() / "telegram_history.db" +_MAX_HISTORY = 40 # messages kept per chat +_RATE_LIMIT = 20 # messages per hour per user +_APPROVAL_TIMEOUT = 120 # seconds to wait for inline button press +_EDIT_INTERVAL = 1.5 # minimum seconds between progressive message edits +_MAX_TOOL_ITER = 10 +_MAX_MSG_LEN = 4096 # Telegram hard limit + + +# ── SQLite history ──────────────────────────────────────────────────────────── + +def _open_db() -> sqlite3.Connection: + conn = sqlite3.connect(_HISTORY_DB) + conn.execute(""" + CREATE TABLE IF NOT EXISTS sessions ( + chat_id INTEGER PRIMARY KEY, + history TEXT NOT NULL DEFAULT '[]', + updated REAL NOT NULL + ) + """) + conn.commit() + try: + import os + os.chmod(_HISTORY_DB, 0o600) + except Exception: + pass + return conn + + +def _load_history(chat_id: int) -> list[dict]: + conn = _open_db() + row = conn.execute( + "SELECT history FROM sessions WHERE chat_id = ?", (chat_id,) + ).fetchone() + conn.close() + return json.loads(row[0]) if row else [] + + +def _save_history(chat_id: int, messages: list[dict]) -> None: + trimmed = messages[-_MAX_HISTORY:] + conn = _open_db() + conn.execute( + "INSERT OR REPLACE INTO sessions (chat_id, history, updated) VALUES (?,?,?)", + (chat_id, json.dumps(trimmed), time.time()), + ) + conn.commit() + conn.close() + + +# ── Rate limiter ────────────────────────────────────────────────────────────── + +class _RateLimiter: + def __init__(self, per_hour: int = _RATE_LIMIT) -> None: + self._buckets: dict[int, deque] = {} + self._limit = per_hour + + def allow(self, user_id: int) -> bool: + now = time.monotonic() + bucket = self._buckets.setdefault(user_id, deque()) + cutoff = now - 3600 + while bucket and bucket[0] < cutoff: + bucket.popleft() + if len(bucket) >= self._limit: + return False + bucket.append(now) + return True + + +# ── Plugin ──────────────────────────────────────────────────────────────────── + +class TelegramBotPlugin(BasePlugin): + name = "telegram_bot" + description = "Remote Pyra chat over Telegram (daemon task, long-polling)" + version = "1.0.0" + + def __init__(self) -> None: + self._vault_reader: Callable[[str], str | None] | None = None + self._rate_limiter = _RateLimiter() + # chat_id -> {authenticated, awaiting_passphrase, attempts} + self._sessions: dict[int, dict] = {} + # short call_id -> asyncio.Future[bool] + self._pending_approvals: dict[str, asyncio.Future] = {} + + # ── Plugin lifecycle ────────────────────────────────────────────────────── + + def on_load(self, vault_reader: Callable[[str], str | None]) -> None: + self._vault_reader = vault_reader + + def setup(self, console: Any, vault_writer: Callable[[str, str], None]) -> None: + import questionary + + console.print("[bold]Telegram Bot Setup[/bold]") + console.print( + "1. Create a bot via @BotFather on Telegram to get your bot token.\n" + "2. Find your Telegram user ID by messaging @userinfobot.\n" + ) + + token = questionary.password("Bot token (from @BotFather):").ask() + if not token: + return + + allowed = questionary.text( + "Allowed Telegram user IDs (comma-separated, e.g. 123456789):" + ).ask() + + passphrase = questionary.password("Session passphrase:").ask() + if not passphrase: + return + + confirm = questionary.password("Confirm passphrase:").ask() + if passphrase != confirm: + console.print("[red]Passphrases do not match.[/red]") + return + + pw_hash = bcrypt.hashpw(passphrase.encode(), bcrypt.gensalt()).decode() + vault_writer("plugin:telegram_bot:token", token.strip()) + vault_writer("plugin:telegram_bot:allowed_users", (allowed or "").strip()) + vault_writer("plugin:telegram_bot:passphrase_hash", pw_hash) + console.print("[green]Telegram bot configured. Enable it and start the daemon.[/green]") + + def config_fields(self) -> list[ConfigField]: + return [ + ConfigField( + "rate_limit", + "Rate limit (messages/hour)", + "text", + str(_RATE_LIMIT), + description="Maximum messages per hour per Telegram user", + ), + ] + + def daemon_tasks(self) -> list: + return [self._run_polling()] + + # ── Daemon task ─────────────────────────────────────────────────────────── + + async def _run_polling(self) -> None: + assert self._vault_reader is not None + + token = self._vault_reader("plugin:telegram_bot:token") + if not token: + _log.error( + "Telegram bot token not set. Run `pyra plugin setup telegram_bot`." + ) + return + + passphrase_hash = self._vault_reader("plugin:telegram_bot:passphrase_hash") or "" + allowed_str = self._vault_reader("plugin:telegram_bot:allowed_users") or "" + allowed_users: set[int] = { + int(uid.strip()) + for uid in allowed_str.split(",") + if uid.strip().isdigit() + } + + app = Application.builder().token(token).build() + plugin = self # closure reference + + async def _on_start(update: Update, ctx: ContextTypes.DEFAULT_TYPE) -> None: + if update.effective_user and update.effective_user.id in allowed_users: + await update.message.reply_text( + "Pyra is online. Send any message to authenticate." + ) + + async def _on_message(update: Update, ctx: ContextTypes.DEFAULT_TYPE) -> None: + await plugin._handle_message(update, ctx, allowed_users, passphrase_hash) + + async def _on_approval(update: Update, ctx: ContextTypes.DEFAULT_TYPE) -> None: + await plugin._handle_approval_callback(update) + + app.add_handler(CommandHandler("start", _on_start)) + app.add_handler(CallbackQueryHandler(_on_approval)) + app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, _on_message)) + + _log.info("Telegram bot starting (long-polling).") + await app.initialize() + try: + await app.start() + await app.updater.start_polling(drop_pending_updates=True) + _log.info("Telegram bot is polling for updates.") + await asyncio.Event().wait() # block until CancelledError + except asyncio.CancelledError: + _log.info("Telegram bot shutting down.") + finally: + try: + await app.updater.stop() + await app.stop() + await app.shutdown() + except Exception as exc: + _log.warning("Error during Telegram bot shutdown: %s", exc) + + # ── Message handler ─────────────────────────────────────────────────────── + + async def _handle_message( + self, + update: Update, + ctx: ContextTypes.DEFAULT_TYPE, + allowed_users: set[int], + passphrase_hash: str, + ) -> None: + if update.effective_user is None or update.message is None: + return + + user_id = update.effective_user.id + chat_id = update.effective_chat.id if update.effective_chat else user_id + + # Allowlist — silently ignore unknown senders + if allowed_users and user_id not in allowed_users: + return + + text = (update.message.text or "").strip() + if not text: + return + + session = self._sessions.setdefault( + chat_id, + {"authenticated": False, "awaiting_passphrase": False, "attempts": 0}, + ) + + # ── Passphrase authentication ───────────────────────────────────────── + if not session["authenticated"]: + if not session["awaiting_passphrase"]: + session["awaiting_passphrase"] = True + session["attempts"] = 0 + await update.message.reply_text("Enter your passphrase to continue:") + return + + if passphrase_hash and bcrypt.checkpw(text.encode(), passphrase_hash.encode()): + session["authenticated"] = True + session["awaiting_passphrase"] = False + session["attempts"] = 0 + await update.message.reply_text( + "Authenticated. How can I help you?\n" + "Send /start at any time to check bot status." + ) + else: + session["attempts"] += 1 + remaining = 3 - session["attempts"] + if remaining <= 0: + session["awaiting_passphrase"] = False + await update.message.reply_text( + "Too many failed attempts. Send any message to try again." + ) + else: + await update.message.reply_text( + f"Wrong passphrase. {remaining} attempt(s) left." + ) + return + + # ── Rate limit ──────────────────────────────────────────────────────── + if not self._rate_limiter.allow(user_id): + await update.message.reply_text( + "Rate limit reached (20 messages/hour). Try again later." + ) + return + + # ── Injection scan ──────────────────────────────────────────────────── + from pyra.security.injection import scan_response + + warnings = scan_response(text) + if warnings: + labels = ", ".join(w.pattern_label for w in warnings) + _log.warning("Injection in Telegram message (user %d): %s", user_id, labels) + await update.message.reply_text( + "Your message was blocked: injection pattern detected." + ) + return + + # ── Load history + system context ───────────────────────────────────── + history = _load_history(chat_id) + if not history: + try: + from pyra.memory.reader import load_context_for_session + + ctx_text = load_context_for_session() + if ctx_text: + history = [{"role": "system", "content": ctx_text}] + except Exception: + pass + + history.append({"role": "user", "content": text}) + + try: + await ctx.bot.send_chat_action(chat_id=chat_id, action="typing") + except Exception: + pass + + # ── AI response ─────────────────────────────────────────────────────── + try: + reply = await self._ai_chat(chat_id, history, ctx.bot) + except Exception as exc: + _log.error("AI error (chat %d): %s", chat_id, exc, exc_info=True) + await ctx.bot.send_message(chat_id=chat_id, text=f"AI error: {exc}") + return + + history.append({"role": "assistant", "content": reply}) + _save_history(chat_id, history) + + # ── AI streaming + tool-use loop ────────────────────────────────────────── + + async def _ai_chat(self, chat_id: int, messages: list[dict], bot: Bot) -> str: + from pyra.config.manager import load_config + from pyra.plugins.registry import PluginRegistry + from pyra.setup.providers import get_provider + from pyra.vault.reader import get_key + + cfg = load_config() + provider = get_provider(cfg.ai.provider_id) + api_key = get_key(cfg.ai.provider_id) if provider.requires_key else "local" + + call_kwargs: dict[str, Any] = { + "model": f"{provider.litellm_prefix}{cfg.ai.model}", + "api_key": api_key, + } + base_url = cfg.ai.base_url or provider.base_url + if base_url: + call_kwargs["api_base"] = base_url + + litellm.suppress_debug_info = True + + registry = PluginRegistry.instance() + tools_spec = [ + { + "type": "function", + "function": { + "name": t.name, + "description": t.description, + "parameters": t.parameters, + }, + } + for t in registry.get_all_tools() + ] + + # Mutable state shared with helpers below + state: dict[str, Any] = {"msg_id": None, "last_edit": 0.0} + + placeholder = await bot.send_message(chat_id=chat_id, text="…") + state["msg_id"] = placeholder.message_id + + async def _update(text: str) -> None: + if not state["msg_id"]: + return + now = time.monotonic() + if now - state["last_edit"] < _EDIT_INTERVAL: + return + try: + await bot.edit_message_text( + chat_id=chat_id, + message_id=state["msg_id"], + text=text[:_MAX_MSG_LEN], + ) + state["last_edit"] = now + except Exception: + pass + + async def _finalize(text: str) -> None: + if not state["msg_id"]: + return + if text: + try: + await bot.edit_message_text( + chat_id=chat_id, + message_id=state["msg_id"], + text=text[:_MAX_MSG_LEN], + ) + except Exception: + pass + else: + try: + await bot.delete_message(chat_id=chat_id, message_id=state["msg_id"]) + except Exception: + pass + state["msg_id"] = None + + accumulated = "" + + try: + for _iter in range(_MAX_TOOL_ITER): + tool_chunks: dict[int, dict] = {} + accumulated = "" + + stream = await litellm.acompletion( + **call_kwargs, + messages=messages, + tools=tools_spec if tools_spec else None, + tool_choice="auto" if tools_spec else None, + stream=True, + ) + + async for chunk in stream: + delta = chunk.choices[0].delta + if delta.content: + accumulated += delta.content + await _update(accumulated) + + if delta.tool_calls: + for tc in delta.tool_calls: + idx = tc.index + if idx not in tool_chunks: + tool_chunks[idx] = {"id": tc.id or "", "name": "", "args": ""} + if tc.function: + if tc.function.name: + tool_chunks[idx]["name"] += tc.function.name + if tc.function.arguments: + tool_chunks[idx]["args"] += tc.function.arguments + + if not tool_chunks: + await _finalize(accumulated) + return accumulated + + # Show any intermediate prose before tool calls + if accumulated: + await _finalize(accumulated) + else: + try: + await bot.delete_message(chat_id=chat_id, message_id=state["msg_id"]) + except Exception: + pass + state["msg_id"] = None + + tool_calls_list = [ + { + "id": data["id"], + "type": "function", + "function": {"name": data["name"], "arguments": data["args"]}, + } + for _, data in sorted(tool_chunks.items()) + ] + messages.append({ + "role": "assistant", + "content": accumulated or None, + "tool_calls": tool_calls_list, + }) + + for tc in tool_calls_list: + result = await self._execute_tool_with_approval(tc, chat_id, bot) + messages.append({ + "role": "tool", + "tool_call_id": tc["id"], + "content": result, + }) + + # New placeholder for the next AI response + new_ph = await bot.send_message(chat_id=chat_id, text="…") + state["msg_id"] = new_ph.message_id + state["last_edit"] = 0.0 + + except litellm.BadRequestError: + # Provider doesn't support tool calls — retry without tools + accumulated = "" + state["last_edit"] = 0.0 + stream = await litellm.acompletion( + **call_kwargs, messages=messages, stream=True + ) + async for chunk in stream: + if chunk.choices[0].delta.content: + accumulated += chunk.choices[0].delta.content + await _update(accumulated) + await _finalize(accumulated) + return accumulated + + return accumulated or "Error: tool-use loop exceeded maximum iterations." + + # ── Tool approval via inline buttons ────────────────────────────────────── + + async def _execute_tool_with_approval( + self, tool_call: dict, chat_id: int, bot: Bot + ) -> str: + from pyra.plugins.registry import PluginRegistry + from pyra.security.injection import scan_response + + tool_name = tool_call["function"]["name"] + args_raw = tool_call["function"]["arguments"] + try: + args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw + except json.JSONDecodeError: + return f"Error: invalid tool arguments for {tool_name}" + + args_preview = json.dumps(args, indent=2)[:500] + call_id = uuid.uuid4().hex[:8] + + keyboard = InlineKeyboardMarkup([[ + InlineKeyboardButton("✅ Approve", callback_data=f"approve:{call_id}"), + InlineKeyboardButton("❌ Deny", callback_data=f"deny:{call_id}"), + ]]) + + await bot.send_message( + chat_id=chat_id, + text=f"Tool request: {tool_name}\n\n{args_preview}", + reply_markup=keyboard, + ) + + loop = asyncio.get_running_loop() + future: asyncio.Future[bool] = loop.create_future() + self._pending_approvals[call_id] = future + + try: + approved = await asyncio.wait_for(future, timeout=_APPROVAL_TIMEOUT) + except asyncio.TimeoutError: + self._pending_approvals.pop(call_id, None) + await bot.send_message( + chat_id=chat_id, text=f"Tool {tool_name}: approval timed out — denied." + ) + return "Tool execution denied (timeout)." + + if not approved: + return "Tool execution denied by user." + + registry = PluginRegistry.instance() + tool = registry.find_tool(tool_name) + if tool is None: + return f"Error: tool '{tool_name}' not found in registry." + + try: + result = tool.handler(**args) + if not isinstance(result, str): + result = str(result) + except Exception as exc: + return f"Tool error: {exc}" + + injection_warnings = scan_response(result) + if injection_warnings: + labels = ", ".join(w.pattern_label for w in injection_warnings) + await bot.send_message( + chat_id=chat_id, + text=f"Warning: tool result contains suspicious content ({labels}).", + ) + + return result[:4000] + + # ── Approval callback ───────────────────────────────────────────────────── + + async def _handle_approval_callback(self, update: Update) -> None: + query = update.callback_query + if query is None: + return + await query.answer() + data = query.data or "" + if ":" not in data: + return + action, call_id = data.split(":", 1) + future = self._pending_approvals.pop(call_id, None) + if future and not future.done(): + future.set_result(action == "approve") + try: + await query.edit_message_reply_markup(reply_markup=None) + except Exception: + pass + + +def get_plugin() -> TelegramBotPlugin: + return TelegramBotPlugin() diff --git a/tests/unit/test_telegram_bot.py b/tests/unit/test_telegram_bot.py new file mode 100644 index 0000000..247e010 --- /dev/null +++ b/tests/unit/test_telegram_bot.py @@ -0,0 +1,236 @@ +"""Unit tests for the telegram_bot bundled plugin. + +Tests cover pure-logic helpers: rate limiter, SQLite history, auth flow, +and tool argument parsing. Handler integration (live Telegram) is not tested. +""" +from __future__ import annotations + +import importlib.util +import json +import sys +import time +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +# ── Import helpers straight from the bundled source ────────────────────────── + +_PLUGIN_PATH = ( + Path(__file__).parent.parent.parent + / "src/pyra/bundled_plugins/telegram_bot/plugin.py" +) + + +def _load_plugin_module(): + """Load plugin.py without requiring python-telegram-bot to be installed.""" + # Provide stub modules so the top-level imports don't fail in unit tests + for mod_name in [ + "bcrypt", + "telegram", + "telegram.ext", + "litellm", + ]: + if mod_name not in sys.modules: + stub = MagicMock() + sys.modules[mod_name] = stub + # telegram.ext sub-attributes needed at import time + if mod_name == "telegram.ext": + stub.Application = MagicMock() + stub.CallbackQueryHandler = MagicMock() + stub.CommandHandler = MagicMock() + stub.ContextTypes = MagicMock() + stub.MessageHandler = MagicMock() + stub.filters = MagicMock() + + spec = importlib.util.spec_from_file_location("pyra_plugin_telegram_bot", _PLUGIN_PATH) + assert spec and spec.loader + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) # type: ignore[union-attr] + return mod + + +_mod = _load_plugin_module() +_RateLimiter = _mod._RateLimiter +_load_history = _mod._load_history +_save_history = _mod._save_history +TelegramBotPlugin = _mod.TelegramBotPlugin + + +# ── Rate limiter ────────────────────────────────────────────────────────────── + +class TestRateLimiter: + def test_allows_up_to_limit(self): + rl = _RateLimiter(per_hour=5) + for _ in range(5): + assert rl.allow(user_id=1) + + def test_blocks_over_limit(self): + rl = _RateLimiter(per_hour=3) + for _ in range(3): + rl.allow(user_id=1) + assert not rl.allow(user_id=1) + + def test_independent_per_user(self): + rl = _RateLimiter(per_hour=2) + rl.allow(1) + rl.allow(1) + assert not rl.allow(1) + assert rl.allow(2) + + def test_old_entries_expire(self): + rl = _RateLimiter(per_hour=1) + # Manually populate bucket with an old timestamp + from collections import deque + rl._buckets[99] = deque([time.monotonic() - 3601]) + assert rl.allow(99) # old entry removed, new one fits + + def test_multiple_users_independent(self): + rl = _RateLimiter(per_hour=1) + rl.allow(10) + assert not rl.allow(10) + assert rl.allow(20) + assert rl.allow(30) + + +# ── History DB ──────────────────────────────────────────────────────────────── + +class TestHistoryDB: + def test_empty_history(self, tmp_path, monkeypatch): + db_path = tmp_path / "tg.db" + monkeypatch.setattr(_mod, "_HISTORY_DB", db_path) + assert _load_history(42) == [] + + def test_save_and_load(self, tmp_path, monkeypatch): + db_path = tmp_path / "tg.db" + monkeypatch.setattr(_mod, "_HISTORY_DB", db_path) + msgs = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi there"}, + ] + _save_history(42, msgs) + assert _load_history(42) == msgs + + def test_save_trims_to_max(self, tmp_path, monkeypatch): + db_path = tmp_path / "tg.db" + monkeypatch.setattr(_mod, "_HISTORY_DB", db_path) + monkeypatch.setattr(_mod, "_MAX_HISTORY", 3) + msgs = [{"role": "user", "content": str(i)} for i in range(10)] + _save_history(1, msgs) + loaded = _load_history(1) + assert len(loaded) == 3 + # Should keep the most recent messages + assert loaded[-1]["content"] == "9" + + def test_overwrite(self, tmp_path, monkeypatch): + db_path = tmp_path / "tg.db" + monkeypatch.setattr(_mod, "_HISTORY_DB", db_path) + _save_history(1, [{"role": "user", "content": "first"}]) + _save_history(1, [{"role": "user", "content": "second"}]) + assert _load_history(1)[0]["content"] == "second" + + def test_separate_chats(self, tmp_path, monkeypatch): + db_path = tmp_path / "tg.db" + monkeypatch.setattr(_mod, "_HISTORY_DB", db_path) + _save_history(1, [{"role": "user", "content": "chat1"}]) + _save_history(2, [{"role": "user", "content": "chat2"}]) + assert _load_history(1)[0]["content"] == "chat1" + assert _load_history(2)[0]["content"] == "chat2" + + +# ── Plugin: on_load and setup ───────────────────────────────────────────────── + +class TestPluginLifecycle: + def test_on_load_stores_vault_reader(self): + plugin = TelegramBotPlugin() + reader = MagicMock(return_value=None) + plugin.on_load(reader) + assert plugin._vault_reader is reader + + def test_daemon_tasks_returns_coroutine(self): + import inspect + plugin = TelegramBotPlugin() + plugin.on_load(MagicMock(return_value=None)) + tasks = plugin.daemon_tasks() + assert len(tasks) == 1 + assert inspect.iscoroutine(tasks[0]) + tasks[0].close() # prevent "coroutine never awaited" warning + + def test_get_plugin_factory(self): + plugin = _mod.get_plugin() + assert isinstance(plugin, TelegramBotPlugin) + assert plugin.name == "telegram_bot" + + def test_config_fields(self): + plugin = TelegramBotPlugin() + fields = plugin.config_fields() + assert any(f.key == "rate_limit" for f in fields) + + def test_setup_mismatched_passphrase(self, capsys): + """setup() exits cleanly when passphrases don't match.""" + plugin = TelegramBotPlugin() + console = MagicMock() + vault_writer = MagicMock() + + answers = iter(["fake-token", "user123", "pass1", "pass2"]) + with patch("questionary.password", side_effect=lambda *a, **kw: MagicMock(ask=lambda: next(answers))), \ + patch("questionary.text", return_value=MagicMock(ask=lambda: next(answers))): + plugin.setup(console, vault_writer) + + vault_writer.assert_not_called() + console.print.assert_called_with("[red]Passphrases do not match.[/red]") + + +# ── Auth session state ──────────────────────────────────────────────────────── + +class TestAuthState: + def test_session_defaults(self): + plugin = TelegramBotPlugin() + session = plugin._sessions.setdefault( + 99, {"authenticated": False, "awaiting_passphrase": False, "attempts": 0} + ) + assert not session["authenticated"] + assert not session["awaiting_passphrase"] + assert session["attempts"] == 0 + + def test_session_per_chat(self): + plugin = TelegramBotPlugin() + plugin._sessions[1] = {"authenticated": True, "awaiting_passphrase": False, "attempts": 0} + plugin._sessions[2] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 1} + assert plugin._sessions[1]["authenticated"] + assert not plugin._sessions[2]["authenticated"] + assert plugin._sessions[2]["attempts"] == 1 + + def test_session_auth_flag(self): + plugin = TelegramBotPlugin() + plugin._sessions[5] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 0} + # Simulate successful auth + plugin._sessions[5]["authenticated"] = True + plugin._sessions[5]["awaiting_passphrase"] = False + assert plugin._sessions[5]["authenticated"] + + +# ── Tool argument parsing ───────────────────────────────────────────────────── + +class TestToolArgParsing: + def test_valid_json_string(self): + args_raw = '{"query": "hello"}' + args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw + assert args == {"query": "hello"} + + def test_dict_passthrough(self): + args_raw = {"query": "hello"} + args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw + assert args == {"query": "hello"} + + def test_invalid_json_caught(self): + args_raw = "not json" + try: + json.loads(args_raw) + assert False, "Should have raised" + except json.JSONDecodeError: + pass + + def test_result_truncated_to_4000(self): + long_result = "x" * 5000 + assert len(long_result[:4000]) == 4000