Files
Pyra/tests/unit/test_telegram_bot.py
T
curo1305 f59aa1a758 feat(plugin/telegram_bot): replace bare prompts with 5-step guided setup wizard
Step 1 — create bot via @BotFather (instructions + press-any-key pause)
Step 2 — find Telegram user ID via @userinfobot (instructions + pause)
Step 3 — set session passphrase with security explanation
Step 4 — save all three vault keys, print ✓ confirmations
Step 5 — configuration complete marker

Adds setup cancellation on empty token, updated tests: happy path, mismatch,
and cancel all covered; press_any_key_to_continue calls properly patched.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 23:15:22 +02:00

282 lines
11 KiB
Python

"""Unit tests for the telegram_bot bundled plugin.
Tests cover pure-logic helpers: rate limiter, SQLite history, auth flow,
and tool argument parsing. Handler integration (live Telegram) is not tested.
"""
from __future__ import annotations
import importlib.util
import json
import sys
import time
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
# ── Import helpers straight from the bundled source ──────────────────────────
_PLUGIN_PATH = (
Path(__file__).parent.parent.parent
/ "src/pyra/bundled_plugins/telegram_bot/plugin.py"
)
def _load_plugin_module():
"""Load plugin.py without requiring python-telegram-bot to be installed."""
# Provide stub modules so the top-level imports don't fail in unit tests
for mod_name in [
"bcrypt",
"telegram",
"telegram.ext",
"litellm",
]:
if mod_name not in sys.modules:
stub = MagicMock()
sys.modules[mod_name] = stub
# telegram.ext sub-attributes needed at import time
if mod_name == "telegram.ext":
stub.Application = MagicMock()
stub.CallbackQueryHandler = MagicMock()
stub.CommandHandler = MagicMock()
stub.ContextTypes = MagicMock()
stub.MessageHandler = MagicMock()
stub.filters = MagicMock()
spec = importlib.util.spec_from_file_location("pyra_plugin_telegram_bot", _PLUGIN_PATH)
assert spec and spec.loader
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore[union-attr]
return mod
_mod = _load_plugin_module()
_RateLimiter = _mod._RateLimiter
_load_history = _mod._load_history
_save_history = _mod._save_history
TelegramBotPlugin = _mod.TelegramBotPlugin
# ── Rate limiter ──────────────────────────────────────────────────────────────
class TestRateLimiter:
def test_allows_up_to_limit(self):
rl = _RateLimiter(per_hour=5)
for _ in range(5):
assert rl.allow(user_id=1)
def test_blocks_over_limit(self):
rl = _RateLimiter(per_hour=3)
for _ in range(3):
rl.allow(user_id=1)
assert not rl.allow(user_id=1)
def test_independent_per_user(self):
rl = _RateLimiter(per_hour=2)
rl.allow(1)
rl.allow(1)
assert not rl.allow(1)
assert rl.allow(2)
def test_old_entries_expire(self):
rl = _RateLimiter(per_hour=1)
# Manually populate bucket with an old timestamp
from collections import deque
rl._buckets[99] = deque([time.monotonic() - 3601])
assert rl.allow(99) # old entry removed, new one fits
def test_multiple_users_independent(self):
rl = _RateLimiter(per_hour=1)
rl.allow(10)
assert not rl.allow(10)
assert rl.allow(20)
assert rl.allow(30)
# ── History DB ────────────────────────────────────────────────────────────────
class TestHistoryDB:
def test_empty_history(self, tmp_path, monkeypatch):
db_path = tmp_path / "tg.db"
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
assert _load_history(42) == []
def test_save_and_load(self, tmp_path, monkeypatch):
db_path = tmp_path / "tg.db"
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
msgs = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi there"},
]
_save_history(42, msgs)
assert _load_history(42) == msgs
def test_save_trims_to_max(self, tmp_path, monkeypatch):
db_path = tmp_path / "tg.db"
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
monkeypatch.setattr(_mod, "_MAX_HISTORY", 3)
msgs = [{"role": "user", "content": str(i)} for i in range(10)]
_save_history(1, msgs)
loaded = _load_history(1)
assert len(loaded) == 3
# Should keep the most recent messages
assert loaded[-1]["content"] == "9"
def test_overwrite(self, tmp_path, monkeypatch):
db_path = tmp_path / "tg.db"
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
_save_history(1, [{"role": "user", "content": "first"}])
_save_history(1, [{"role": "user", "content": "second"}])
assert _load_history(1)[0]["content"] == "second"
def test_separate_chats(self, tmp_path, monkeypatch):
db_path = tmp_path / "tg.db"
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
_save_history(1, [{"role": "user", "content": "chat1"}])
_save_history(2, [{"role": "user", "content": "chat2"}])
assert _load_history(1)[0]["content"] == "chat1"
assert _load_history(2)[0]["content"] == "chat2"
# ── Plugin: on_load and setup ─────────────────────────────────────────────────
class TestPluginLifecycle:
def test_on_load_stores_vault_reader(self):
plugin = TelegramBotPlugin()
reader = MagicMock(return_value=None)
plugin.on_load(reader)
assert plugin._vault_reader is reader
def test_daemon_tasks_returns_coroutine(self):
import inspect
plugin = TelegramBotPlugin()
plugin.on_load(MagicMock(return_value=None))
tasks = plugin.daemon_tasks()
assert len(tasks) == 1
assert inspect.iscoroutine(tasks[0])
tasks[0].close() # prevent "coroutine never awaited" warning
def test_get_plugin_factory(self):
plugin = _mod.get_plugin()
assert isinstance(plugin, TelegramBotPlugin)
assert plugin.name == "telegram_bot"
def test_config_fields(self):
plugin = TelegramBotPlugin()
fields = plugin.config_fields()
assert any(f.key == "rate_limit" for f in fields)
def _patch_setup(self, token, allowed, pass1, pass2):
"""Return a context manager that patches all questionary calls used by setup()."""
pw_answers = iter([token, pass1, pass2])
return (
patch("questionary.password",
side_effect=lambda *a, **kw: MagicMock(ask=lambda: next(pw_answers))),
patch("questionary.text",
return_value=MagicMock(ask=lambda: allowed)),
patch("questionary.press_any_key_to_continue",
return_value=MagicMock(ask=lambda: None)),
)
def test_setup_mismatched_passphrase(self):
"""setup() writes nothing to the vault when passphrases don't match."""
plugin = TelegramBotPlugin()
console = MagicMock()
vault_writer = MagicMock()
pw_patch, text_patch, pakc_patch = self._patch_setup(
"fake-token", "123456789", "pass1", "pass2"
)
with pw_patch, text_patch, pakc_patch:
plugin.setup(console, vault_writer)
vault_writer.assert_not_called()
console.print.assert_called_with(
"[red]Passphrases do not match. Run setup again to retry.[/red]"
)
def test_setup_happy_path(self):
"""setup() writes all three vault keys when credentials are valid."""
plugin = TelegramBotPlugin()
console = MagicMock()
vault_writer = MagicMock()
pw_patch, text_patch, pakc_patch = self._patch_setup(
"real-token", "111222333", "s3cr3t", "s3cr3t"
)
with pw_patch, text_patch, pakc_patch:
plugin.setup(console, vault_writer)
calls = {call[0][0]: call[0][1] for call in vault_writer.call_args_list}
assert calls.get("plugin:telegram_bot:token") == "real-token"
assert calls.get("plugin:telegram_bot:allowed_users") == "111222333"
assert "plugin:telegram_bot:passphrase_hash" in calls
def test_setup_cancelled_on_empty_token(self):
"""setup() exits without writing if the token prompt is cancelled."""
plugin = TelegramBotPlugin()
console = MagicMock()
vault_writer = MagicMock()
with patch("questionary.password", return_value=MagicMock(ask=lambda: None)), \
patch("questionary.press_any_key_to_continue",
return_value=MagicMock(ask=lambda: None)):
plugin.setup(console, vault_writer)
vault_writer.assert_not_called()
# ── Auth session state ────────────────────────────────────────────────────────
class TestAuthState:
def test_session_defaults(self):
plugin = TelegramBotPlugin()
session = plugin._sessions.setdefault(
99, {"authenticated": False, "awaiting_passphrase": False, "attempts": 0}
)
assert not session["authenticated"]
assert not session["awaiting_passphrase"]
assert session["attempts"] == 0
def test_session_per_chat(self):
plugin = TelegramBotPlugin()
plugin._sessions[1] = {"authenticated": True, "awaiting_passphrase": False, "attempts": 0}
plugin._sessions[2] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 1}
assert plugin._sessions[1]["authenticated"]
assert not plugin._sessions[2]["authenticated"]
assert plugin._sessions[2]["attempts"] == 1
def test_session_auth_flag(self):
plugin = TelegramBotPlugin()
plugin._sessions[5] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 0}
# Simulate successful auth
plugin._sessions[5]["authenticated"] = True
plugin._sessions[5]["awaiting_passphrase"] = False
assert plugin._sessions[5]["authenticated"]
# ── Tool argument parsing ─────────────────────────────────────────────────────
class TestToolArgParsing:
def test_valid_json_string(self):
args_raw = '{"query": "hello"}'
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
assert args == {"query": "hello"}
def test_dict_passthrough(self):
args_raw = {"query": "hello"}
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
assert args == {"query": "hello"}
def test_invalid_json_caught(self):
args_raw = "not json"
try:
json.loads(args_raw)
assert False, "Should have raised"
except json.JSONDecodeError:
pass
def test_result_truncated_to_4000(self):
long_result = "x" * 5000
assert len(long_result[:4000]) == 4000