Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3f30b782d2 |
+2
-10
@@ -28,30 +28,22 @@ dev = [
|
|||||||
]
|
]
|
||||||
nextcloud = ["caldav>=1.3.0", "webdav4>=0.9.0", "vobject>=0.9.6"]
|
nextcloud = ["caldav>=1.3.0", "webdav4>=0.9.0", "vobject>=0.9.6"]
|
||||||
matrix = ["matrix-nio>=0.24.0", "aiofiles>=23.0.0"]
|
matrix = ["matrix-nio>=0.24.0", "aiofiles>=23.0.0"]
|
||||||
telegram = ["python-telegram-bot>=21.0"]
|
telegram = ["python-telegram-bot>=21.0", "bcrypt>=4.0.0"]
|
||||||
ssh = ["paramiko>=3.4.0"]
|
ssh = ["paramiko>=3.4.0"]
|
||||||
docker = ["docker>=7.0.0"]
|
docker = ["docker>=7.0.0"]
|
||||||
gdrive = ["google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0"]
|
gdrive = ["google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0"]
|
||||||
onedrive = ["msal>=1.28.0"]
|
onedrive = ["msal>=1.28.0"]
|
||||||
dropbox = ["dropbox>=12.0.0"]
|
dropbox = ["dropbox>=12.0.0"]
|
||||||
daemon = ["aiofiles>=23.0.0"]
|
daemon = ["aiofiles>=23.0.0"]
|
||||||
email = [
|
|
||||||
"imap-tools>=1.7.0",
|
|
||||||
"google-api-python-client>=2.120.0",
|
|
||||||
"google-auth-oauthlib>=1.2.0",
|
|
||||||
"O365>=2.0.36",
|
|
||||||
]
|
|
||||||
all-plugins = [
|
all-plugins = [
|
||||||
"caldav>=1.3.0", "webdav4>=0.9.0", "vobject>=0.9.6",
|
"caldav>=1.3.0", "webdav4>=0.9.0", "vobject>=0.9.6",
|
||||||
"matrix-nio>=0.24.0", "aiofiles>=23.0.0",
|
"matrix-nio>=0.24.0", "aiofiles>=23.0.0",
|
||||||
"python-telegram-bot>=21.0",
|
"python-telegram-bot>=21.0", "bcrypt>=4.0.0",
|
||||||
"paramiko>=3.4.0",
|
"paramiko>=3.4.0",
|
||||||
"docker>=7.0.0",
|
"docker>=7.0.0",
|
||||||
"google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0",
|
"google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0",
|
||||||
"msal>=1.28.0",
|
"msal>=1.28.0",
|
||||||
"dropbox>=12.0.0",
|
"dropbox>=12.0.0",
|
||||||
"imap-tools>=1.7.0",
|
|
||||||
"O365>=2.0.36",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "email",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "Full email management — read, send, search, sort, and create filter rules. Supports Gmail, Microsoft 365, ProtonMail (Bridge), and any IMAP provider. Background monitoring pushes new-email summaries to your configured messaging bot.",
|
|
||||||
"author": "pyra",
|
|
||||||
"requires": [
|
|
||||||
"imap-tools>=1.7.0",
|
|
||||||
"google-api-python-client>=2.120.0",
|
|
||||||
"google-auth-oauthlib>=1.2.0",
|
|
||||||
"O365>=2.0.36"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"name": "telegram_bot",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"description": "Remote Pyra chat over Telegram — full AI chat with tool approval via inline buttons",
|
||||||
|
"author": "pyra",
|
||||||
|
"requires": ["python-telegram-bot>=21.0", "bcrypt>=4.0.0"]
|
||||||
|
}
|
||||||
@@ -0,0 +1,592 @@
|
|||||||
|
"""Telegram bot plugin — remote Pyra chat over Telegram.
|
||||||
|
|
||||||
|
Runs as a daemon task (long-polling). Each chat session requires passphrase
|
||||||
|
authentication and is rate-limited to 20 messages/hour. Incoming messages are
|
||||||
|
injection-scanned before reaching the AI. Tool calls are approved via Telegram
|
||||||
|
inline keyboard buttons (2-minute timeout). AI responses are streamed
|
||||||
|
progressively by editing a placeholder message.
|
||||||
|
|
||||||
|
Vault keys used:
|
||||||
|
plugin:telegram_bot:token — bot token from @BotFather
|
||||||
|
plugin:telegram_bot:allowed_users — comma-separated Telegram user IDs
|
||||||
|
plugin:telegram_bot:passphrase_hash — bcrypt hash of the session passphrase
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections import deque
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
import litellm
|
||||||
|
from telegram import Bot, InlineKeyboardButton, InlineKeyboardMarkup, Update
|
||||||
|
from telegram.ext import (
|
||||||
|
Application,
|
||||||
|
CallbackQueryHandler,
|
||||||
|
CommandHandler,
|
||||||
|
ContextTypes,
|
||||||
|
MessageHandler,
|
||||||
|
filters,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pyra.plugins.base import BasePlugin, ConfigField
|
||||||
|
from pyra.utils.paths import pyra_home
|
||||||
|
|
||||||
|
_log = logging.getLogger("pyra.plugin.telegram_bot")
|
||||||
|
|
||||||
|
_HISTORY_DB = pyra_home() / "telegram_history.db"
|
||||||
|
_MAX_HISTORY = 40 # messages kept per chat
|
||||||
|
_RATE_LIMIT = 20 # messages per hour per user
|
||||||
|
_APPROVAL_TIMEOUT = 120 # seconds to wait for inline button press
|
||||||
|
_EDIT_INTERVAL = 1.5 # minimum seconds between progressive message edits
|
||||||
|
_MAX_TOOL_ITER = 10
|
||||||
|
_MAX_MSG_LEN = 4096 # Telegram hard limit
|
||||||
|
|
||||||
|
|
||||||
|
# ── SQLite history ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _open_db() -> sqlite3.Connection:
|
||||||
|
conn = sqlite3.connect(_HISTORY_DB)
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
|
chat_id INTEGER PRIMARY KEY,
|
||||||
|
history TEXT NOT NULL DEFAULT '[]',
|
||||||
|
updated REAL NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
conn.commit()
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
os.chmod(_HISTORY_DB, 0o600)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
def _load_history(chat_id: int) -> list[dict]:
|
||||||
|
conn = _open_db()
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT history FROM sessions WHERE chat_id = ?", (chat_id,)
|
||||||
|
).fetchone()
|
||||||
|
conn.close()
|
||||||
|
return json.loads(row[0]) if row else []
|
||||||
|
|
||||||
|
|
||||||
|
def _save_history(chat_id: int, messages: list[dict]) -> None:
|
||||||
|
trimmed = messages[-_MAX_HISTORY:]
|
||||||
|
conn = _open_db()
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR REPLACE INTO sessions (chat_id, history, updated) VALUES (?,?,?)",
|
||||||
|
(chat_id, json.dumps(trimmed), time.time()),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Rate limiter ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _RateLimiter:
|
||||||
|
def __init__(self, per_hour: int = _RATE_LIMIT) -> None:
|
||||||
|
self._buckets: dict[int, deque] = {}
|
||||||
|
self._limit = per_hour
|
||||||
|
|
||||||
|
def allow(self, user_id: int) -> bool:
|
||||||
|
now = time.monotonic()
|
||||||
|
bucket = self._buckets.setdefault(user_id, deque())
|
||||||
|
cutoff = now - 3600
|
||||||
|
while bucket and bucket[0] < cutoff:
|
||||||
|
bucket.popleft()
|
||||||
|
if len(bucket) >= self._limit:
|
||||||
|
return False
|
||||||
|
bucket.append(now)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# ── Plugin ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TelegramBotPlugin(BasePlugin):
|
||||||
|
name = "telegram_bot"
|
||||||
|
description = "Remote Pyra chat over Telegram (daemon task, long-polling)"
|
||||||
|
version = "1.0.0"
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._vault_reader: Callable[[str], str | None] | None = None
|
||||||
|
self._rate_limiter = _RateLimiter()
|
||||||
|
# chat_id -> {authenticated, awaiting_passphrase, attempts}
|
||||||
|
self._sessions: dict[int, dict] = {}
|
||||||
|
# short call_id -> asyncio.Future[bool]
|
||||||
|
self._pending_approvals: dict[str, asyncio.Future] = {}
|
||||||
|
|
||||||
|
# ── Plugin lifecycle ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def on_load(self, vault_reader: Callable[[str], str | None]) -> None:
|
||||||
|
self._vault_reader = vault_reader
|
||||||
|
|
||||||
|
def setup(self, console: Any, vault_writer: Callable[[str, str], None]) -> None:
|
||||||
|
import questionary
|
||||||
|
|
||||||
|
console.print("[bold]Telegram Bot Setup[/bold]")
|
||||||
|
console.print(
|
||||||
|
"1. Create a bot via @BotFather on Telegram to get your bot token.\n"
|
||||||
|
"2. Find your Telegram user ID by messaging @userinfobot.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
token = questionary.password("Bot token (from @BotFather):").ask()
|
||||||
|
if not token:
|
||||||
|
return
|
||||||
|
|
||||||
|
allowed = questionary.text(
|
||||||
|
"Allowed Telegram user IDs (comma-separated, e.g. 123456789):"
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
passphrase = questionary.password("Session passphrase:").ask()
|
||||||
|
if not passphrase:
|
||||||
|
return
|
||||||
|
|
||||||
|
confirm = questionary.password("Confirm passphrase:").ask()
|
||||||
|
if passphrase != confirm:
|
||||||
|
console.print("[red]Passphrases do not match.[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
pw_hash = bcrypt.hashpw(passphrase.encode(), bcrypt.gensalt()).decode()
|
||||||
|
vault_writer("plugin:telegram_bot:token", token.strip())
|
||||||
|
vault_writer("plugin:telegram_bot:allowed_users", (allowed or "").strip())
|
||||||
|
vault_writer("plugin:telegram_bot:passphrase_hash", pw_hash)
|
||||||
|
console.print("[green]Telegram bot configured. Enable it and start the daemon.[/green]")
|
||||||
|
|
||||||
|
def config_fields(self) -> list[ConfigField]:
|
||||||
|
return [
|
||||||
|
ConfigField(
|
||||||
|
"rate_limit",
|
||||||
|
"Rate limit (messages/hour)",
|
||||||
|
"text",
|
||||||
|
str(_RATE_LIMIT),
|
||||||
|
description="Maximum messages per hour per Telegram user",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def daemon_tasks(self) -> list:
|
||||||
|
return [self._run_polling()]
|
||||||
|
|
||||||
|
# ── Daemon task ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _run_polling(self) -> None:
|
||||||
|
assert self._vault_reader is not None
|
||||||
|
|
||||||
|
token = self._vault_reader("plugin:telegram_bot:token")
|
||||||
|
if not token:
|
||||||
|
_log.error(
|
||||||
|
"Telegram bot token not set. Run `pyra plugin setup telegram_bot`."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
passphrase_hash = self._vault_reader("plugin:telegram_bot:passphrase_hash") or ""
|
||||||
|
allowed_str = self._vault_reader("plugin:telegram_bot:allowed_users") or ""
|
||||||
|
allowed_users: set[int] = {
|
||||||
|
int(uid.strip())
|
||||||
|
for uid in allowed_str.split(",")
|
||||||
|
if uid.strip().isdigit()
|
||||||
|
}
|
||||||
|
|
||||||
|
app = Application.builder().token(token).build()
|
||||||
|
plugin = self # closure reference
|
||||||
|
|
||||||
|
async def _on_start(update: Update, ctx: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
|
if update.effective_user and update.effective_user.id in allowed_users:
|
||||||
|
await update.message.reply_text(
|
||||||
|
"Pyra is online. Send any message to authenticate."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _on_message(update: Update, ctx: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
|
await plugin._handle_message(update, ctx, allowed_users, passphrase_hash)
|
||||||
|
|
||||||
|
async def _on_approval(update: Update, ctx: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
|
await plugin._handle_approval_callback(update)
|
||||||
|
|
||||||
|
app.add_handler(CommandHandler("start", _on_start))
|
||||||
|
app.add_handler(CallbackQueryHandler(_on_approval))
|
||||||
|
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, _on_message))
|
||||||
|
|
||||||
|
_log.info("Telegram bot starting (long-polling).")
|
||||||
|
await app.initialize()
|
||||||
|
try:
|
||||||
|
await app.start()
|
||||||
|
await app.updater.start_polling(drop_pending_updates=True)
|
||||||
|
_log.info("Telegram bot is polling for updates.")
|
||||||
|
await asyncio.Event().wait() # block until CancelledError
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
_log.info("Telegram bot shutting down.")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await app.updater.stop()
|
||||||
|
await app.stop()
|
||||||
|
await app.shutdown()
|
||||||
|
except Exception as exc:
|
||||||
|
_log.warning("Error during Telegram bot shutdown: %s", exc)
|
||||||
|
|
||||||
|
# ── Message handler ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _handle_message(
|
||||||
|
self,
|
||||||
|
update: Update,
|
||||||
|
ctx: ContextTypes.DEFAULT_TYPE,
|
||||||
|
allowed_users: set[int],
|
||||||
|
passphrase_hash: str,
|
||||||
|
) -> None:
|
||||||
|
if update.effective_user is None or update.message is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
user_id = update.effective_user.id
|
||||||
|
chat_id = update.effective_chat.id if update.effective_chat else user_id
|
||||||
|
|
||||||
|
# Allowlist — silently ignore unknown senders
|
||||||
|
if allowed_users and user_id not in allowed_users:
|
||||||
|
return
|
||||||
|
|
||||||
|
text = (update.message.text or "").strip()
|
||||||
|
if not text:
|
||||||
|
return
|
||||||
|
|
||||||
|
session = self._sessions.setdefault(
|
||||||
|
chat_id,
|
||||||
|
{"authenticated": False, "awaiting_passphrase": False, "attempts": 0},
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Passphrase authentication ─────────────────────────────────────────
|
||||||
|
if not session["authenticated"]:
|
||||||
|
if not session["awaiting_passphrase"]:
|
||||||
|
session["awaiting_passphrase"] = True
|
||||||
|
session["attempts"] = 0
|
||||||
|
await update.message.reply_text("Enter your passphrase to continue:")
|
||||||
|
return
|
||||||
|
|
||||||
|
if passphrase_hash and bcrypt.checkpw(text.encode(), passphrase_hash.encode()):
|
||||||
|
session["authenticated"] = True
|
||||||
|
session["awaiting_passphrase"] = False
|
||||||
|
session["attempts"] = 0
|
||||||
|
await update.message.reply_text(
|
||||||
|
"Authenticated. How can I help you?\n"
|
||||||
|
"Send /start at any time to check bot status."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
session["attempts"] += 1
|
||||||
|
remaining = 3 - session["attempts"]
|
||||||
|
if remaining <= 0:
|
||||||
|
session["awaiting_passphrase"] = False
|
||||||
|
await update.message.reply_text(
|
||||||
|
"Too many failed attempts. Send any message to try again."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await update.message.reply_text(
|
||||||
|
f"Wrong passphrase. {remaining} attempt(s) left."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Rate limit ────────────────────────────────────────────────────────
|
||||||
|
if not self._rate_limiter.allow(user_id):
|
||||||
|
await update.message.reply_text(
|
||||||
|
"Rate limit reached (20 messages/hour). Try again later."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Injection scan ────────────────────────────────────────────────────
|
||||||
|
from pyra.security.injection import scan_response
|
||||||
|
|
||||||
|
warnings = scan_response(text)
|
||||||
|
if warnings:
|
||||||
|
labels = ", ".join(w.pattern_label for w in warnings)
|
||||||
|
_log.warning("Injection in Telegram message (user %d): %s", user_id, labels)
|
||||||
|
await update.message.reply_text(
|
||||||
|
"Your message was blocked: injection pattern detected."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Load history + system context ─────────────────────────────────────
|
||||||
|
history = _load_history(chat_id)
|
||||||
|
if not history:
|
||||||
|
try:
|
||||||
|
from pyra.memory.reader import load_context_for_session
|
||||||
|
|
||||||
|
ctx_text = load_context_for_session()
|
||||||
|
if ctx_text:
|
||||||
|
history = [{"role": "system", "content": ctx_text}]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
history.append({"role": "user", "content": text})
|
||||||
|
|
||||||
|
try:
|
||||||
|
await ctx.bot.send_chat_action(chat_id=chat_id, action="typing")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ── AI response ───────────────────────────────────────────────────────
|
||||||
|
try:
|
||||||
|
reply = await self._ai_chat(chat_id, history, ctx.bot)
|
||||||
|
except Exception as exc:
|
||||||
|
_log.error("AI error (chat %d): %s", chat_id, exc, exc_info=True)
|
||||||
|
await ctx.bot.send_message(chat_id=chat_id, text=f"AI error: {exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
history.append({"role": "assistant", "content": reply})
|
||||||
|
_save_history(chat_id, history)
|
||||||
|
|
||||||
|
# ── AI streaming + tool-use loop ──────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _ai_chat(self, chat_id: int, messages: list[dict], bot: Bot) -> str:
|
||||||
|
from pyra.config.manager import load_config
|
||||||
|
from pyra.plugins.registry import PluginRegistry
|
||||||
|
from pyra.setup.providers import get_provider
|
||||||
|
from pyra.vault.reader import get_key
|
||||||
|
|
||||||
|
cfg = load_config()
|
||||||
|
provider = get_provider(cfg.ai.provider_id)
|
||||||
|
api_key = get_key(cfg.ai.provider_id) if provider.requires_key else "local"
|
||||||
|
|
||||||
|
call_kwargs: dict[str, Any] = {
|
||||||
|
"model": f"{provider.litellm_prefix}{cfg.ai.model}",
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
||||||
|
base_url = cfg.ai.base_url or provider.base_url
|
||||||
|
if base_url:
|
||||||
|
call_kwargs["api_base"] = base_url
|
||||||
|
|
||||||
|
litellm.suppress_debug_info = True
|
||||||
|
|
||||||
|
registry = PluginRegistry.instance()
|
||||||
|
tools_spec = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": t.name,
|
||||||
|
"description": t.description,
|
||||||
|
"parameters": t.parameters,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for t in registry.get_all_tools()
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mutable state shared with helpers below
|
||||||
|
state: dict[str, Any] = {"msg_id": None, "last_edit": 0.0}
|
||||||
|
|
||||||
|
placeholder = await bot.send_message(chat_id=chat_id, text="…")
|
||||||
|
state["msg_id"] = placeholder.message_id
|
||||||
|
|
||||||
|
async def _update(text: str) -> None:
|
||||||
|
if not state["msg_id"]:
|
||||||
|
return
|
||||||
|
now = time.monotonic()
|
||||||
|
if now - state["last_edit"] < _EDIT_INTERVAL:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await bot.edit_message_text(
|
||||||
|
chat_id=chat_id,
|
||||||
|
message_id=state["msg_id"],
|
||||||
|
text=text[:_MAX_MSG_LEN],
|
||||||
|
)
|
||||||
|
state["last_edit"] = now
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _finalize(text: str) -> None:
|
||||||
|
if not state["msg_id"]:
|
||||||
|
return
|
||||||
|
if text:
|
||||||
|
try:
|
||||||
|
await bot.edit_message_text(
|
||||||
|
chat_id=chat_id,
|
||||||
|
message_id=state["msg_id"],
|
||||||
|
text=text[:_MAX_MSG_LEN],
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
await bot.delete_message(chat_id=chat_id, message_id=state["msg_id"])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
state["msg_id"] = None
|
||||||
|
|
||||||
|
accumulated = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
for _iter in range(_MAX_TOOL_ITER):
|
||||||
|
tool_chunks: dict[int, dict] = {}
|
||||||
|
accumulated = ""
|
||||||
|
|
||||||
|
stream = await litellm.acompletion(
|
||||||
|
**call_kwargs,
|
||||||
|
messages=messages,
|
||||||
|
tools=tools_spec if tools_spec else None,
|
||||||
|
tool_choice="auto" if tools_spec else None,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in stream:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if delta.content:
|
||||||
|
accumulated += delta.content
|
||||||
|
await _update(accumulated)
|
||||||
|
|
||||||
|
if delta.tool_calls:
|
||||||
|
for tc in delta.tool_calls:
|
||||||
|
idx = tc.index
|
||||||
|
if idx not in tool_chunks:
|
||||||
|
tool_chunks[idx] = {"id": tc.id or "", "name": "", "args": ""}
|
||||||
|
if tc.function:
|
||||||
|
if tc.function.name:
|
||||||
|
tool_chunks[idx]["name"] += tc.function.name
|
||||||
|
if tc.function.arguments:
|
||||||
|
tool_chunks[idx]["args"] += tc.function.arguments
|
||||||
|
|
||||||
|
if not tool_chunks:
|
||||||
|
await _finalize(accumulated)
|
||||||
|
return accumulated
|
||||||
|
|
||||||
|
# Show any intermediate prose before tool calls
|
||||||
|
if accumulated:
|
||||||
|
await _finalize(accumulated)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
await bot.delete_message(chat_id=chat_id, message_id=state["msg_id"])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
state["msg_id"] = None
|
||||||
|
|
||||||
|
tool_calls_list = [
|
||||||
|
{
|
||||||
|
"id": data["id"],
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": data["name"], "arguments": data["args"]},
|
||||||
|
}
|
||||||
|
for _, data in sorted(tool_chunks.items())
|
||||||
|
]
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": accumulated or None,
|
||||||
|
"tool_calls": tool_calls_list,
|
||||||
|
})
|
||||||
|
|
||||||
|
for tc in tool_calls_list:
|
||||||
|
result = await self._execute_tool_with_approval(tc, chat_id, bot)
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tc["id"],
|
||||||
|
"content": result,
|
||||||
|
})
|
||||||
|
|
||||||
|
# New placeholder for the next AI response
|
||||||
|
new_ph = await bot.send_message(chat_id=chat_id, text="…")
|
||||||
|
state["msg_id"] = new_ph.message_id
|
||||||
|
state["last_edit"] = 0.0
|
||||||
|
|
||||||
|
except litellm.BadRequestError:
|
||||||
|
# Provider doesn't support tool calls — retry without tools
|
||||||
|
accumulated = ""
|
||||||
|
state["last_edit"] = 0.0
|
||||||
|
stream = await litellm.acompletion(
|
||||||
|
**call_kwargs, messages=messages, stream=True
|
||||||
|
)
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk.choices[0].delta.content:
|
||||||
|
accumulated += chunk.choices[0].delta.content
|
||||||
|
await _update(accumulated)
|
||||||
|
await _finalize(accumulated)
|
||||||
|
return accumulated
|
||||||
|
|
||||||
|
return accumulated or "Error: tool-use loop exceeded maximum iterations."
|
||||||
|
|
||||||
|
# ── Tool approval via inline buttons ──────────────────────────────────────
|
||||||
|
|
||||||
|
async def _execute_tool_with_approval(
|
||||||
|
self, tool_call: dict, chat_id: int, bot: Bot
|
||||||
|
) -> str:
|
||||||
|
from pyra.plugins.registry import PluginRegistry
|
||||||
|
from pyra.security.injection import scan_response
|
||||||
|
|
||||||
|
tool_name = tool_call["function"]["name"]
|
||||||
|
args_raw = tool_call["function"]["arguments"]
|
||||||
|
try:
|
||||||
|
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return f"Error: invalid tool arguments for {tool_name}"
|
||||||
|
|
||||||
|
args_preview = json.dumps(args, indent=2)[:500]
|
||||||
|
call_id = uuid.uuid4().hex[:8]
|
||||||
|
|
||||||
|
keyboard = InlineKeyboardMarkup([[
|
||||||
|
InlineKeyboardButton("✅ Approve", callback_data=f"approve:{call_id}"),
|
||||||
|
InlineKeyboardButton("❌ Deny", callback_data=f"deny:{call_id}"),
|
||||||
|
]])
|
||||||
|
|
||||||
|
await bot.send_message(
|
||||||
|
chat_id=chat_id,
|
||||||
|
text=f"Tool request: {tool_name}\n\n{args_preview}",
|
||||||
|
reply_markup=keyboard,
|
||||||
|
)
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
future: asyncio.Future[bool] = loop.create_future()
|
||||||
|
self._pending_approvals[call_id] = future
|
||||||
|
|
||||||
|
try:
|
||||||
|
approved = await asyncio.wait_for(future, timeout=_APPROVAL_TIMEOUT)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self._pending_approvals.pop(call_id, None)
|
||||||
|
await bot.send_message(
|
||||||
|
chat_id=chat_id, text=f"Tool {tool_name}: approval timed out — denied."
|
||||||
|
)
|
||||||
|
return "Tool execution denied (timeout)."
|
||||||
|
|
||||||
|
if not approved:
|
||||||
|
return "Tool execution denied by user."
|
||||||
|
|
||||||
|
registry = PluginRegistry.instance()
|
||||||
|
tool = registry.find_tool(tool_name)
|
||||||
|
if tool is None:
|
||||||
|
return f"Error: tool '{tool_name}' not found in registry."
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = tool.handler(**args)
|
||||||
|
if not isinstance(result, str):
|
||||||
|
result = str(result)
|
||||||
|
except Exception as exc:
|
||||||
|
return f"Tool error: {exc}"
|
||||||
|
|
||||||
|
injection_warnings = scan_response(result)
|
||||||
|
if injection_warnings:
|
||||||
|
labels = ", ".join(w.pattern_label for w in injection_warnings)
|
||||||
|
await bot.send_message(
|
||||||
|
chat_id=chat_id,
|
||||||
|
text=f"Warning: tool result contains suspicious content ({labels}).",
|
||||||
|
)
|
||||||
|
|
||||||
|
return result[:4000]
|
||||||
|
|
||||||
|
# ── Approval callback ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _handle_approval_callback(self, update: Update) -> None:
|
||||||
|
query = update.callback_query
|
||||||
|
if query is None:
|
||||||
|
return
|
||||||
|
await query.answer()
|
||||||
|
data = query.data or ""
|
||||||
|
if ":" not in data:
|
||||||
|
return
|
||||||
|
action, call_id = data.split(":", 1)
|
||||||
|
future = self._pending_approvals.pop(call_id, None)
|
||||||
|
if future and not future.done():
|
||||||
|
future.set_result(action == "approve")
|
||||||
|
try:
|
||||||
|
await query.edit_message_reply_markup(reply_markup=None)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_plugin() -> TelegramBotPlugin:
|
||||||
|
return TelegramBotPlugin()
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Pyra background daemon package."""
|
"""Pyra background daemon package."""
|
||||||
|
|
||||||
from pyra.daemon.core import PluginSupervisor, run_foreground, start_background
|
from pyra.daemon.core import PluginSupervisor, run_foreground, start_background
|
||||||
from pyra.daemon.events import publish, subscribe_forever
|
|
||||||
from pyra.daemon.ipc import IpcClient, IpcServer, send_command
|
from pyra.daemon.ipc import IpcClient, IpcServer, send_command
|
||||||
from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path
|
from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path
|
||||||
from pyra.daemon.service import detect_platform, install_service, uninstall_service
|
from pyra.daemon.service import detect_platform, install_service, uninstall_service
|
||||||
@@ -10,8 +9,6 @@ __all__ = [
|
|||||||
"run_foreground",
|
"run_foreground",
|
||||||
"start_background",
|
"start_background",
|
||||||
"PluginSupervisor",
|
"PluginSupervisor",
|
||||||
"publish",
|
|
||||||
"subscribe_forever",
|
|
||||||
"IpcClient",
|
"IpcClient",
|
||||||
"IpcServer",
|
"IpcServer",
|
||||||
"send_command",
|
"send_command",
|
||||||
|
|||||||
@@ -1,46 +0,0 @@
|
|||||||
"""Async notification bus for inter-plugin communication in the daemon.
|
|
||||||
|
|
||||||
Plugins publish events to a shared asyncio.Queue; other plugins (e.g. messaging
|
|
||||||
bots) consume them via subscribe_forever(). No direct plugin-to-plugin imports
|
|
||||||
are needed — both sides just use this module.
|
|
||||||
|
|
||||||
Event shape (by convention):
|
|
||||||
{"type": "new_email", "priority": int, "from": str, "subject": str,
|
|
||||||
"summary": str, "uid": str, "folder": str}
|
|
||||||
{"type": "new_message", "bot": str, "user_id": str, "text": str}
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from typing import Any, AsyncGenerator
|
|
||||||
|
|
||||||
_queue: asyncio.Queue[dict[str, Any]] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_queue() -> asyncio.Queue[dict[str, Any]]:
|
|
||||||
global _queue
|
|
||||||
if _queue is None:
|
|
||||||
_queue = asyncio.Queue(maxsize=200)
|
|
||||||
return _queue
|
|
||||||
|
|
||||||
|
|
||||||
async def publish(event: dict[str, Any]) -> None:
|
|
||||||
"""Emit an event. Drops silently if the queue is full (daemon is overloaded)."""
|
|
||||||
q = get_queue()
|
|
||||||
try:
|
|
||||||
q.put_nowait(event)
|
|
||||||
except asyncio.QueueFull:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def subscribe_forever() -> AsyncGenerator[dict[str, Any], None]:
|
|
||||||
"""Async generator — yields events as they arrive. Intended for daemon tasks."""
|
|
||||||
q = get_queue()
|
|
||||||
while True:
|
|
||||||
yield await q.get()
|
|
||||||
|
|
||||||
|
|
||||||
def reset() -> None:
|
|
||||||
"""Discard the current queue and create a fresh one. FOR TESTS ONLY."""
|
|
||||||
global _queue
|
|
||||||
_queue = None
|
|
||||||
@@ -1,384 +0,0 @@
|
|||||||
"""Unit tests for the email plugin — pure-logic helpers, no network calls."""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
# Import helpers directly — they depend only on stdlib
|
|
||||||
from pyra.bundled_plugins.email.plugin import (
|
|
||||||
EmailMessage,
|
|
||||||
FilterRule,
|
|
||||||
_build_imap_search,
|
|
||||||
_decode_header,
|
|
||||||
_gmail_action_summary,
|
|
||||||
_gmail_criteria_summary,
|
|
||||||
_normalize_to_gmail,
|
|
||||||
_normalize_to_outlook,
|
|
||||||
_outlook_actions_summary,
|
|
||||||
_parse_raw_message,
|
|
||||||
_strip_html,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── _strip_html ────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def test_strip_html_removes_tags():
|
|
||||||
result = _strip_html("<p>Hello <b>world</b></p>")
|
|
||||||
assert "<" not in result
|
|
||||||
assert "Hello" in result
|
|
||||||
assert "world" in result
|
|
||||||
|
|
||||||
|
|
||||||
def test_strip_html_decodes_entities():
|
|
||||||
result = _strip_html("<script> & "test"")
|
|
||||||
assert "<script>" in result
|
|
||||||
assert "&" in result
|
|
||||||
|
|
||||||
|
|
||||||
def test_strip_html_removes_style_and_script():
|
|
||||||
html = "<style>body{color:red}</style><script>alert(1)</script><p>Keep this</p>"
|
|
||||||
result = _strip_html(html)
|
|
||||||
assert "color" not in result
|
|
||||||
assert "alert" not in result
|
|
||||||
assert "Keep this" in result
|
|
||||||
|
|
||||||
|
|
||||||
def test_strip_html_plain_text_unchanged():
|
|
||||||
result = _strip_html("Hello, world!")
|
|
||||||
assert result == "Hello, world!"
|
|
||||||
|
|
||||||
|
|
||||||
# ── _decode_header ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def test_decode_header_plain():
|
|
||||||
assert _decode_header("Hello") == "Hello"
|
|
||||||
|
|
||||||
|
|
||||||
def test_decode_header_encoded():
|
|
||||||
# RFC 2047 base64-encoded UTF-8
|
|
||||||
encoded = "=?utf-8?b?SGVsbG8gV29ybGQ=?="
|
|
||||||
assert _decode_header(encoded) == "Hello World"
|
|
||||||
|
|
||||||
|
|
||||||
def test_decode_header_empty():
|
|
||||||
assert _decode_header("") == ""
|
|
||||||
|
|
||||||
|
|
||||||
# ── _parse_raw_message ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _make_raw_email(
|
|
||||||
from_addr: str = "sender@example.com",
|
|
||||||
to_addr: str = "recipient@example.com",
|
|
||||||
subject: str = "Test Subject",
|
|
||||||
body: str = "Hello from test.",
|
|
||||||
message_id: str = "<test123@example.com>",
|
|
||||||
) -> bytes:
|
|
||||||
return (
|
|
||||||
f"From: {from_addr}\r\n"
|
|
||||||
f"To: {to_addr}\r\n"
|
|
||||||
f"Subject: {subject}\r\n"
|
|
||||||
f"Date: Mon, 01 Jan 2024 12:00:00 +0000\r\n"
|
|
||||||
f"Message-ID: {message_id}\r\n"
|
|
||||||
f"MIME-Version: 1.0\r\n"
|
|
||||||
f"Content-Type: text/plain; charset=utf-8\r\n"
|
|
||||||
f"\r\n"
|
|
||||||
f"{body}\r\n"
|
|
||||||
).encode()
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_raw_message_basic_fields():
|
|
||||||
raw = _make_raw_email()
|
|
||||||
msg = _parse_raw_message(raw, uid="42", folder="INBOX", is_read=False)
|
|
||||||
assert msg.uid == "42"
|
|
||||||
assert msg.folder == "INBOX"
|
|
||||||
assert msg.from_addr == "sender@example.com"
|
|
||||||
assert "recipient@example.com" in msg.to_addrs
|
|
||||||
assert msg.subject == "Test Subject"
|
|
||||||
assert msg.body_text == "Hello from test."
|
|
||||||
assert msg.is_read is False
|
|
||||||
assert msg.has_attachments is False
|
|
||||||
assert msg.attachments == []
|
|
||||||
assert msg.message_id == "<test123@example.com>"
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_raw_message_snippet_truncated():
|
|
||||||
long_body = "A" * 500
|
|
||||||
raw = _make_raw_email(body=long_body)
|
|
||||||
msg = _parse_raw_message(raw, uid="1", folder="INBOX", is_read=True)
|
|
||||||
assert len(msg.snippet) <= 200
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_raw_message_body_truncated_at_8000():
|
|
||||||
huge_body = "x" * 10000
|
|
||||||
raw = _make_raw_email(body=huge_body)
|
|
||||||
msg = _parse_raw_message(raw, uid="1", folder="INBOX", is_read=False)
|
|
||||||
assert len(msg.body_text) <= 8030 # 8000 + "[...truncated]"
|
|
||||||
assert "truncated" in msg.body_text
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_raw_message_html_stripped():
|
|
||||||
raw = _make_raw_email(body="<html><body><p>Plain text content</p></body></html>")
|
|
||||||
# Create HTML part manually
|
|
||||||
html_raw = (
|
|
||||||
"From: a@b.com\r\nTo: c@d.com\r\nSubject: Test\r\n"
|
|
||||||
"MIME-Version: 1.0\r\nContent-Type: text/html; charset=utf-8\r\n\r\n"
|
|
||||||
"<html><body><p>Plain text content</p></body></html>\r\n"
|
|
||||||
).encode()
|
|
||||||
msg = _parse_raw_message(html_raw, uid="1", folder="INBOX", is_read=False)
|
|
||||||
assert "<" not in msg.body_text
|
|
||||||
assert "Plain text content" in msg.body_text
|
|
||||||
|
|
||||||
|
|
||||||
# ── _build_imap_search ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def test_build_imap_search_unread():
|
|
||||||
from imap_tools import AND
|
|
||||||
criteria = _build_imap_search("unread invoices")
|
|
||||||
# Should produce an AND with seen=False
|
|
||||||
assert criteria is not None
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_imap_search_from():
|
|
||||||
criteria = _build_imap_search("from:boss@company.com")
|
|
||||||
assert criteria is not None
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_imap_search_subject():
|
|
||||||
criteria = _build_imap_search("subject: meeting notes")
|
|
||||||
assert criteria is not None
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_imap_search_fallback():
|
|
||||||
criteria = _build_imap_search("random search terms")
|
|
||||||
assert criteria is not None
|
|
||||||
|
|
||||||
|
|
||||||
# ── Gmail rule normalisation ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def test_normalize_to_gmail_from_condition():
|
|
||||||
criteria, action = _normalize_to_gmail({"from": "boss@company.com"}, {"mark_read": True})
|
|
||||||
assert criteria.get("from") == "boss@company.com"
|
|
||||||
assert "UNREAD" in action.get("removeLabelIds", [])
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_to_gmail_move_to():
|
|
||||||
criteria, action = _normalize_to_gmail({"subject": "invoice"}, {"move_to": "Bills"})
|
|
||||||
assert criteria.get("subject") == "invoice"
|
|
||||||
assert "Bills" in action.get("addLabelIds", [])
|
|
||||||
assert "INBOX" in action.get("removeLabelIds", [])
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_to_gmail_mark_important():
|
|
||||||
_, action = _normalize_to_gmail({}, {"mark_important": True})
|
|
||||||
assert "IMPORTANT" in action.get("addLabelIds", [])
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_to_gmail_forward():
|
|
||||||
_, action = _normalize_to_gmail({}, {"forward_to": "archive@example.com"})
|
|
||||||
assert action.get("forward") == "archive@example.com"
|
|
||||||
|
|
||||||
|
|
||||||
def test_gmail_criteria_summary_empty():
|
|
||||||
assert _gmail_criteria_summary({}) == "(any)"
|
|
||||||
|
|
||||||
|
|
||||||
def test_gmail_criteria_summary_from():
|
|
||||||
assert "from=boss" in _gmail_criteria_summary({"from": "boss@company.com"})
|
|
||||||
|
|
||||||
|
|
||||||
def test_gmail_action_summary_empty():
|
|
||||||
assert _gmail_action_summary({}) == "(no action)"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Outlook rule normalisation ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def test_normalize_to_outlook_from():
|
|
||||||
body = _normalize_to_outlook({"from": "a@b.com"}, {"move_to": "Work"})
|
|
||||||
from_addrs = body["conditions"].get("fromAddresses", [])
|
|
||||||
assert any("a@b.com" in str(a) for a in from_addrs)
|
|
||||||
assert body["actions"].get("moveToFolder") == "Work"
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_to_outlook_subject_contains():
|
|
||||||
body = _normalize_to_outlook({"subject": "invoice"}, {"mark_read": True})
|
|
||||||
assert "invoice" in body["conditions"].get("subjectContains", [])
|
|
||||||
assert body["actions"].get("markAsRead") is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_to_outlook_mark_important():
|
|
||||||
body = _normalize_to_outlook({}, {"mark_important": True})
|
|
||||||
assert body["actions"].get("markImportance") == "high"
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_to_outlook_delete():
|
|
||||||
body = _normalize_to_outlook({}, {"delete": True})
|
|
||||||
assert body["actions"].get("delete") is True
|
|
||||||
|
|
||||||
|
|
||||||
# ── email_move folder-not-found path ──────────────────────────────────────────
|
|
||||||
|
|
||||||
def test_email_move_returns_error_when_folder_missing(tmp_pyra_home):
|
|
||||||
from pyra.bundled_plugins.email.plugin import EmailPlugin
|
|
||||||
|
|
||||||
plugin = EmailPlugin()
|
|
||||||
|
|
||||||
# Inject a mock provider with known folders
|
|
||||||
mock_provider = MagicMock()
|
|
||||||
mock_provider.list_folders.return_value = ["INBOX", "Sent", "Trash"]
|
|
||||||
plugin._provider_instance = mock_provider
|
|
||||||
|
|
||||||
result = plugin._tool_move("uid123", "NonExistent", "INBOX")
|
|
||||||
|
|
||||||
assert "does not exist" in result.lower()
|
|
||||||
assert "email_create_folder" in result
|
|
||||||
mock_provider.move_message.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
def test_email_move_succeeds_when_folder_exists(tmp_pyra_home):
|
|
||||||
from pyra.bundled_plugins.email.plugin import EmailPlugin
|
|
||||||
|
|
||||||
plugin = EmailPlugin()
|
|
||||||
|
|
||||||
mock_provider = MagicMock()
|
|
||||||
mock_provider.list_folders.return_value = ["INBOX", "Work", "Newsletters"]
|
|
||||||
plugin._provider_instance = mock_provider
|
|
||||||
|
|
||||||
result = plugin._tool_move("uid456", "Work", "INBOX")
|
|
||||||
|
|
||||||
assert "moved" in result.lower()
|
|
||||||
mock_provider.move_message.assert_called_once_with("uid456", "INBOX", "Work")
|
|
||||||
|
|
||||||
|
|
||||||
# ── email_list_rules not-supported path ───────────────────────────────────────
|
|
||||||
|
|
||||||
def test_email_list_rules_not_supported(tmp_pyra_home):
|
|
||||||
from pyra.bundled_plugins.email.plugin import EmailPlugin
|
|
||||||
|
|
||||||
plugin = EmailPlugin()
|
|
||||||
mock_provider = MagicMock()
|
|
||||||
mock_provider.list_rules.side_effect = NotImplementedError
|
|
||||||
plugin._provider_instance = mock_provider
|
|
||||||
|
|
||||||
result = plugin._tool_list_rules()
|
|
||||||
assert "not supported" in result.lower()
|
|
||||||
|
|
||||||
|
|
||||||
# ── daemon/events integration ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_events_publish_and_subscribe():
|
|
||||||
from pyra.daemon import events
|
|
||||||
events.reset()
|
|
||||||
|
|
||||||
await events.publish({"type": "new_email", "subject": "Test"})
|
|
||||||
|
|
||||||
received = []
|
|
||||||
async for event in events.subscribe_forever():
|
|
||||||
received.append(event)
|
|
||||||
break # only need one
|
|
||||||
|
|
||||||
assert received[0]["type"] == "new_email"
|
|
||||||
assert received[0]["subject"] == "Test"
|
|
||||||
events.reset()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_events_queue_full_drops_silently():
|
|
||||||
from pyra.daemon import events
|
|
||||||
events.reset()
|
|
||||||
|
|
||||||
# Fill the queue
|
|
||||||
for i in range(200):
|
|
||||||
await events.publish({"n": i})
|
|
||||||
|
|
||||||
# This should not raise even though queue is full
|
|
||||||
await events.publish({"n": 999})
|
|
||||||
|
|
||||||
events.reset()
|
|
||||||
|
|
||||||
|
|
||||||
# ── ProtonMail Bridge connectivity check (mocked) ─────────────────────────────
|
|
||||||
|
|
||||||
def test_protonmail_setup_aborts_when_bridge_unreachable(tmp_pyra_home):
|
|
||||||
"""_setup_protonmail should abort gracefully when Bridge is not running."""
|
|
||||||
import socket
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
from pyra.bundled_plugins.email.plugin import EmailPlugin
|
|
||||||
|
|
||||||
plugin = EmailPlugin()
|
|
||||||
console = MagicMock()
|
|
||||||
vault_writer = MagicMock()
|
|
||||||
|
|
||||||
with patch("socket.create_connection", side_effect=ConnectionRefusedError):
|
|
||||||
plugin._setup_protonmail(console, vault_writer, "user@proton.me")
|
|
||||||
|
|
||||||
# Should not store any vault key if Bridge is unreachable
|
|
||||||
vault_writer.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
# ── messaging bot recommendation ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
def test_check_messaging_bot_warns_when_no_bot(tmp_pyra_home):
|
|
||||||
from pyra.bundled_plugins.email.plugin import EmailPlugin
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
from pyra.config.schema import PyraConfig, ProviderConfig, PluginConfig
|
|
||||||
|
|
||||||
plugin = EmailPlugin()
|
|
||||||
console = MagicMock()
|
|
||||||
|
|
||||||
cfg = PyraConfig(ai=ProviderConfig(provider_id="lmstudio", model="test"))
|
|
||||||
cfg.plugins = PluginConfig(enabled=[]) # no bots
|
|
||||||
|
|
||||||
with patch("pyra.bundled_plugins.email.plugin.EmailPlugin._load_settings", return_value={}), \
|
|
||||||
patch("pyra.config.manager.load_config", return_value=cfg):
|
|
||||||
plugin._check_messaging_bot(console)
|
|
||||||
|
|
||||||
# Should have printed something (Panel) recommending a bot
|
|
||||||
console.print.assert_called()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tool list completeness ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def test_plugin_exposes_16_tools():
|
|
||||||
from pyra.bundled_plugins.email.plugin import EmailPlugin
|
|
||||||
plugin = EmailPlugin()
|
|
||||||
# on_load with no-op vault reader
|
|
||||||
plugin.on_load(lambda _: None)
|
|
||||||
tools = plugin.tools()
|
|
||||||
tool_names = [t.name for t in tools]
|
|
||||||
assert len(tools) == 16
|
|
||||||
|
|
||||||
expected = {
|
|
||||||
"email_list_folder", "email_read", "email_send", "email_reply",
|
|
||||||
"email_forward", "email_move", "email_delete", "email_mark_read",
|
|
||||||
"email_search", "email_list_folders", "email_create_folder",
|
|
||||||
"email_inbox_summary", "email_list_rules", "email_create_rule",
|
|
||||||
"email_delete_rule", "email_bulk_action",
|
|
||||||
}
|
|
||||||
assert set(tool_names) == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_write_tools_require_approval():
|
|
||||||
from pyra.bundled_plugins.email.plugin import EmailPlugin
|
|
||||||
plugin = EmailPlugin()
|
|
||||||
plugin.on_load(lambda _: None)
|
|
||||||
tools = {t.name: t for t in plugin.tools()}
|
|
||||||
|
|
||||||
for name in ["email_send", "email_reply", "email_forward", "email_move",
|
|
||||||
"email_delete", "email_create_folder", "email_create_rule",
|
|
||||||
"email_delete_rule", "email_bulk_action"]:
|
|
||||||
assert tools[name].requires_approval, f"{name} should require approval"
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_tools_no_approval():
|
|
||||||
from pyra.bundled_plugins.email.plugin import EmailPlugin
|
|
||||||
plugin = EmailPlugin()
|
|
||||||
plugin.on_load(lambda _: None)
|
|
||||||
tools = {t.name: t for t in plugin.tools()}
|
|
||||||
|
|
||||||
for name in ["email_list_folder", "email_read", "email_mark_read",
|
|
||||||
"email_search", "email_list_folders", "email_inbox_summary",
|
|
||||||
"email_list_rules"]:
|
|
||||||
assert not tools[name].requires_approval, f"{name} should NOT require approval"
|
|
||||||
@@ -0,0 +1,236 @@
|
|||||||
|
"""Unit tests for the telegram_bot bundled plugin.
|
||||||
|
|
||||||
|
Tests cover pure-logic helpers: rate limiter, SQLite history, auth flow,
|
||||||
|
and tool argument parsing. Handler integration (live Telegram) is not tested.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# ── Import helpers straight from the bundled source ──────────────────────────
|
||||||
|
|
||||||
|
_PLUGIN_PATH = (
|
||||||
|
Path(__file__).parent.parent.parent
|
||||||
|
/ "src/pyra/bundled_plugins/telegram_bot/plugin.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_plugin_module():
|
||||||
|
"""Load plugin.py without requiring python-telegram-bot to be installed."""
|
||||||
|
# Provide stub modules so the top-level imports don't fail in unit tests
|
||||||
|
for mod_name in [
|
||||||
|
"bcrypt",
|
||||||
|
"telegram",
|
||||||
|
"telegram.ext",
|
||||||
|
"litellm",
|
||||||
|
]:
|
||||||
|
if mod_name not in sys.modules:
|
||||||
|
stub = MagicMock()
|
||||||
|
sys.modules[mod_name] = stub
|
||||||
|
# telegram.ext sub-attributes needed at import time
|
||||||
|
if mod_name == "telegram.ext":
|
||||||
|
stub.Application = MagicMock()
|
||||||
|
stub.CallbackQueryHandler = MagicMock()
|
||||||
|
stub.CommandHandler = MagicMock()
|
||||||
|
stub.ContextTypes = MagicMock()
|
||||||
|
stub.MessageHandler = MagicMock()
|
||||||
|
stub.filters = MagicMock()
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location("pyra_plugin_telegram_bot", _PLUGIN_PATH)
|
||||||
|
assert spec and spec.loader
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(mod) # type: ignore[union-attr]
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
_mod = _load_plugin_module()
|
||||||
|
_RateLimiter = _mod._RateLimiter
|
||||||
|
_load_history = _mod._load_history
|
||||||
|
_save_history = _mod._save_history
|
||||||
|
TelegramBotPlugin = _mod.TelegramBotPlugin
|
||||||
|
|
||||||
|
|
||||||
|
# ── Rate limiter ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestRateLimiter:
|
||||||
|
def test_allows_up_to_limit(self):
|
||||||
|
rl = _RateLimiter(per_hour=5)
|
||||||
|
for _ in range(5):
|
||||||
|
assert rl.allow(user_id=1)
|
||||||
|
|
||||||
|
def test_blocks_over_limit(self):
|
||||||
|
rl = _RateLimiter(per_hour=3)
|
||||||
|
for _ in range(3):
|
||||||
|
rl.allow(user_id=1)
|
||||||
|
assert not rl.allow(user_id=1)
|
||||||
|
|
||||||
|
def test_independent_per_user(self):
|
||||||
|
rl = _RateLimiter(per_hour=2)
|
||||||
|
rl.allow(1)
|
||||||
|
rl.allow(1)
|
||||||
|
assert not rl.allow(1)
|
||||||
|
assert rl.allow(2)
|
||||||
|
|
||||||
|
def test_old_entries_expire(self):
|
||||||
|
rl = _RateLimiter(per_hour=1)
|
||||||
|
# Manually populate bucket with an old timestamp
|
||||||
|
from collections import deque
|
||||||
|
rl._buckets[99] = deque([time.monotonic() - 3601])
|
||||||
|
assert rl.allow(99) # old entry removed, new one fits
|
||||||
|
|
||||||
|
def test_multiple_users_independent(self):
|
||||||
|
rl = _RateLimiter(per_hour=1)
|
||||||
|
rl.allow(10)
|
||||||
|
assert not rl.allow(10)
|
||||||
|
assert rl.allow(20)
|
||||||
|
assert rl.allow(30)
|
||||||
|
|
||||||
|
|
||||||
|
# ── History DB ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestHistoryDB:
|
||||||
|
def test_empty_history(self, tmp_path, monkeypatch):
|
||||||
|
db_path = tmp_path / "tg.db"
|
||||||
|
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||||
|
assert _load_history(42) == []
|
||||||
|
|
||||||
|
def test_save_and_load(self, tmp_path, monkeypatch):
|
||||||
|
db_path = tmp_path / "tg.db"
|
||||||
|
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
{"role": "assistant", "content": "hi there"},
|
||||||
|
]
|
||||||
|
_save_history(42, msgs)
|
||||||
|
assert _load_history(42) == msgs
|
||||||
|
|
||||||
|
def test_save_trims_to_max(self, tmp_path, monkeypatch):
|
||||||
|
db_path = tmp_path / "tg.db"
|
||||||
|
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||||
|
monkeypatch.setattr(_mod, "_MAX_HISTORY", 3)
|
||||||
|
msgs = [{"role": "user", "content": str(i)} for i in range(10)]
|
||||||
|
_save_history(1, msgs)
|
||||||
|
loaded = _load_history(1)
|
||||||
|
assert len(loaded) == 3
|
||||||
|
# Should keep the most recent messages
|
||||||
|
assert loaded[-1]["content"] == "9"
|
||||||
|
|
||||||
|
def test_overwrite(self, tmp_path, monkeypatch):
|
||||||
|
db_path = tmp_path / "tg.db"
|
||||||
|
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||||
|
_save_history(1, [{"role": "user", "content": "first"}])
|
||||||
|
_save_history(1, [{"role": "user", "content": "second"}])
|
||||||
|
assert _load_history(1)[0]["content"] == "second"
|
||||||
|
|
||||||
|
def test_separate_chats(self, tmp_path, monkeypatch):
|
||||||
|
db_path = tmp_path / "tg.db"
|
||||||
|
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||||
|
_save_history(1, [{"role": "user", "content": "chat1"}])
|
||||||
|
_save_history(2, [{"role": "user", "content": "chat2"}])
|
||||||
|
assert _load_history(1)[0]["content"] == "chat1"
|
||||||
|
assert _load_history(2)[0]["content"] == "chat2"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Plugin: on_load and setup ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestPluginLifecycle:
|
||||||
|
def test_on_load_stores_vault_reader(self):
|
||||||
|
plugin = TelegramBotPlugin()
|
||||||
|
reader = MagicMock(return_value=None)
|
||||||
|
plugin.on_load(reader)
|
||||||
|
assert plugin._vault_reader is reader
|
||||||
|
|
||||||
|
def test_daemon_tasks_returns_coroutine(self):
|
||||||
|
import inspect
|
||||||
|
plugin = TelegramBotPlugin()
|
||||||
|
plugin.on_load(MagicMock(return_value=None))
|
||||||
|
tasks = plugin.daemon_tasks()
|
||||||
|
assert len(tasks) == 1
|
||||||
|
assert inspect.iscoroutine(tasks[0])
|
||||||
|
tasks[0].close() # prevent "coroutine never awaited" warning
|
||||||
|
|
||||||
|
def test_get_plugin_factory(self):
|
||||||
|
plugin = _mod.get_plugin()
|
||||||
|
assert isinstance(plugin, TelegramBotPlugin)
|
||||||
|
assert plugin.name == "telegram_bot"
|
||||||
|
|
||||||
|
def test_config_fields(self):
|
||||||
|
plugin = TelegramBotPlugin()
|
||||||
|
fields = plugin.config_fields()
|
||||||
|
assert any(f.key == "rate_limit" for f in fields)
|
||||||
|
|
||||||
|
def test_setup_mismatched_passphrase(self, capsys):
|
||||||
|
"""setup() exits cleanly when passphrases don't match."""
|
||||||
|
plugin = TelegramBotPlugin()
|
||||||
|
console = MagicMock()
|
||||||
|
vault_writer = MagicMock()
|
||||||
|
|
||||||
|
answers = iter(["fake-token", "user123", "pass1", "pass2"])
|
||||||
|
with patch("questionary.password", side_effect=lambda *a, **kw: MagicMock(ask=lambda: next(answers))), \
|
||||||
|
patch("questionary.text", return_value=MagicMock(ask=lambda: next(answers))):
|
||||||
|
plugin.setup(console, vault_writer)
|
||||||
|
|
||||||
|
vault_writer.assert_not_called()
|
||||||
|
console.print.assert_called_with("[red]Passphrases do not match.[/red]")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Auth session state ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestAuthState:
|
||||||
|
def test_session_defaults(self):
|
||||||
|
plugin = TelegramBotPlugin()
|
||||||
|
session = plugin._sessions.setdefault(
|
||||||
|
99, {"authenticated": False, "awaiting_passphrase": False, "attempts": 0}
|
||||||
|
)
|
||||||
|
assert not session["authenticated"]
|
||||||
|
assert not session["awaiting_passphrase"]
|
||||||
|
assert session["attempts"] == 0
|
||||||
|
|
||||||
|
def test_session_per_chat(self):
|
||||||
|
plugin = TelegramBotPlugin()
|
||||||
|
plugin._sessions[1] = {"authenticated": True, "awaiting_passphrase": False, "attempts": 0}
|
||||||
|
plugin._sessions[2] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 1}
|
||||||
|
assert plugin._sessions[1]["authenticated"]
|
||||||
|
assert not plugin._sessions[2]["authenticated"]
|
||||||
|
assert plugin._sessions[2]["attempts"] == 1
|
||||||
|
|
||||||
|
def test_session_auth_flag(self):
|
||||||
|
plugin = TelegramBotPlugin()
|
||||||
|
plugin._sessions[5] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 0}
|
||||||
|
# Simulate successful auth
|
||||||
|
plugin._sessions[5]["authenticated"] = True
|
||||||
|
plugin._sessions[5]["awaiting_passphrase"] = False
|
||||||
|
assert plugin._sessions[5]["authenticated"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tool argument parsing ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestToolArgParsing:
|
||||||
|
def test_valid_json_string(self):
|
||||||
|
args_raw = '{"query": "hello"}'
|
||||||
|
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||||
|
assert args == {"query": "hello"}
|
||||||
|
|
||||||
|
def test_dict_passthrough(self):
|
||||||
|
args_raw = {"query": "hello"}
|
||||||
|
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||||
|
assert args == {"query": "hello"}
|
||||||
|
|
||||||
|
def test_invalid_json_caught(self):
|
||||||
|
args_raw = "not json"
|
||||||
|
try:
|
||||||
|
json.loads(args_raw)
|
||||||
|
assert False, "Should have raised"
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_result_truncated_to_4000(self):
|
||||||
|
long_result = "x" * 5000
|
||||||
|
assert len(long_result[:4000]) == 4000
|
||||||
Reference in New Issue
Block a user