"""Unit tests for the telegram_bot bundled plugin. Tests cover pure-logic helpers: rate limiter, SQLite history, auth flow, and tool argument parsing. Handler integration (live Telegram) is not tested. """ from __future__ import annotations import importlib.util import json import sys import time from pathlib import Path from unittest.mock import MagicMock, patch import pytest # ── Import helpers straight from the bundled source ────────────────────────── _PLUGIN_PATH = ( Path(__file__).parent.parent.parent / "src/pyra/bundled_plugins/telegram_bot/plugin.py" ) def _load_plugin_module(): """Load plugin.py without requiring python-telegram-bot to be installed.""" # Provide stub modules so the top-level imports don't fail in unit tests for mod_name in [ "bcrypt", "telegram", "telegram.ext", "litellm", ]: if mod_name not in sys.modules: stub = MagicMock() sys.modules[mod_name] = stub # telegram.ext sub-attributes needed at import time if mod_name == "telegram.ext": stub.Application = MagicMock() stub.CallbackQueryHandler = MagicMock() stub.CommandHandler = MagicMock() stub.ContextTypes = MagicMock() stub.MessageHandler = MagicMock() stub.filters = MagicMock() spec = importlib.util.spec_from_file_location("pyra_plugin_telegram_bot", _PLUGIN_PATH) assert spec and spec.loader mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) # type: ignore[union-attr] return mod _mod = _load_plugin_module() _RateLimiter = _mod._RateLimiter _load_history = _mod._load_history _save_history = _mod._save_history TelegramBotPlugin = _mod.TelegramBotPlugin # ── Rate limiter ────────────────────────────────────────────────────────────── class TestRateLimiter: def test_allows_up_to_limit(self): rl = _RateLimiter(per_hour=5) for _ in range(5): assert rl.allow(user_id=1) def test_blocks_over_limit(self): rl = _RateLimiter(per_hour=3) for _ in range(3): rl.allow(user_id=1) assert not rl.allow(user_id=1) def test_independent_per_user(self): rl = _RateLimiter(per_hour=2) rl.allow(1) rl.allow(1) assert not rl.allow(1) assert rl.allow(2) def test_old_entries_expire(self): rl = _RateLimiter(per_hour=1) # Manually populate bucket with an old timestamp from collections import deque rl._buckets[99] = deque([time.monotonic() - 3601]) assert rl.allow(99) # old entry removed, new one fits def test_multiple_users_independent(self): rl = _RateLimiter(per_hour=1) rl.allow(10) assert not rl.allow(10) assert rl.allow(20) assert rl.allow(30) # ── History DB ──────────────────────────────────────────────────────────────── class TestHistoryDB: def test_empty_history(self, tmp_path, monkeypatch): db_path = tmp_path / "tg.db" monkeypatch.setattr(_mod, "_HISTORY_DB", db_path) assert _load_history(42) == [] def test_save_and_load(self, tmp_path, monkeypatch): db_path = tmp_path / "tg.db" monkeypatch.setattr(_mod, "_HISTORY_DB", db_path) msgs = [ {"role": "user", "content": "hello"}, {"role": "assistant", "content": "hi there"}, ] _save_history(42, msgs) assert _load_history(42) == msgs def test_save_trims_to_max(self, tmp_path, monkeypatch): db_path = tmp_path / "tg.db" monkeypatch.setattr(_mod, "_HISTORY_DB", db_path) monkeypatch.setattr(_mod, "_MAX_HISTORY", 3) msgs = [{"role": "user", "content": str(i)} for i in range(10)] _save_history(1, msgs) loaded = _load_history(1) assert len(loaded) == 3 # Should keep the most recent messages assert loaded[-1]["content"] == "9" def test_overwrite(self, tmp_path, monkeypatch): db_path = tmp_path / "tg.db" monkeypatch.setattr(_mod, "_HISTORY_DB", db_path) _save_history(1, [{"role": "user", "content": "first"}]) _save_history(1, [{"role": "user", "content": "second"}]) assert _load_history(1)[0]["content"] == "second" def test_separate_chats(self, tmp_path, monkeypatch): db_path = tmp_path / "tg.db" monkeypatch.setattr(_mod, "_HISTORY_DB", db_path) _save_history(1, [{"role": "user", "content": "chat1"}]) _save_history(2, [{"role": "user", "content": "chat2"}]) assert _load_history(1)[0]["content"] == "chat1" assert _load_history(2)[0]["content"] == "chat2" # ── Plugin: on_load and setup ───────────────────────────────────────────────── class TestPluginLifecycle: def test_on_load_stores_vault_reader(self): plugin = TelegramBotPlugin() reader = MagicMock(return_value=None) plugin.on_load(reader) assert plugin._vault_reader is reader def test_daemon_tasks_returns_coroutine(self): import inspect plugin = TelegramBotPlugin() plugin.on_load(MagicMock(return_value=None)) tasks = plugin.daemon_tasks() assert len(tasks) == 1 assert inspect.iscoroutine(tasks[0]) tasks[0].close() # prevent "coroutine never awaited" warning def test_get_plugin_factory(self): plugin = _mod.get_plugin() assert isinstance(plugin, TelegramBotPlugin) assert plugin.name == "telegram_bot" def test_config_fields(self): plugin = TelegramBotPlugin() fields = plugin.config_fields() assert any(f.key == "rate_limit" for f in fields) def test_setup_mismatched_passphrase(self, capsys): """setup() exits cleanly when passphrases don't match.""" plugin = TelegramBotPlugin() console = MagicMock() vault_writer = MagicMock() answers = iter(["fake-token", "user123", "pass1", "pass2"]) with patch("questionary.password", side_effect=lambda *a, **kw: MagicMock(ask=lambda: next(answers))), \ patch("questionary.text", return_value=MagicMock(ask=lambda: next(answers))): plugin.setup(console, vault_writer) vault_writer.assert_not_called() console.print.assert_called_with("[red]Passphrases do not match.[/red]") # ── Auth session state ──────────────────────────────────────────────────────── class TestAuthState: def test_session_defaults(self): plugin = TelegramBotPlugin() session = plugin._sessions.setdefault( 99, {"authenticated": False, "awaiting_passphrase": False, "attempts": 0} ) assert not session["authenticated"] assert not session["awaiting_passphrase"] assert session["attempts"] == 0 def test_session_per_chat(self): plugin = TelegramBotPlugin() plugin._sessions[1] = {"authenticated": True, "awaiting_passphrase": False, "attempts": 0} plugin._sessions[2] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 1} assert plugin._sessions[1]["authenticated"] assert not plugin._sessions[2]["authenticated"] assert plugin._sessions[2]["attempts"] == 1 def test_session_auth_flag(self): plugin = TelegramBotPlugin() plugin._sessions[5] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 0} # Simulate successful auth plugin._sessions[5]["authenticated"] = True plugin._sessions[5]["awaiting_passphrase"] = False assert plugin._sessions[5]["authenticated"] # ── Tool argument parsing ───────────────────────────────────────────────────── class TestToolArgParsing: def test_valid_json_string(self): args_raw = '{"query": "hello"}' args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw assert args == {"query": "hello"} def test_dict_passthrough(self): args_raw = {"query": "hello"} args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw assert args == {"query": "hello"} def test_invalid_json_caught(self): args_raw = "not json" try: json.loads(args_raw) assert False, "Should have raised" except json.JSONDecodeError: pass def test_result_truncated_to_4000(self): long_result = "x" * 5000 assert len(long_result[:4000]) == 4000