f59aa1a758
Step 1 — create bot via @BotFather (instructions + press-any-key pause) Step 2 — find Telegram user ID via @userinfobot (instructions + pause) Step 3 — set session passphrase with security explanation Step 4 — save all three vault keys, print ✓ confirmations Step 5 — configuration complete marker Adds setup cancellation on empty token, updated tests: happy path, mismatch, and cancel all covered; press_any_key_to_continue calls properly patched. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
282 lines
11 KiB
Python
282 lines
11 KiB
Python
"""Unit tests for the telegram_bot bundled plugin.
|
|
|
|
Tests cover pure-logic helpers: rate limiter, SQLite history, auth flow,
|
|
and tool argument parsing. Handler integration (live Telegram) is not tested.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import importlib.util
|
|
import json
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
# ── Import helpers straight from the bundled source ──────────────────────────
|
|
|
|
_PLUGIN_PATH = (
|
|
Path(__file__).parent.parent.parent
|
|
/ "src/pyra/bundled_plugins/telegram_bot/plugin.py"
|
|
)
|
|
|
|
|
|
def _load_plugin_module():
|
|
"""Load plugin.py without requiring python-telegram-bot to be installed."""
|
|
# Provide stub modules so the top-level imports don't fail in unit tests
|
|
for mod_name in [
|
|
"bcrypt",
|
|
"telegram",
|
|
"telegram.ext",
|
|
"litellm",
|
|
]:
|
|
if mod_name not in sys.modules:
|
|
stub = MagicMock()
|
|
sys.modules[mod_name] = stub
|
|
# telegram.ext sub-attributes needed at import time
|
|
if mod_name == "telegram.ext":
|
|
stub.Application = MagicMock()
|
|
stub.CallbackQueryHandler = MagicMock()
|
|
stub.CommandHandler = MagicMock()
|
|
stub.ContextTypes = MagicMock()
|
|
stub.MessageHandler = MagicMock()
|
|
stub.filters = MagicMock()
|
|
|
|
spec = importlib.util.spec_from_file_location("pyra_plugin_telegram_bot", _PLUGIN_PATH)
|
|
assert spec and spec.loader
|
|
mod = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(mod) # type: ignore[union-attr]
|
|
return mod
|
|
|
|
|
|
_mod = _load_plugin_module()
|
|
_RateLimiter = _mod._RateLimiter
|
|
_load_history = _mod._load_history
|
|
_save_history = _mod._save_history
|
|
TelegramBotPlugin = _mod.TelegramBotPlugin
|
|
|
|
|
|
# ── Rate limiter ──────────────────────────────────────────────────────────────
|
|
|
|
class TestRateLimiter:
|
|
def test_allows_up_to_limit(self):
|
|
rl = _RateLimiter(per_hour=5)
|
|
for _ in range(5):
|
|
assert rl.allow(user_id=1)
|
|
|
|
def test_blocks_over_limit(self):
|
|
rl = _RateLimiter(per_hour=3)
|
|
for _ in range(3):
|
|
rl.allow(user_id=1)
|
|
assert not rl.allow(user_id=1)
|
|
|
|
def test_independent_per_user(self):
|
|
rl = _RateLimiter(per_hour=2)
|
|
rl.allow(1)
|
|
rl.allow(1)
|
|
assert not rl.allow(1)
|
|
assert rl.allow(2)
|
|
|
|
def test_old_entries_expire(self):
|
|
rl = _RateLimiter(per_hour=1)
|
|
# Manually populate bucket with an old timestamp
|
|
from collections import deque
|
|
rl._buckets[99] = deque([time.monotonic() - 3601])
|
|
assert rl.allow(99) # old entry removed, new one fits
|
|
|
|
def test_multiple_users_independent(self):
|
|
rl = _RateLimiter(per_hour=1)
|
|
rl.allow(10)
|
|
assert not rl.allow(10)
|
|
assert rl.allow(20)
|
|
assert rl.allow(30)
|
|
|
|
|
|
# ── History DB ────────────────────────────────────────────────────────────────
|
|
|
|
class TestHistoryDB:
|
|
def test_empty_history(self, tmp_path, monkeypatch):
|
|
db_path = tmp_path / "tg.db"
|
|
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
|
assert _load_history(42) == []
|
|
|
|
def test_save_and_load(self, tmp_path, monkeypatch):
|
|
db_path = tmp_path / "tg.db"
|
|
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
|
msgs = [
|
|
{"role": "user", "content": "hello"},
|
|
{"role": "assistant", "content": "hi there"},
|
|
]
|
|
_save_history(42, msgs)
|
|
assert _load_history(42) == msgs
|
|
|
|
def test_save_trims_to_max(self, tmp_path, monkeypatch):
|
|
db_path = tmp_path / "tg.db"
|
|
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
|
monkeypatch.setattr(_mod, "_MAX_HISTORY", 3)
|
|
msgs = [{"role": "user", "content": str(i)} for i in range(10)]
|
|
_save_history(1, msgs)
|
|
loaded = _load_history(1)
|
|
assert len(loaded) == 3
|
|
# Should keep the most recent messages
|
|
assert loaded[-1]["content"] == "9"
|
|
|
|
def test_overwrite(self, tmp_path, monkeypatch):
|
|
db_path = tmp_path / "tg.db"
|
|
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
|
_save_history(1, [{"role": "user", "content": "first"}])
|
|
_save_history(1, [{"role": "user", "content": "second"}])
|
|
assert _load_history(1)[0]["content"] == "second"
|
|
|
|
def test_separate_chats(self, tmp_path, monkeypatch):
|
|
db_path = tmp_path / "tg.db"
|
|
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
|
_save_history(1, [{"role": "user", "content": "chat1"}])
|
|
_save_history(2, [{"role": "user", "content": "chat2"}])
|
|
assert _load_history(1)[0]["content"] == "chat1"
|
|
assert _load_history(2)[0]["content"] == "chat2"
|
|
|
|
|
|
# ── Plugin: on_load and setup ─────────────────────────────────────────────────
|
|
|
|
class TestPluginLifecycle:
|
|
def test_on_load_stores_vault_reader(self):
|
|
plugin = TelegramBotPlugin()
|
|
reader = MagicMock(return_value=None)
|
|
plugin.on_load(reader)
|
|
assert plugin._vault_reader is reader
|
|
|
|
def test_daemon_tasks_returns_coroutine(self):
|
|
import inspect
|
|
plugin = TelegramBotPlugin()
|
|
plugin.on_load(MagicMock(return_value=None))
|
|
tasks = plugin.daemon_tasks()
|
|
assert len(tasks) == 1
|
|
assert inspect.iscoroutine(tasks[0])
|
|
tasks[0].close() # prevent "coroutine never awaited" warning
|
|
|
|
def test_get_plugin_factory(self):
|
|
plugin = _mod.get_plugin()
|
|
assert isinstance(plugin, TelegramBotPlugin)
|
|
assert plugin.name == "telegram_bot"
|
|
|
|
def test_config_fields(self):
|
|
plugin = TelegramBotPlugin()
|
|
fields = plugin.config_fields()
|
|
assert any(f.key == "rate_limit" for f in fields)
|
|
|
|
def _patch_setup(self, token, allowed, pass1, pass2):
|
|
"""Return a context manager that patches all questionary calls used by setup()."""
|
|
pw_answers = iter([token, pass1, pass2])
|
|
return (
|
|
patch("questionary.password",
|
|
side_effect=lambda *a, **kw: MagicMock(ask=lambda: next(pw_answers))),
|
|
patch("questionary.text",
|
|
return_value=MagicMock(ask=lambda: allowed)),
|
|
patch("questionary.press_any_key_to_continue",
|
|
return_value=MagicMock(ask=lambda: None)),
|
|
)
|
|
|
|
def test_setup_mismatched_passphrase(self):
|
|
"""setup() writes nothing to the vault when passphrases don't match."""
|
|
plugin = TelegramBotPlugin()
|
|
console = MagicMock()
|
|
vault_writer = MagicMock()
|
|
|
|
pw_patch, text_patch, pakc_patch = self._patch_setup(
|
|
"fake-token", "123456789", "pass1", "pass2"
|
|
)
|
|
with pw_patch, text_patch, pakc_patch:
|
|
plugin.setup(console, vault_writer)
|
|
|
|
vault_writer.assert_not_called()
|
|
console.print.assert_called_with(
|
|
"[red]Passphrases do not match. Run setup again to retry.[/red]"
|
|
)
|
|
|
|
def test_setup_happy_path(self):
|
|
"""setup() writes all three vault keys when credentials are valid."""
|
|
plugin = TelegramBotPlugin()
|
|
console = MagicMock()
|
|
vault_writer = MagicMock()
|
|
|
|
pw_patch, text_patch, pakc_patch = self._patch_setup(
|
|
"real-token", "111222333", "s3cr3t", "s3cr3t"
|
|
)
|
|
with pw_patch, text_patch, pakc_patch:
|
|
plugin.setup(console, vault_writer)
|
|
|
|
calls = {call[0][0]: call[0][1] for call in vault_writer.call_args_list}
|
|
assert calls.get("plugin:telegram_bot:token") == "real-token"
|
|
assert calls.get("plugin:telegram_bot:allowed_users") == "111222333"
|
|
assert "plugin:telegram_bot:passphrase_hash" in calls
|
|
|
|
def test_setup_cancelled_on_empty_token(self):
|
|
"""setup() exits without writing if the token prompt is cancelled."""
|
|
plugin = TelegramBotPlugin()
|
|
console = MagicMock()
|
|
vault_writer = MagicMock()
|
|
|
|
with patch("questionary.password", return_value=MagicMock(ask=lambda: None)), \
|
|
patch("questionary.press_any_key_to_continue",
|
|
return_value=MagicMock(ask=lambda: None)):
|
|
plugin.setup(console, vault_writer)
|
|
|
|
vault_writer.assert_not_called()
|
|
|
|
|
|
# ── Auth session state ────────────────────────────────────────────────────────
|
|
|
|
class TestAuthState:
|
|
def test_session_defaults(self):
|
|
plugin = TelegramBotPlugin()
|
|
session = plugin._sessions.setdefault(
|
|
99, {"authenticated": False, "awaiting_passphrase": False, "attempts": 0}
|
|
)
|
|
assert not session["authenticated"]
|
|
assert not session["awaiting_passphrase"]
|
|
assert session["attempts"] == 0
|
|
|
|
def test_session_per_chat(self):
|
|
plugin = TelegramBotPlugin()
|
|
plugin._sessions[1] = {"authenticated": True, "awaiting_passphrase": False, "attempts": 0}
|
|
plugin._sessions[2] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 1}
|
|
assert plugin._sessions[1]["authenticated"]
|
|
assert not plugin._sessions[2]["authenticated"]
|
|
assert plugin._sessions[2]["attempts"] == 1
|
|
|
|
def test_session_auth_flag(self):
|
|
plugin = TelegramBotPlugin()
|
|
plugin._sessions[5] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 0}
|
|
# Simulate successful auth
|
|
plugin._sessions[5]["authenticated"] = True
|
|
plugin._sessions[5]["awaiting_passphrase"] = False
|
|
assert plugin._sessions[5]["authenticated"]
|
|
|
|
|
|
# ── Tool argument parsing ─────────────────────────────────────────────────────
|
|
|
|
class TestToolArgParsing:
|
|
def test_valid_json_string(self):
|
|
args_raw = '{"query": "hello"}'
|
|
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
|
assert args == {"query": "hello"}
|
|
|
|
def test_dict_passthrough(self):
|
|
args_raw = {"query": "hello"}
|
|
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
|
assert args == {"query": "hello"}
|
|
|
|
def test_invalid_json_caught(self):
|
|
args_raw = "not json"
|
|
try:
|
|
json.loads(args_raw)
|
|
assert False, "Should have raised"
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
def test_result_truncated_to_4000(self):
|
|
long_result = "x" * 5000
|
|
assert len(long_result[:4000]) == 4000
|