Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| aba28293b7 | |||
| f59aa1a758 | |||
| 3f30b782d2 |
+2
-2
@@ -28,7 +28,7 @@ dev = [
|
||||
]
|
||||
nextcloud = ["caldav>=1.3.0", "webdav4>=0.9.0", "vobject>=0.9.6"]
|
||||
matrix = ["matrix-nio>=0.24.0", "aiofiles>=23.0.0"]
|
||||
telegram = ["python-telegram-bot>=21.0"]
|
||||
telegram = ["python-telegram-bot>=21.0", "bcrypt>=4.0.0"]
|
||||
ssh = ["paramiko>=3.4.0"]
|
||||
docker = ["docker>=7.0.0"]
|
||||
gdrive = ["google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0"]
|
||||
@@ -38,7 +38,7 @@ daemon = ["aiofiles>=23.0.0"]
|
||||
all-plugins = [
|
||||
"caldav>=1.3.0", "webdav4>=0.9.0", "vobject>=0.9.6",
|
||||
"matrix-nio>=0.24.0", "aiofiles>=23.0.0",
|
||||
"python-telegram-bot>=21.0",
|
||||
"python-telegram-bot>=21.0", "bcrypt>=4.0.0",
|
||||
"paramiko>=3.4.0",
|
||||
"docker>=7.0.0",
|
||||
"google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0",
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"name": "telegram_bot",
|
||||
"version": "1.0.0",
|
||||
"description": "Remote Pyra chat over Telegram — full AI chat with tool approval via inline buttons",
|
||||
"author": "pyra",
|
||||
"requires": ["python-telegram-bot>=21.0", "bcrypt>=4.0.0"]
|
||||
}
|
||||
@@ -0,0 +1,665 @@
|
||||
"""Telegram bot plugin — remote Pyra chat over Telegram.
|
||||
|
||||
Runs as a daemon task (long-polling). Each chat session requires passphrase
|
||||
authentication and is rate-limited to 20 messages/hour. Incoming messages are
|
||||
injection-scanned before reaching the AI. Tool calls are approved via Telegram
|
||||
inline keyboard buttons (2-minute timeout). AI responses are streamed
|
||||
progressively by editing a placeholder message.
|
||||
|
||||
Vault keys used:
|
||||
plugin:telegram_bot:token — bot token from @BotFather
|
||||
plugin:telegram_bot:allowed_users — comma-separated Telegram user IDs
|
||||
plugin:telegram_bot:passphrase_hash — bcrypt hash of the session passphrase
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from typing import Any, Callable
|
||||
|
||||
import bcrypt
|
||||
import litellm
|
||||
from telegram import Bot, InlineKeyboardButton, InlineKeyboardMarkup, Update
|
||||
from telegram.ext import (
|
||||
Application,
|
||||
CallbackQueryHandler,
|
||||
CommandHandler,
|
||||
ContextTypes,
|
||||
MessageHandler,
|
||||
filters,
|
||||
)
|
||||
|
||||
from pyra.plugins.base import BasePlugin, ConfigField
|
||||
from pyra.utils.paths import pyra_home
|
||||
|
||||
_log = logging.getLogger("pyra.plugin.telegram_bot")
|
||||
|
||||
_HISTORY_DB = pyra_home() / "telegram_history.db"
|
||||
_MAX_HISTORY = 40 # messages kept per chat
|
||||
_RATE_LIMIT = 20 # messages per hour per user
|
||||
_APPROVAL_TIMEOUT = 120 # seconds to wait for inline button press
|
||||
_EDIT_INTERVAL = 1.5 # minimum seconds between progressive message edits
|
||||
_MAX_TOOL_ITER = 10
|
||||
_MAX_MSG_LEN = 4096 # Telegram hard limit
|
||||
|
||||
|
||||
# ── SQLite history ────────────────────────────────────────────────────────────
|
||||
|
||||
def _open_db() -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(_HISTORY_DB)
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
chat_id INTEGER PRIMARY KEY,
|
||||
history TEXT NOT NULL DEFAULT '[]',
|
||||
updated REAL NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
try:
|
||||
import os
|
||||
os.chmod(_HISTORY_DB, 0o600)
|
||||
except Exception:
|
||||
pass
|
||||
return conn
|
||||
|
||||
|
||||
def _load_history(chat_id: int) -> list[dict]:
|
||||
conn = _open_db()
|
||||
row = conn.execute(
|
||||
"SELECT history FROM sessions WHERE chat_id = ?", (chat_id,)
|
||||
).fetchone()
|
||||
conn.close()
|
||||
return json.loads(row[0]) if row else []
|
||||
|
||||
|
||||
def _save_history(chat_id: int, messages: list[dict]) -> None:
|
||||
trimmed = messages[-_MAX_HISTORY:]
|
||||
conn = _open_db()
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO sessions (chat_id, history, updated) VALUES (?,?,?)",
|
||||
(chat_id, json.dumps(trimmed), time.time()),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
# ── Rate limiter ──────────────────────────────────────────────────────────────
|
||||
|
||||
class _RateLimiter:
|
||||
def __init__(self, per_hour: int = _RATE_LIMIT) -> None:
|
||||
self._buckets: dict[int, deque] = {}
|
||||
self._limit = per_hour
|
||||
|
||||
def allow(self, user_id: int) -> bool:
|
||||
now = time.monotonic()
|
||||
bucket = self._buckets.setdefault(user_id, deque())
|
||||
cutoff = now - 3600
|
||||
while bucket and bucket[0] < cutoff:
|
||||
bucket.popleft()
|
||||
if len(bucket) >= self._limit:
|
||||
return False
|
||||
bucket.append(now)
|
||||
return True
|
||||
|
||||
|
||||
# ── Plugin ────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TelegramBotPlugin(BasePlugin):
|
||||
name = "telegram_bot"
|
||||
description = "Remote Pyra chat over Telegram (daemon task, long-polling)"
|
||||
version = "1.0.0"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._vault_reader: Callable[[str], str | None] | None = None
|
||||
self._rate_limiter = _RateLimiter()
|
||||
# chat_id -> {authenticated, awaiting_passphrase, attempts}
|
||||
self._sessions: dict[int, dict] = {}
|
||||
# short call_id -> asyncio.Future[bool]
|
||||
self._pending_approvals: dict[str, asyncio.Future] = {}
|
||||
|
||||
# ── Plugin lifecycle ──────────────────────────────────────────────────────
|
||||
|
||||
def on_load(self, vault_reader: Callable[[str], str | None]) -> None:
|
||||
self._vault_reader = vault_reader
|
||||
|
||||
def setup(self, console: Any, vault_writer: Callable[[str, str], None]) -> None:
|
||||
import questionary
|
||||
from rich.panel import Panel
|
||||
from rich.rule import Rule
|
||||
|
||||
console.print()
|
||||
console.print(Panel(
|
||||
"[bold]Telegram Bot Setup Wizard[/bold]\n\n"
|
||||
"This wizard connects Pyra to Telegram so you can chat with your\n"
|
||||
"assistant from anywhere. You will need Telegram open on your phone\n"
|
||||
"or desktop to complete the next steps.",
|
||||
border_style="cyan",
|
||||
))
|
||||
|
||||
# ── Step 1: Create bot ────────────────────────────────────────────────
|
||||
console.print()
|
||||
console.print(Rule("[bold cyan]Step 1 / 5[/bold cyan] Create your Telegram bot"))
|
||||
console.print()
|
||||
console.print(
|
||||
" 1. Open Telegram and search for [bold]@BotFather[/bold]\n"
|
||||
" 2. Send [bold]/newbot[/bold] and follow the prompts\n"
|
||||
" 3. Choose a display name (e.g. [dim]My Pyra Assistant[/dim])\n"
|
||||
" 4. Choose a username ending in [bold]bot[/bold] "
|
||||
"(e.g. [dim]my_pyra_bot[/dim])\n"
|
||||
" 5. BotFather replies with a token that looks like:\n"
|
||||
" [dim]123456789:AABBccDDeeFFggHHiiJJkkLL[/dim]"
|
||||
)
|
||||
console.print()
|
||||
questionary.press_any_key_to_continue(
|
||||
" Press any key when you have your token ready ..."
|
||||
).ask()
|
||||
console.print()
|
||||
|
||||
token = questionary.password(" Bot token:").ask()
|
||||
if not token or not token.strip():
|
||||
console.print("[dim]Setup cancelled.[/dim]")
|
||||
return
|
||||
token = token.strip()
|
||||
|
||||
# ── Step 2: Find user ID ──────────────────────────────────────────────
|
||||
console.print()
|
||||
console.print(Rule("[bold cyan]Step 2 / 5[/bold cyan] Find your Telegram user ID"))
|
||||
console.print()
|
||||
console.print(
|
||||
" Your user ID is a permanent number that identifies your account.\n"
|
||||
" It never changes, even if you change your username.\n\n"
|
||||
" 1. Search for [bold]@userinfobot[/bold] in Telegram\n"
|
||||
" 2. Send any message (e.g. [dim]/start[/dim])\n"
|
||||
" 3. Copy the [bold]Id:[/bold] number from the reply "
|
||||
"(e.g. [dim]123456789[/dim])"
|
||||
)
|
||||
console.print()
|
||||
questionary.press_any_key_to_continue(
|
||||
" Press any key when you have your user ID ready ..."
|
||||
).ask()
|
||||
console.print()
|
||||
|
||||
allowed = questionary.text(
|
||||
" Allowed Telegram user IDs (comma-separated, leave blank to allow anyone):"
|
||||
).ask()
|
||||
if allowed is None:
|
||||
console.print("[dim]Setup cancelled.[/dim]")
|
||||
return
|
||||
|
||||
# ── Step 3: Session passphrase ────────────────────────────────────────
|
||||
console.print()
|
||||
console.print(Rule("[bold cyan]Step 3 / 5[/bold cyan] Set a session passphrase"))
|
||||
console.print()
|
||||
console.print(
|
||||
" The passphrase is an extra layer of security. Every new chat\n"
|
||||
" session must pass this challenge before Pyra responds — even\n"
|
||||
" if someone else gains access to your Telegram account."
|
||||
)
|
||||
console.print()
|
||||
|
||||
passphrase = questionary.password(" Session passphrase:").ask()
|
||||
if not passphrase:
|
||||
console.print("[dim]Setup cancelled.[/dim]")
|
||||
return
|
||||
|
||||
confirm = questionary.password(" Confirm passphrase:").ask()
|
||||
if passphrase != confirm:
|
||||
console.print("[red]Passphrases do not match. Run setup again to retry.[/red]")
|
||||
return
|
||||
|
||||
# ── Step 4: Save to vault ─────────────────────────────────────────────
|
||||
console.print()
|
||||
console.print(Rule("[bold cyan]Step 4 / 5[/bold cyan] Saving configuration"))
|
||||
console.print()
|
||||
|
||||
pw_hash = bcrypt.hashpw(passphrase.encode(), bcrypt.gensalt()).decode()
|
||||
vault_writer("plugin:telegram_bot:token", token)
|
||||
vault_writer("plugin:telegram_bot:allowed_users", (allowed or "").strip())
|
||||
vault_writer("plugin:telegram_bot:passphrase_hash", pw_hash)
|
||||
|
||||
allowed_display = (allowed or "").strip() or "[dim](any user — consider restricting)[/dim]"
|
||||
console.print(f" [green]✓[/green] Bot token stored in vault")
|
||||
console.print(f" [green]✓[/green] Allowed users: {allowed_display}")
|
||||
console.print(f" [green]✓[/green] Passphrase stored as bcrypt hash")
|
||||
|
||||
# ── Step 5: Done ──────────────────────────────────────────────────────
|
||||
console.print()
|
||||
console.print(Rule("[bold cyan]Step 5 / 5[/bold cyan] Configuration complete"))
|
||||
console.print()
|
||||
|
||||
def config_fields(self) -> list[ConfigField]:
|
||||
return [
|
||||
ConfigField(
|
||||
"rate_limit",
|
||||
"Rate limit (messages/hour)",
|
||||
"text",
|
||||
str(_RATE_LIMIT),
|
||||
description="Maximum messages per hour per Telegram user",
|
||||
),
|
||||
]
|
||||
|
||||
def daemon_tasks(self) -> list:
|
||||
return [self._run_polling()]
|
||||
|
||||
# ── Daemon task ───────────────────────────────────────────────────────────
|
||||
|
||||
async def _run_polling(self) -> None:
|
||||
assert self._vault_reader is not None
|
||||
|
||||
token = self._vault_reader("plugin:telegram_bot:token")
|
||||
if not token:
|
||||
_log.error(
|
||||
"Telegram bot token not set. Run `pyra plugin setup telegram_bot`."
|
||||
)
|
||||
return
|
||||
|
||||
passphrase_hash = self._vault_reader("plugin:telegram_bot:passphrase_hash") or ""
|
||||
allowed_str = self._vault_reader("plugin:telegram_bot:allowed_users") or ""
|
||||
allowed_users: set[int] = {
|
||||
int(uid.strip())
|
||||
for uid in allowed_str.split(",")
|
||||
if uid.strip().isdigit()
|
||||
}
|
||||
|
||||
app = Application.builder().token(token).build()
|
||||
plugin = self # closure reference
|
||||
|
||||
async def _on_start(update: Update, ctx: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
if update.effective_user and update.effective_user.id in allowed_users:
|
||||
await update.message.reply_text(
|
||||
"Pyra is online. Send any message to authenticate."
|
||||
)
|
||||
|
||||
async def _on_message(update: Update, ctx: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
await plugin._handle_message(update, ctx, allowed_users, passphrase_hash)
|
||||
|
||||
async def _on_approval(update: Update, ctx: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
await plugin._handle_approval_callback(update)
|
||||
|
||||
app.add_handler(CommandHandler("start", _on_start))
|
||||
app.add_handler(CallbackQueryHandler(_on_approval))
|
||||
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, _on_message))
|
||||
|
||||
_log.info("Telegram bot starting (long-polling).")
|
||||
await app.initialize()
|
||||
try:
|
||||
await app.start()
|
||||
await app.updater.start_polling(drop_pending_updates=True)
|
||||
_log.info("Telegram bot is polling for updates.")
|
||||
await asyncio.Event().wait() # block until CancelledError
|
||||
except asyncio.CancelledError:
|
||||
_log.info("Telegram bot shutting down.")
|
||||
finally:
|
||||
try:
|
||||
await app.updater.stop()
|
||||
await app.stop()
|
||||
await app.shutdown()
|
||||
except Exception as exc:
|
||||
_log.warning("Error during Telegram bot shutdown: %s", exc)
|
||||
|
||||
# ── Message handler ───────────────────────────────────────────────────────
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
update: Update,
|
||||
ctx: ContextTypes.DEFAULT_TYPE,
|
||||
allowed_users: set[int],
|
||||
passphrase_hash: str,
|
||||
) -> None:
|
||||
if update.effective_user is None or update.message is None:
|
||||
return
|
||||
|
||||
user_id = update.effective_user.id
|
||||
chat_id = update.effective_chat.id if update.effective_chat else user_id
|
||||
|
||||
# Allowlist — silently ignore unknown senders
|
||||
if allowed_users and user_id not in allowed_users:
|
||||
return
|
||||
|
||||
text = (update.message.text or "").strip()
|
||||
if not text:
|
||||
return
|
||||
|
||||
session = self._sessions.setdefault(
|
||||
chat_id,
|
||||
{"authenticated": False, "awaiting_passphrase": False, "attempts": 0},
|
||||
)
|
||||
|
||||
# ── Passphrase authentication ─────────────────────────────────────────
|
||||
if not session["authenticated"]:
|
||||
if not session["awaiting_passphrase"]:
|
||||
session["awaiting_passphrase"] = True
|
||||
session["attempts"] = 0
|
||||
await update.message.reply_text("Enter your passphrase to continue:")
|
||||
return
|
||||
|
||||
if passphrase_hash and bcrypt.checkpw(text.encode(), passphrase_hash.encode()):
|
||||
session["authenticated"] = True
|
||||
session["awaiting_passphrase"] = False
|
||||
session["attempts"] = 0
|
||||
await update.message.reply_text(
|
||||
"Authenticated. How can I help you?\n"
|
||||
"Send /start at any time to check bot status."
|
||||
)
|
||||
else:
|
||||
session["attempts"] += 1
|
||||
remaining = 3 - session["attempts"]
|
||||
if remaining <= 0:
|
||||
session["awaiting_passphrase"] = False
|
||||
await update.message.reply_text(
|
||||
"Too many failed attempts. Send any message to try again."
|
||||
)
|
||||
else:
|
||||
await update.message.reply_text(
|
||||
f"Wrong passphrase. {remaining} attempt(s) left."
|
||||
)
|
||||
return
|
||||
|
||||
# ── Rate limit ────────────────────────────────────────────────────────
|
||||
if not self._rate_limiter.allow(user_id):
|
||||
await update.message.reply_text(
|
||||
"Rate limit reached (20 messages/hour). Try again later."
|
||||
)
|
||||
return
|
||||
|
||||
# ── Injection scan ────────────────────────────────────────────────────
|
||||
from pyra.security.injection import scan_response
|
||||
|
||||
warnings = scan_response(text)
|
||||
if warnings:
|
||||
labels = ", ".join(w.pattern_label for w in warnings)
|
||||
_log.warning("Injection in Telegram message (user %d): %s", user_id, labels)
|
||||
await update.message.reply_text(
|
||||
"Your message was blocked: injection pattern detected."
|
||||
)
|
||||
return
|
||||
|
||||
# ── Load history + system context ─────────────────────────────────────
|
||||
history = _load_history(chat_id)
|
||||
if not history:
|
||||
try:
|
||||
from pyra.memory.reader import load_context_for_session
|
||||
|
||||
ctx_text = load_context_for_session()
|
||||
if ctx_text:
|
||||
history = [{"role": "system", "content": ctx_text}]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
history.append({"role": "user", "content": text})
|
||||
|
||||
try:
|
||||
await ctx.bot.send_chat_action(chat_id=chat_id, action="typing")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ── AI response ───────────────────────────────────────────────────────
|
||||
try:
|
||||
reply = await self._ai_chat(chat_id, history, ctx.bot)
|
||||
except Exception as exc:
|
||||
_log.error("AI error (chat %d): %s", chat_id, exc, exc_info=True)
|
||||
await ctx.bot.send_message(chat_id=chat_id, text=f"AI error: {exc}")
|
||||
return
|
||||
|
||||
history.append({"role": "assistant", "content": reply})
|
||||
_save_history(chat_id, history)
|
||||
|
||||
# ── AI streaming + tool-use loop ──────────────────────────────────────────
|
||||
|
||||
async def _ai_chat(self, chat_id: int, messages: list[dict], bot: Bot) -> str:
|
||||
from pyra.config.manager import load_config
|
||||
from pyra.plugins.registry import PluginRegistry
|
||||
from pyra.setup.providers import get_provider
|
||||
from pyra.vault.reader import get_key
|
||||
|
||||
cfg = load_config()
|
||||
provider = get_provider(cfg.ai.provider_id)
|
||||
api_key = get_key(cfg.ai.provider_id) if provider.requires_key else "local"
|
||||
|
||||
call_kwargs: dict[str, Any] = {
|
||||
"model": f"{provider.litellm_prefix}{cfg.ai.model}",
|
||||
"api_key": api_key,
|
||||
}
|
||||
base_url = cfg.ai.base_url or provider.base_url
|
||||
if base_url:
|
||||
call_kwargs["api_base"] = base_url
|
||||
|
||||
litellm.suppress_debug_info = True
|
||||
|
||||
registry = PluginRegistry.instance()
|
||||
tools_spec = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": t.name,
|
||||
"description": t.description,
|
||||
"parameters": t.parameters,
|
||||
},
|
||||
}
|
||||
for t in registry.get_all_tools()
|
||||
]
|
||||
|
||||
# Mutable state shared with helpers below
|
||||
state: dict[str, Any] = {"msg_id": None, "last_edit": 0.0}
|
||||
|
||||
placeholder = await bot.send_message(chat_id=chat_id, text="…")
|
||||
state["msg_id"] = placeholder.message_id
|
||||
|
||||
async def _update(text: str) -> None:
|
||||
if not state["msg_id"]:
|
||||
return
|
||||
now = time.monotonic()
|
||||
if now - state["last_edit"] < _EDIT_INTERVAL:
|
||||
return
|
||||
try:
|
||||
await bot.edit_message_text(
|
||||
chat_id=chat_id,
|
||||
message_id=state["msg_id"],
|
||||
text=text[:_MAX_MSG_LEN],
|
||||
)
|
||||
state["last_edit"] = now
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _finalize(text: str) -> None:
|
||||
if not state["msg_id"]:
|
||||
return
|
||||
if text:
|
||||
try:
|
||||
await bot.edit_message_text(
|
||||
chat_id=chat_id,
|
||||
message_id=state["msg_id"],
|
||||
text=text[:_MAX_MSG_LEN],
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
await bot.delete_message(chat_id=chat_id, message_id=state["msg_id"])
|
||||
except Exception:
|
||||
pass
|
||||
state["msg_id"] = None
|
||||
|
||||
accumulated = ""
|
||||
|
||||
try:
|
||||
for _iter in range(_MAX_TOOL_ITER):
|
||||
tool_chunks: dict[int, dict] = {}
|
||||
accumulated = ""
|
||||
|
||||
stream = await litellm.acompletion(
|
||||
**call_kwargs,
|
||||
messages=messages,
|
||||
tools=tools_spec if tools_spec else None,
|
||||
tool_choice="auto" if tools_spec else None,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.content:
|
||||
accumulated += delta.content
|
||||
await _update(accumulated)
|
||||
|
||||
if delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
if idx not in tool_chunks:
|
||||
tool_chunks[idx] = {"id": tc.id or "", "name": "", "args": ""}
|
||||
if tc.function:
|
||||
if tc.function.name:
|
||||
tool_chunks[idx]["name"] += tc.function.name
|
||||
if tc.function.arguments:
|
||||
tool_chunks[idx]["args"] += tc.function.arguments
|
||||
|
||||
if not tool_chunks:
|
||||
await _finalize(accumulated)
|
||||
return accumulated
|
||||
|
||||
# Show any intermediate prose before tool calls
|
||||
if accumulated:
|
||||
await _finalize(accumulated)
|
||||
else:
|
||||
try:
|
||||
await bot.delete_message(chat_id=chat_id, message_id=state["msg_id"])
|
||||
except Exception:
|
||||
pass
|
||||
state["msg_id"] = None
|
||||
|
||||
tool_calls_list = [
|
||||
{
|
||||
"id": data["id"],
|
||||
"type": "function",
|
||||
"function": {"name": data["name"], "arguments": data["args"]},
|
||||
}
|
||||
for _, data in sorted(tool_chunks.items())
|
||||
]
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": accumulated or None,
|
||||
"tool_calls": tool_calls_list,
|
||||
})
|
||||
|
||||
for tc in tool_calls_list:
|
||||
result = await self._execute_tool_with_approval(tc, chat_id, bot)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc["id"],
|
||||
"content": result,
|
||||
})
|
||||
|
||||
# New placeholder for the next AI response
|
||||
new_ph = await bot.send_message(chat_id=chat_id, text="…")
|
||||
state["msg_id"] = new_ph.message_id
|
||||
state["last_edit"] = 0.0
|
||||
|
||||
except litellm.BadRequestError:
|
||||
# Provider doesn't support tool calls — retry without tools
|
||||
accumulated = ""
|
||||
state["last_edit"] = 0.0
|
||||
stream = await litellm.acompletion(
|
||||
**call_kwargs, messages=messages, stream=True
|
||||
)
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
accumulated += chunk.choices[0].delta.content
|
||||
await _update(accumulated)
|
||||
await _finalize(accumulated)
|
||||
return accumulated
|
||||
|
||||
return accumulated or "Error: tool-use loop exceeded maximum iterations."
|
||||
|
||||
# ── Tool approval via inline buttons ──────────────────────────────────────
|
||||
|
||||
async def _execute_tool_with_approval(
|
||||
self, tool_call: dict, chat_id: int, bot: Bot
|
||||
) -> str:
|
||||
from pyra.plugins.registry import PluginRegistry
|
||||
from pyra.security.injection import scan_response
|
||||
|
||||
tool_name = tool_call["function"]["name"]
|
||||
args_raw = tool_call["function"]["arguments"]
|
||||
try:
|
||||
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||
except json.JSONDecodeError:
|
||||
return f"Error: invalid tool arguments for {tool_name}"
|
||||
|
||||
args_preview = json.dumps(args, indent=2)[:500]
|
||||
call_id = uuid.uuid4().hex[:8]
|
||||
|
||||
keyboard = InlineKeyboardMarkup([[
|
||||
InlineKeyboardButton("✅ Approve", callback_data=f"approve:{call_id}"),
|
||||
InlineKeyboardButton("❌ Deny", callback_data=f"deny:{call_id}"),
|
||||
]])
|
||||
|
||||
await bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=f"Tool request: {tool_name}\n\n{args_preview}",
|
||||
reply_markup=keyboard,
|
||||
)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
future: asyncio.Future[bool] = loop.create_future()
|
||||
self._pending_approvals[call_id] = future
|
||||
|
||||
try:
|
||||
approved = await asyncio.wait_for(future, timeout=_APPROVAL_TIMEOUT)
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_approvals.pop(call_id, None)
|
||||
await bot.send_message(
|
||||
chat_id=chat_id, text=f"Tool {tool_name}: approval timed out — denied."
|
||||
)
|
||||
return "Tool execution denied (timeout)."
|
||||
|
||||
if not approved:
|
||||
return "Tool execution denied by user."
|
||||
|
||||
registry = PluginRegistry.instance()
|
||||
tool = registry.find_tool(tool_name)
|
||||
if tool is None:
|
||||
return f"Error: tool '{tool_name}' not found in registry."
|
||||
|
||||
try:
|
||||
result = tool.handler(**args)
|
||||
if not isinstance(result, str):
|
||||
result = str(result)
|
||||
except Exception as exc:
|
||||
return f"Tool error: {exc}"
|
||||
|
||||
injection_warnings = scan_response(result)
|
||||
if injection_warnings:
|
||||
labels = ", ".join(w.pattern_label for w in injection_warnings)
|
||||
await bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=f"Warning: tool result contains suspicious content ({labels}).",
|
||||
)
|
||||
|
||||
return result[:4000]
|
||||
|
||||
# ── Approval callback ─────────────────────────────────────────────────────
|
||||
|
||||
async def _handle_approval_callback(self, update: Update) -> None:
|
||||
query = update.callback_query
|
||||
if query is None:
|
||||
return
|
||||
await query.answer()
|
||||
data = query.data or ""
|
||||
if ":" not in data:
|
||||
return
|
||||
action, call_id = data.split(":", 1)
|
||||
future = self._pending_approvals.pop(call_id, None)
|
||||
if future and not future.done():
|
||||
future.set_result(action == "approve")
|
||||
try:
|
||||
await query.edit_message_reply_markup(reply_markup=None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_plugin() -> TelegramBotPlugin:
|
||||
return TelegramBotPlugin()
|
||||
+46
-2
@@ -165,6 +165,7 @@ def plugin_list() -> None:
|
||||
@click.argument("name")
|
||||
def plugin_install(name: str) -> None:
|
||||
"""Install a bundled plugin to ~/.pyra/plugins/."""
|
||||
import questionary
|
||||
from pyra.plugins.install import get_bundled_plugins_dir, install_bundled_plugin
|
||||
from pyra.utils.paths import pyra_home
|
||||
|
||||
@@ -173,12 +174,55 @@ def plugin_install(name: str) -> None:
|
||||
try:
|
||||
install_bundled_plugin(name, bundled_dir, plugins_dir)
|
||||
console.print(f"[green]Installed:[/green] {name}")
|
||||
console.print(f" Enable: [dim]pyra plugin enable {name}[/dim]")
|
||||
console.print(f" Configure: [dim]pyra plugin setup {name}[/dim]")
|
||||
except FileNotFoundError as exc:
|
||||
console.print(f"[red]Error:[/red] {exc}")
|
||||
return
|
||||
except Exception as exc:
|
||||
console.print(f"[red]Install failed:[/red] {exc}")
|
||||
return
|
||||
|
||||
try:
|
||||
configure_now = questionary.confirm(f"Configure {name} now?", default=True).ask()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return
|
||||
if not configure_now:
|
||||
console.print(f" Configure later: [dim]pyra plugin setup {name}[/dim]")
|
||||
console.print(f" Enable: [dim]pyra plugin enable {name}[/dim]")
|
||||
return
|
||||
|
||||
from pyra.plugins.loader import load_plugin_by_name
|
||||
from pyra.vault.writer import set_key as _set_key
|
||||
|
||||
p = load_plugin_by_name(name, plugins_dir)
|
||||
if p is None:
|
||||
console.print(f"[red]Could not load {name} for setup.[/red]")
|
||||
return
|
||||
try:
|
||||
p.setup(console, _set_key)
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
console.print("[dim]Setup cancelled.[/dim]")
|
||||
return
|
||||
except Exception as exc:
|
||||
console.print(f"[red]Setup error:[/red] {exc}")
|
||||
return
|
||||
|
||||
try:
|
||||
enable_now = questionary.confirm(f"Enable {name} now?", default=True).ask()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
enable_now = False
|
||||
if enable_now:
|
||||
from pyra.config.manager import load_config, save_config
|
||||
try:
|
||||
cfg = load_config()
|
||||
if name not in cfg.plugins.enabled:
|
||||
cfg.plugins.enabled.append(name)
|
||||
save_config(cfg)
|
||||
console.print(f"[green]Enabled:[/green] {name}")
|
||||
except Exception as exc:
|
||||
console.print(f"[yellow]Could not enable automatically:[/yellow] {exc}")
|
||||
console.print(f" Enable manually: [dim]pyra plugin enable {name}[/dim]")
|
||||
|
||||
console.print(f"[dim]Run [bold]pyra daemon start[/bold] to bring {name} online.[/dim]")
|
||||
|
||||
|
||||
@plugin.command("enable")
|
||||
|
||||
@@ -7,7 +7,7 @@ from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from pyra.config.manager import save_config
|
||||
from pyra.config.manager import load_config, save_config
|
||||
from pyra.config.schema import GeneralConfig, ProviderConfig, PyraConfig
|
||||
from pyra.setup.providers import PROVIDERS, Provider, get_provider
|
||||
from pyra.utils.paths import pyra_home, safe_chmod
|
||||
@@ -167,6 +167,7 @@ def run_setup() -> None:
|
||||
_delete_draft()
|
||||
|
||||
_suggest_plugins(use_cases)
|
||||
_offer_telegram_setup_if_selected(use_cases)
|
||||
|
||||
console.print()
|
||||
console.print(Panel(
|
||||
@@ -238,6 +239,68 @@ def _suggest_plugins(use_cases: list[str]) -> None:
|
||||
))
|
||||
|
||||
|
||||
def _offer_telegram_setup_if_selected(use_cases: list[str]) -> None:
|
||||
"""If the user chose 'Communication bots', offer to install and configure telegram_bot."""
|
||||
relevant = any(
|
||||
"telegram_bot" in _USE_CASE_PLUGINS.get(uc, [])
|
||||
for uc in use_cases
|
||||
)
|
||||
if not relevant:
|
||||
return
|
||||
|
||||
console.print()
|
||||
try:
|
||||
answer = questionary.confirm(
|
||||
"Set up the Telegram bot for remote access to Pyra?", default=True
|
||||
).ask()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return
|
||||
if not answer:
|
||||
console.print(
|
||||
" [dim]You can do this later: pyra plugin install telegram_bot[/dim]"
|
||||
)
|
||||
return
|
||||
|
||||
from pyra.plugins.install import get_bundled_plugins_dir, install_bundled_plugin
|
||||
from pyra.plugins.loader import load_plugin_by_name
|
||||
from pyra.utils.paths import pyra_home
|
||||
from pyra.vault.writer import set_key
|
||||
|
||||
bundled_dir = get_bundled_plugins_dir()
|
||||
plugins_dir = pyra_home() / "plugins"
|
||||
|
||||
try:
|
||||
install_bundled_plugin("telegram_bot", bundled_dir, plugins_dir)
|
||||
except Exception as exc:
|
||||
console.print(f"[red]Could not install telegram_bot:[/red] {exc}")
|
||||
return
|
||||
|
||||
p = load_plugin_by_name("telegram_bot", plugins_dir)
|
||||
if p is None:
|
||||
console.print("[red]Could not load telegram_bot for setup.[/red]")
|
||||
return
|
||||
|
||||
try:
|
||||
p.setup(console, set_key)
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
console.print("[dim]Telegram setup skipped.[/dim]")
|
||||
return
|
||||
except Exception as exc:
|
||||
console.print(f"[red]Telegram setup error:[/red] {exc}")
|
||||
return
|
||||
|
||||
try:
|
||||
cfg = load_config()
|
||||
if "telegram_bot" not in cfg.plugins.enabled:
|
||||
cfg.plugins.enabled.append("telegram_bot")
|
||||
save_config(cfg)
|
||||
console.print("[green]Telegram bot enabled.[/green]")
|
||||
except Exception:
|
||||
console.print(
|
||||
" [dim]Enable manually: pyra plugin enable telegram_bot[/dim]"
|
||||
)
|
||||
|
||||
|
||||
def _choose_provider() -> Provider:
|
||||
local = [p for p in PROVIDERS if p.group == "Local"]
|
||||
cloud = [p for p in PROVIDERS if p.group == "Cloud"]
|
||||
|
||||
@@ -137,3 +137,47 @@ def test_main_skips_setup_when_config_exists(tmp_pyra_home, monkeypatch):
|
||||
def test_config_slash_command_registered():
|
||||
from pyra.chat.session import _STATIC_COMMANDS
|
||||
assert "/config" in _STATIC_COMMANDS
|
||||
|
||||
|
||||
# ── plugin install ────────────────────────────────────────────────────────────
|
||||
|
||||
def test_plugin_install_decline_setup(tmp_pyra_home, monkeypatch):
|
||||
"""Declining 'Configure now?' shows manual instructions and exits cleanly."""
|
||||
from unittest.mock import MagicMock
|
||||
import questionary
|
||||
|
||||
monkeypatch.setattr(
|
||||
"pyra.plugins.install.install_bundled_plugin", lambda *a, **kw: None
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
questionary, "confirm",
|
||||
lambda *a, **kw: MagicMock(ask=lambda: False),
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["plugin", "install", "telegram_bot"])
|
||||
assert result.exit_code == 0
|
||||
assert "Installed" in result.output
|
||||
assert "Configure later" in result.output
|
||||
|
||||
|
||||
def test_plugin_install_error_does_not_prompt(tmp_pyra_home, monkeypatch):
|
||||
"""If install fails, the configure prompt is never shown."""
|
||||
from unittest.mock import MagicMock
|
||||
import questionary
|
||||
|
||||
monkeypatch.setattr(
|
||||
"pyra.plugins.install.install_bundled_plugin",
|
||||
lambda *a, **kw: (_ for _ in ()).throw(FileNotFoundError("not found")),
|
||||
)
|
||||
confirm_calls = []
|
||||
monkeypatch.setattr(
|
||||
questionary, "confirm",
|
||||
lambda *a, **kw: confirm_calls.append(1) or MagicMock(ask=lambda: False),
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["plugin", "install", "telegram_bot"])
|
||||
assert result.exit_code == 0
|
||||
assert "Error" in result.output
|
||||
assert len(confirm_calls) == 0 # prompt never reached
|
||||
|
||||
@@ -0,0 +1,281 @@
|
||||
"""Unit tests for the telegram_bot bundled plugin.
|
||||
|
||||
Tests cover pure-logic helpers: rate limiter, SQLite history, auth flow,
|
||||
and tool argument parsing. Handler integration (live Telegram) is not tested.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# ── Import helpers straight from the bundled source ──────────────────────────
|
||||
|
||||
_PLUGIN_PATH = (
|
||||
Path(__file__).parent.parent.parent
|
||||
/ "src/pyra/bundled_plugins/telegram_bot/plugin.py"
|
||||
)
|
||||
|
||||
|
||||
def _load_plugin_module():
|
||||
"""Load plugin.py without requiring python-telegram-bot to be installed."""
|
||||
# Provide stub modules so the top-level imports don't fail in unit tests
|
||||
for mod_name in [
|
||||
"bcrypt",
|
||||
"telegram",
|
||||
"telegram.ext",
|
||||
"litellm",
|
||||
]:
|
||||
if mod_name not in sys.modules:
|
||||
stub = MagicMock()
|
||||
sys.modules[mod_name] = stub
|
||||
# telegram.ext sub-attributes needed at import time
|
||||
if mod_name == "telegram.ext":
|
||||
stub.Application = MagicMock()
|
||||
stub.CallbackQueryHandler = MagicMock()
|
||||
stub.CommandHandler = MagicMock()
|
||||
stub.ContextTypes = MagicMock()
|
||||
stub.MessageHandler = MagicMock()
|
||||
stub.filters = MagicMock()
|
||||
|
||||
spec = importlib.util.spec_from_file_location("pyra_plugin_telegram_bot", _PLUGIN_PATH)
|
||||
assert spec and spec.loader
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod) # type: ignore[union-attr]
|
||||
return mod
|
||||
|
||||
|
||||
_mod = _load_plugin_module()
|
||||
_RateLimiter = _mod._RateLimiter
|
||||
_load_history = _mod._load_history
|
||||
_save_history = _mod._save_history
|
||||
TelegramBotPlugin = _mod.TelegramBotPlugin
|
||||
|
||||
|
||||
# ── Rate limiter ──────────────────────────────────────────────────────────────
|
||||
|
||||
class TestRateLimiter:
|
||||
def test_allows_up_to_limit(self):
|
||||
rl = _RateLimiter(per_hour=5)
|
||||
for _ in range(5):
|
||||
assert rl.allow(user_id=1)
|
||||
|
||||
def test_blocks_over_limit(self):
|
||||
rl = _RateLimiter(per_hour=3)
|
||||
for _ in range(3):
|
||||
rl.allow(user_id=1)
|
||||
assert not rl.allow(user_id=1)
|
||||
|
||||
def test_independent_per_user(self):
|
||||
rl = _RateLimiter(per_hour=2)
|
||||
rl.allow(1)
|
||||
rl.allow(1)
|
||||
assert not rl.allow(1)
|
||||
assert rl.allow(2)
|
||||
|
||||
def test_old_entries_expire(self):
|
||||
rl = _RateLimiter(per_hour=1)
|
||||
# Manually populate bucket with an old timestamp
|
||||
from collections import deque
|
||||
rl._buckets[99] = deque([time.monotonic() - 3601])
|
||||
assert rl.allow(99) # old entry removed, new one fits
|
||||
|
||||
def test_multiple_users_independent(self):
|
||||
rl = _RateLimiter(per_hour=1)
|
||||
rl.allow(10)
|
||||
assert not rl.allow(10)
|
||||
assert rl.allow(20)
|
||||
assert rl.allow(30)
|
||||
|
||||
|
||||
# ── History DB ────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestHistoryDB:
|
||||
def test_empty_history(self, tmp_path, monkeypatch):
|
||||
db_path = tmp_path / "tg.db"
|
||||
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||
assert _load_history(42) == []
|
||||
|
||||
def test_save_and_load(self, tmp_path, monkeypatch):
|
||||
db_path = tmp_path / "tg.db"
|
||||
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||
msgs = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi there"},
|
||||
]
|
||||
_save_history(42, msgs)
|
||||
assert _load_history(42) == msgs
|
||||
|
||||
def test_save_trims_to_max(self, tmp_path, monkeypatch):
|
||||
db_path = tmp_path / "tg.db"
|
||||
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||
monkeypatch.setattr(_mod, "_MAX_HISTORY", 3)
|
||||
msgs = [{"role": "user", "content": str(i)} for i in range(10)]
|
||||
_save_history(1, msgs)
|
||||
loaded = _load_history(1)
|
||||
assert len(loaded) == 3
|
||||
# Should keep the most recent messages
|
||||
assert loaded[-1]["content"] == "9"
|
||||
|
||||
def test_overwrite(self, tmp_path, monkeypatch):
|
||||
db_path = tmp_path / "tg.db"
|
||||
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||
_save_history(1, [{"role": "user", "content": "first"}])
|
||||
_save_history(1, [{"role": "user", "content": "second"}])
|
||||
assert _load_history(1)[0]["content"] == "second"
|
||||
|
||||
def test_separate_chats(self, tmp_path, monkeypatch):
|
||||
db_path = tmp_path / "tg.db"
|
||||
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||
_save_history(1, [{"role": "user", "content": "chat1"}])
|
||||
_save_history(2, [{"role": "user", "content": "chat2"}])
|
||||
assert _load_history(1)[0]["content"] == "chat1"
|
||||
assert _load_history(2)[0]["content"] == "chat2"
|
||||
|
||||
|
||||
# ── Plugin: on_load and setup ─────────────────────────────────────────────────
|
||||
|
||||
class TestPluginLifecycle:
|
||||
def test_on_load_stores_vault_reader(self):
|
||||
plugin = TelegramBotPlugin()
|
||||
reader = MagicMock(return_value=None)
|
||||
plugin.on_load(reader)
|
||||
assert plugin._vault_reader is reader
|
||||
|
||||
def test_daemon_tasks_returns_coroutine(self):
|
||||
import inspect
|
||||
plugin = TelegramBotPlugin()
|
||||
plugin.on_load(MagicMock(return_value=None))
|
||||
tasks = plugin.daemon_tasks()
|
||||
assert len(tasks) == 1
|
||||
assert inspect.iscoroutine(tasks[0])
|
||||
tasks[0].close() # prevent "coroutine never awaited" warning
|
||||
|
||||
def test_get_plugin_factory(self):
|
||||
plugin = _mod.get_plugin()
|
||||
assert isinstance(plugin, TelegramBotPlugin)
|
||||
assert plugin.name == "telegram_bot"
|
||||
|
||||
def test_config_fields(self):
|
||||
plugin = TelegramBotPlugin()
|
||||
fields = plugin.config_fields()
|
||||
assert any(f.key == "rate_limit" for f in fields)
|
||||
|
||||
def _patch_setup(self, token, allowed, pass1, pass2):
|
||||
"""Return a context manager that patches all questionary calls used by setup()."""
|
||||
pw_answers = iter([token, pass1, pass2])
|
||||
return (
|
||||
patch("questionary.password",
|
||||
side_effect=lambda *a, **kw: MagicMock(ask=lambda: next(pw_answers))),
|
||||
patch("questionary.text",
|
||||
return_value=MagicMock(ask=lambda: allowed)),
|
||||
patch("questionary.press_any_key_to_continue",
|
||||
return_value=MagicMock(ask=lambda: None)),
|
||||
)
|
||||
|
||||
def test_setup_mismatched_passphrase(self):
|
||||
"""setup() writes nothing to the vault when passphrases don't match."""
|
||||
plugin = TelegramBotPlugin()
|
||||
console = MagicMock()
|
||||
vault_writer = MagicMock()
|
||||
|
||||
pw_patch, text_patch, pakc_patch = self._patch_setup(
|
||||
"fake-token", "123456789", "pass1", "pass2"
|
||||
)
|
||||
with pw_patch, text_patch, pakc_patch:
|
||||
plugin.setup(console, vault_writer)
|
||||
|
||||
vault_writer.assert_not_called()
|
||||
console.print.assert_called_with(
|
||||
"[red]Passphrases do not match. Run setup again to retry.[/red]"
|
||||
)
|
||||
|
||||
def test_setup_happy_path(self):
|
||||
"""setup() writes all three vault keys when credentials are valid."""
|
||||
plugin = TelegramBotPlugin()
|
||||
console = MagicMock()
|
||||
vault_writer = MagicMock()
|
||||
|
||||
pw_patch, text_patch, pakc_patch = self._patch_setup(
|
||||
"real-token", "111222333", "s3cr3t", "s3cr3t"
|
||||
)
|
||||
with pw_patch, text_patch, pakc_patch:
|
||||
plugin.setup(console, vault_writer)
|
||||
|
||||
calls = {call[0][0]: call[0][1] for call in vault_writer.call_args_list}
|
||||
assert calls.get("plugin:telegram_bot:token") == "real-token"
|
||||
assert calls.get("plugin:telegram_bot:allowed_users") == "111222333"
|
||||
assert "plugin:telegram_bot:passphrase_hash" in calls
|
||||
|
||||
def test_setup_cancelled_on_empty_token(self):
|
||||
"""setup() exits without writing if the token prompt is cancelled."""
|
||||
plugin = TelegramBotPlugin()
|
||||
console = MagicMock()
|
||||
vault_writer = MagicMock()
|
||||
|
||||
with patch("questionary.password", return_value=MagicMock(ask=lambda: None)), \
|
||||
patch("questionary.press_any_key_to_continue",
|
||||
return_value=MagicMock(ask=lambda: None)):
|
||||
plugin.setup(console, vault_writer)
|
||||
|
||||
vault_writer.assert_not_called()
|
||||
|
||||
|
||||
# ── Auth session state ────────────────────────────────────────────────────────
|
||||
|
||||
class TestAuthState:
|
||||
def test_session_defaults(self):
|
||||
plugin = TelegramBotPlugin()
|
||||
session = plugin._sessions.setdefault(
|
||||
99, {"authenticated": False, "awaiting_passphrase": False, "attempts": 0}
|
||||
)
|
||||
assert not session["authenticated"]
|
||||
assert not session["awaiting_passphrase"]
|
||||
assert session["attempts"] == 0
|
||||
|
||||
def test_session_per_chat(self):
|
||||
plugin = TelegramBotPlugin()
|
||||
plugin._sessions[1] = {"authenticated": True, "awaiting_passphrase": False, "attempts": 0}
|
||||
plugin._sessions[2] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 1}
|
||||
assert plugin._sessions[1]["authenticated"]
|
||||
assert not plugin._sessions[2]["authenticated"]
|
||||
assert plugin._sessions[2]["attempts"] == 1
|
||||
|
||||
def test_session_auth_flag(self):
|
||||
plugin = TelegramBotPlugin()
|
||||
plugin._sessions[5] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 0}
|
||||
# Simulate successful auth
|
||||
plugin._sessions[5]["authenticated"] = True
|
||||
plugin._sessions[5]["awaiting_passphrase"] = False
|
||||
assert plugin._sessions[5]["authenticated"]
|
||||
|
||||
|
||||
# ── Tool argument parsing ─────────────────────────────────────────────────────
|
||||
|
||||
class TestToolArgParsing:
|
||||
def test_valid_json_string(self):
|
||||
args_raw = '{"query": "hello"}'
|
||||
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||
assert args == {"query": "hello"}
|
||||
|
||||
def test_dict_passthrough(self):
|
||||
args_raw = {"query": "hello"}
|
||||
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||
assert args == {"query": "hello"}
|
||||
|
||||
def test_invalid_json_caught(self):
|
||||
args_raw = "not json"
|
||||
try:
|
||||
json.loads(args_raw)
|
||||
assert False, "Should have raised"
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
def test_result_truncated_to_4000(self):
|
||||
long_result = "x" * 5000
|
||||
assert len(long_result[:4000]) == 4000
|
||||
Reference in New Issue
Block a user