feat(plugin/telegram_bot): add Telegram bot plugin — remote Pyra chat over Telegram
Runs as a supervised daemon task (asyncio long-polling). Features: - Sender allowlist + bcrypt passphrase challenge (3 attempts before lockout) - Rate limiting: 20 messages/hour per user (in-memory, resets on restart) - Injection scanning on every incoming message (pyra.security.injection) - Full AI chat with litellm streaming → progressive Telegram message editing - Tool-use loop (up to 10 iterations) with inline-button approval (120s timeout) - Conversation history persisted per chat_id in ~/.pyra/telegram_history.db - Memory context loaded from ~/.pyra/memory/ as system prompt on first message Vault keys: plugin:telegram_bot:token, allowed_users, passphrase_hash Deps: python-telegram-bot>=21.0, bcrypt>=4.0.0 (added to telegram + all-plugins extras) 22 new unit tests covering rate limiter, history DB, plugin lifecycle, and auth state. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,236 @@
|
||||
"""Unit tests for the telegram_bot bundled plugin.
|
||||
|
||||
Tests cover pure-logic helpers: rate limiter, SQLite history, auth flow,
|
||||
and tool argument parsing. Handler integration (live Telegram) is not tested.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# ── Import helpers straight from the bundled source ──────────────────────────
|
||||
|
||||
_PLUGIN_PATH = (
|
||||
Path(__file__).parent.parent.parent
|
||||
/ "src/pyra/bundled_plugins/telegram_bot/plugin.py"
|
||||
)
|
||||
|
||||
|
||||
def _load_plugin_module():
|
||||
"""Load plugin.py without requiring python-telegram-bot to be installed."""
|
||||
# Provide stub modules so the top-level imports don't fail in unit tests
|
||||
for mod_name in [
|
||||
"bcrypt",
|
||||
"telegram",
|
||||
"telegram.ext",
|
||||
"litellm",
|
||||
]:
|
||||
if mod_name not in sys.modules:
|
||||
stub = MagicMock()
|
||||
sys.modules[mod_name] = stub
|
||||
# telegram.ext sub-attributes needed at import time
|
||||
if mod_name == "telegram.ext":
|
||||
stub.Application = MagicMock()
|
||||
stub.CallbackQueryHandler = MagicMock()
|
||||
stub.CommandHandler = MagicMock()
|
||||
stub.ContextTypes = MagicMock()
|
||||
stub.MessageHandler = MagicMock()
|
||||
stub.filters = MagicMock()
|
||||
|
||||
spec = importlib.util.spec_from_file_location("pyra_plugin_telegram_bot", _PLUGIN_PATH)
|
||||
assert spec and spec.loader
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod) # type: ignore[union-attr]
|
||||
return mod
|
||||
|
||||
|
||||
_mod = _load_plugin_module()
|
||||
_RateLimiter = _mod._RateLimiter
|
||||
_load_history = _mod._load_history
|
||||
_save_history = _mod._save_history
|
||||
TelegramBotPlugin = _mod.TelegramBotPlugin
|
||||
|
||||
|
||||
# ── Rate limiter ──────────────────────────────────────────────────────────────
|
||||
|
||||
class TestRateLimiter:
|
||||
def test_allows_up_to_limit(self):
|
||||
rl = _RateLimiter(per_hour=5)
|
||||
for _ in range(5):
|
||||
assert rl.allow(user_id=1)
|
||||
|
||||
def test_blocks_over_limit(self):
|
||||
rl = _RateLimiter(per_hour=3)
|
||||
for _ in range(3):
|
||||
rl.allow(user_id=1)
|
||||
assert not rl.allow(user_id=1)
|
||||
|
||||
def test_independent_per_user(self):
|
||||
rl = _RateLimiter(per_hour=2)
|
||||
rl.allow(1)
|
||||
rl.allow(1)
|
||||
assert not rl.allow(1)
|
||||
assert rl.allow(2)
|
||||
|
||||
def test_old_entries_expire(self):
|
||||
rl = _RateLimiter(per_hour=1)
|
||||
# Manually populate bucket with an old timestamp
|
||||
from collections import deque
|
||||
rl._buckets[99] = deque([time.monotonic() - 3601])
|
||||
assert rl.allow(99) # old entry removed, new one fits
|
||||
|
||||
def test_multiple_users_independent(self):
|
||||
rl = _RateLimiter(per_hour=1)
|
||||
rl.allow(10)
|
||||
assert not rl.allow(10)
|
||||
assert rl.allow(20)
|
||||
assert rl.allow(30)
|
||||
|
||||
|
||||
# ── History DB ────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestHistoryDB:
|
||||
def test_empty_history(self, tmp_path, monkeypatch):
|
||||
db_path = tmp_path / "tg.db"
|
||||
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||
assert _load_history(42) == []
|
||||
|
||||
def test_save_and_load(self, tmp_path, monkeypatch):
|
||||
db_path = tmp_path / "tg.db"
|
||||
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||
msgs = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi there"},
|
||||
]
|
||||
_save_history(42, msgs)
|
||||
assert _load_history(42) == msgs
|
||||
|
||||
def test_save_trims_to_max(self, tmp_path, monkeypatch):
|
||||
db_path = tmp_path / "tg.db"
|
||||
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||
monkeypatch.setattr(_mod, "_MAX_HISTORY", 3)
|
||||
msgs = [{"role": "user", "content": str(i)} for i in range(10)]
|
||||
_save_history(1, msgs)
|
||||
loaded = _load_history(1)
|
||||
assert len(loaded) == 3
|
||||
# Should keep the most recent messages
|
||||
assert loaded[-1]["content"] == "9"
|
||||
|
||||
def test_overwrite(self, tmp_path, monkeypatch):
|
||||
db_path = tmp_path / "tg.db"
|
||||
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||
_save_history(1, [{"role": "user", "content": "first"}])
|
||||
_save_history(1, [{"role": "user", "content": "second"}])
|
||||
assert _load_history(1)[0]["content"] == "second"
|
||||
|
||||
def test_separate_chats(self, tmp_path, monkeypatch):
|
||||
db_path = tmp_path / "tg.db"
|
||||
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||
_save_history(1, [{"role": "user", "content": "chat1"}])
|
||||
_save_history(2, [{"role": "user", "content": "chat2"}])
|
||||
assert _load_history(1)[0]["content"] == "chat1"
|
||||
assert _load_history(2)[0]["content"] == "chat2"
|
||||
|
||||
|
||||
# ── Plugin: on_load and setup ─────────────────────────────────────────────────
|
||||
|
||||
class TestPluginLifecycle:
|
||||
def test_on_load_stores_vault_reader(self):
|
||||
plugin = TelegramBotPlugin()
|
||||
reader = MagicMock(return_value=None)
|
||||
plugin.on_load(reader)
|
||||
assert plugin._vault_reader is reader
|
||||
|
||||
def test_daemon_tasks_returns_coroutine(self):
|
||||
import inspect
|
||||
plugin = TelegramBotPlugin()
|
||||
plugin.on_load(MagicMock(return_value=None))
|
||||
tasks = plugin.daemon_tasks()
|
||||
assert len(tasks) == 1
|
||||
assert inspect.iscoroutine(tasks[0])
|
||||
tasks[0].close() # prevent "coroutine never awaited" warning
|
||||
|
||||
def test_get_plugin_factory(self):
|
||||
plugin = _mod.get_plugin()
|
||||
assert isinstance(plugin, TelegramBotPlugin)
|
||||
assert plugin.name == "telegram_bot"
|
||||
|
||||
def test_config_fields(self):
|
||||
plugin = TelegramBotPlugin()
|
||||
fields = plugin.config_fields()
|
||||
assert any(f.key == "rate_limit" for f in fields)
|
||||
|
||||
def test_setup_mismatched_passphrase(self, capsys):
|
||||
"""setup() exits cleanly when passphrases don't match."""
|
||||
plugin = TelegramBotPlugin()
|
||||
console = MagicMock()
|
||||
vault_writer = MagicMock()
|
||||
|
||||
answers = iter(["fake-token", "user123", "pass1", "pass2"])
|
||||
with patch("questionary.password", side_effect=lambda *a, **kw: MagicMock(ask=lambda: next(answers))), \
|
||||
patch("questionary.text", return_value=MagicMock(ask=lambda: next(answers))):
|
||||
plugin.setup(console, vault_writer)
|
||||
|
||||
vault_writer.assert_not_called()
|
||||
console.print.assert_called_with("[red]Passphrases do not match.[/red]")
|
||||
|
||||
|
||||
# ── Auth session state ────────────────────────────────────────────────────────
|
||||
|
||||
class TestAuthState:
|
||||
def test_session_defaults(self):
|
||||
plugin = TelegramBotPlugin()
|
||||
session = plugin._sessions.setdefault(
|
||||
99, {"authenticated": False, "awaiting_passphrase": False, "attempts": 0}
|
||||
)
|
||||
assert not session["authenticated"]
|
||||
assert not session["awaiting_passphrase"]
|
||||
assert session["attempts"] == 0
|
||||
|
||||
def test_session_per_chat(self):
|
||||
plugin = TelegramBotPlugin()
|
||||
plugin._sessions[1] = {"authenticated": True, "awaiting_passphrase": False, "attempts": 0}
|
||||
plugin._sessions[2] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 1}
|
||||
assert plugin._sessions[1]["authenticated"]
|
||||
assert not plugin._sessions[2]["authenticated"]
|
||||
assert plugin._sessions[2]["attempts"] == 1
|
||||
|
||||
def test_session_auth_flag(self):
|
||||
plugin = TelegramBotPlugin()
|
||||
plugin._sessions[5] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 0}
|
||||
# Simulate successful auth
|
||||
plugin._sessions[5]["authenticated"] = True
|
||||
plugin._sessions[5]["awaiting_passphrase"] = False
|
||||
assert plugin._sessions[5]["authenticated"]
|
||||
|
||||
|
||||
# ── Tool argument parsing ─────────────────────────────────────────────────────
|
||||
|
||||
class TestToolArgParsing:
|
||||
def test_valid_json_string(self):
|
||||
args_raw = '{"query": "hello"}'
|
||||
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||
assert args == {"query": "hello"}
|
||||
|
||||
def test_dict_passthrough(self):
|
||||
args_raw = {"query": "hello"}
|
||||
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||
assert args == {"query": "hello"}
|
||||
|
||||
def test_invalid_json_caught(self):
|
||||
args_raw = "not json"
|
||||
try:
|
||||
json.loads(args_raw)
|
||||
assert False, "Should have raised"
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
def test_result_truncated_to_4000(self):
|
||||
long_result = "x" * 5000
|
||||
assert len(long_result[:4000]) == 4000
|
||||
Reference in New Issue
Block a user