feat(plugin/telegram_bot): add Telegram bot plugin — remote Pyra chat over Telegram

Runs as a supervised daemon task (asyncio long-polling). Features:
- Sender allowlist + bcrypt passphrase challenge (3 attempts before lockout)
- Rate limiting: 20 messages/hour per user (in-memory, resets on restart)
- Injection scanning on every incoming message (pyra.security.injection)
- Full AI chat with litellm streaming → progressive Telegram message editing
- Tool-use loop (up to 10 iterations) with inline-button approval (120s timeout)
- Conversation history persisted per chat_id in ~/.pyra/telegram_history.db
- Memory context loaded from ~/.pyra/memory/ as system prompt on first message

Vault keys: plugin:telegram_bot:token, allowed_users, passphrase_hash
Deps: python-telegram-bot>=21.0, bcrypt>=4.0.0 (added to telegram + all-plugins extras)

22 new unit tests covering rate limiter, history DB, plugin lifecycle, and auth state.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
curo1305
2026-05-19 22:45:36 +02:00
parent bde0856979
commit 3f30b782d2
4 changed files with 837 additions and 2 deletions
+236
View File
@@ -0,0 +1,236 @@
"""Unit tests for the telegram_bot bundled plugin.
Tests cover pure-logic helpers: rate limiter, SQLite history, auth flow,
and tool argument parsing. Handler integration (live Telegram) is not tested.
"""
from __future__ import annotations
import importlib.util
import json
import sys
import time
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
# ── Import helpers straight from the bundled source ──────────────────────────
_PLUGIN_PATH = (
Path(__file__).parent.parent.parent
/ "src/pyra/bundled_plugins/telegram_bot/plugin.py"
)
def _load_plugin_module():
"""Load plugin.py without requiring python-telegram-bot to be installed."""
# Provide stub modules so the top-level imports don't fail in unit tests
for mod_name in [
"bcrypt",
"telegram",
"telegram.ext",
"litellm",
]:
if mod_name not in sys.modules:
stub = MagicMock()
sys.modules[mod_name] = stub
# telegram.ext sub-attributes needed at import time
if mod_name == "telegram.ext":
stub.Application = MagicMock()
stub.CallbackQueryHandler = MagicMock()
stub.CommandHandler = MagicMock()
stub.ContextTypes = MagicMock()
stub.MessageHandler = MagicMock()
stub.filters = MagicMock()
spec = importlib.util.spec_from_file_location("pyra_plugin_telegram_bot", _PLUGIN_PATH)
assert spec and spec.loader
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore[union-attr]
return mod
_mod = _load_plugin_module()
_RateLimiter = _mod._RateLimiter
_load_history = _mod._load_history
_save_history = _mod._save_history
TelegramBotPlugin = _mod.TelegramBotPlugin
# ── Rate limiter ──────────────────────────────────────────────────────────────
class TestRateLimiter:
def test_allows_up_to_limit(self):
rl = _RateLimiter(per_hour=5)
for _ in range(5):
assert rl.allow(user_id=1)
def test_blocks_over_limit(self):
rl = _RateLimiter(per_hour=3)
for _ in range(3):
rl.allow(user_id=1)
assert not rl.allow(user_id=1)
def test_independent_per_user(self):
rl = _RateLimiter(per_hour=2)
rl.allow(1)
rl.allow(1)
assert not rl.allow(1)
assert rl.allow(2)
def test_old_entries_expire(self):
rl = _RateLimiter(per_hour=1)
# Manually populate bucket with an old timestamp
from collections import deque
rl._buckets[99] = deque([time.monotonic() - 3601])
assert rl.allow(99) # old entry removed, new one fits
def test_multiple_users_independent(self):
rl = _RateLimiter(per_hour=1)
rl.allow(10)
assert not rl.allow(10)
assert rl.allow(20)
assert rl.allow(30)
# ── History DB ────────────────────────────────────────────────────────────────
class TestHistoryDB:
def test_empty_history(self, tmp_path, monkeypatch):
db_path = tmp_path / "tg.db"
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
assert _load_history(42) == []
def test_save_and_load(self, tmp_path, monkeypatch):
db_path = tmp_path / "tg.db"
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
msgs = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi there"},
]
_save_history(42, msgs)
assert _load_history(42) == msgs
def test_save_trims_to_max(self, tmp_path, monkeypatch):
db_path = tmp_path / "tg.db"
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
monkeypatch.setattr(_mod, "_MAX_HISTORY", 3)
msgs = [{"role": "user", "content": str(i)} for i in range(10)]
_save_history(1, msgs)
loaded = _load_history(1)
assert len(loaded) == 3
# Should keep the most recent messages
assert loaded[-1]["content"] == "9"
def test_overwrite(self, tmp_path, monkeypatch):
db_path = tmp_path / "tg.db"
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
_save_history(1, [{"role": "user", "content": "first"}])
_save_history(1, [{"role": "user", "content": "second"}])
assert _load_history(1)[0]["content"] == "second"
def test_separate_chats(self, tmp_path, monkeypatch):
db_path = tmp_path / "tg.db"
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
_save_history(1, [{"role": "user", "content": "chat1"}])
_save_history(2, [{"role": "user", "content": "chat2"}])
assert _load_history(1)[0]["content"] == "chat1"
assert _load_history(2)[0]["content"] == "chat2"
# ── Plugin: on_load and setup ─────────────────────────────────────────────────
class TestPluginLifecycle:
def test_on_load_stores_vault_reader(self):
plugin = TelegramBotPlugin()
reader = MagicMock(return_value=None)
plugin.on_load(reader)
assert plugin._vault_reader is reader
def test_daemon_tasks_returns_coroutine(self):
import inspect
plugin = TelegramBotPlugin()
plugin.on_load(MagicMock(return_value=None))
tasks = plugin.daemon_tasks()
assert len(tasks) == 1
assert inspect.iscoroutine(tasks[0])
tasks[0].close() # prevent "coroutine never awaited" warning
def test_get_plugin_factory(self):
plugin = _mod.get_plugin()
assert isinstance(plugin, TelegramBotPlugin)
assert plugin.name == "telegram_bot"
def test_config_fields(self):
plugin = TelegramBotPlugin()
fields = plugin.config_fields()
assert any(f.key == "rate_limit" for f in fields)
def test_setup_mismatched_passphrase(self, capsys):
"""setup() exits cleanly when passphrases don't match."""
plugin = TelegramBotPlugin()
console = MagicMock()
vault_writer = MagicMock()
answers = iter(["fake-token", "user123", "pass1", "pass2"])
with patch("questionary.password", side_effect=lambda *a, **kw: MagicMock(ask=lambda: next(answers))), \
patch("questionary.text", return_value=MagicMock(ask=lambda: next(answers))):
plugin.setup(console, vault_writer)
vault_writer.assert_not_called()
console.print.assert_called_with("[red]Passphrases do not match.[/red]")
# ── Auth session state ────────────────────────────────────────────────────────
class TestAuthState:
def test_session_defaults(self):
plugin = TelegramBotPlugin()
session = plugin._sessions.setdefault(
99, {"authenticated": False, "awaiting_passphrase": False, "attempts": 0}
)
assert not session["authenticated"]
assert not session["awaiting_passphrase"]
assert session["attempts"] == 0
def test_session_per_chat(self):
plugin = TelegramBotPlugin()
plugin._sessions[1] = {"authenticated": True, "awaiting_passphrase": False, "attempts": 0}
plugin._sessions[2] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 1}
assert plugin._sessions[1]["authenticated"]
assert not plugin._sessions[2]["authenticated"]
assert plugin._sessions[2]["attempts"] == 1
def test_session_auth_flag(self):
plugin = TelegramBotPlugin()
plugin._sessions[5] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 0}
# Simulate successful auth
plugin._sessions[5]["authenticated"] = True
plugin._sessions[5]["awaiting_passphrase"] = False
assert plugin._sessions[5]["authenticated"]
# ── Tool argument parsing ─────────────────────────────────────────────────────
class TestToolArgParsing:
def test_valid_json_string(self):
args_raw = '{"query": "hello"}'
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
assert args == {"query": "hello"}
def test_dict_passthrough(self):
args_raw = {"query": "hello"}
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
assert args == {"query": "hello"}
def test_invalid_json_caught(self):
args_raw = "not json"
try:
json.loads(args_raw)
assert False, "Should have raised"
except json.JSONDecodeError:
pass
def test_result_truncated_to_4000(self):
long_result = "x" * 5000
assert len(long_result[:4000]) == 4000