Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| aba28293b7 | |||
| f59aa1a758 | |||
| 3f30b782d2 | |||
| bde0856979 | |||
| 4744cf819b | |||
| 1d5d0387d9 | |||
| db6ca6ee57 | |||
| cc24257ab0 | |||
| 68f9007ef0 | |||
| c41ad0afc6 | |||
| 3d3ce694b9 | |||
| d42b8b4a47 | |||
| 513871ef96 | |||
| eaed52006f | |||
| 0e052c4992 | |||
| 40aa934431 | |||
| 833d1445f0 | |||
| cb390ad6af | |||
| 01655124b5 | |||
| b3851a2715 | |||
| 019e8044a9 |
@@ -8,7 +8,8 @@ a plugin/integration system (Stage 2+) and an encrypted vault (Stage 3+).
|
|||||||
## Current Status
|
## Current Status
|
||||||
|
|
||||||
**Stage 3 — Memory Database: complete** (2026-05-18)
|
**Stage 3 — Memory Database: complete** (2026-05-18)
|
||||||
Next: Stage 4 — Vault Encryption
|
**Stage 6 — Daemon infrastructure: in progress** (`feat/daemon` branch)
|
||||||
|
Next: Stage 4 — Vault Encryption (skipped for now); messaging bots (Stage 6 remainder)
|
||||||
|
|
||||||
## Project Roadmap
|
## Project Roadmap
|
||||||
|
|
||||||
@@ -19,11 +20,11 @@ memory in `~/.pyra/memory/`, and hard security boundaries around the vault.
|
|||||||
### Stage 2 — Plugin Framework ✅ COMPLETE
|
### Stage 2 — Plugin Framework ✅ COMPLETE
|
||||||
- `src/pyra/plugins/` package: `base.py`, `loader.py`, `registry.py`, `executor.py`, `install.py`
|
- `src/pyra/plugins/` package: `base.py`, `loader.py`, `registry.py`, `executor.py`, `install.py`
|
||||||
- `src/pyra/bundled_plugins/` — ships bundled plugin scripts with pyra
|
- `src/pyra/bundled_plugins/` — ships bundled plugin scripts with pyra
|
||||||
- `src/pyra/daemon/` stub (CLI surface only)
|
- `src/pyra/daemon/` stub (CLI surface only; daemon itself is Stage 6)
|
||||||
- Config: `PluginConfig` + `DaemonConfig` added to `PyraConfig`
|
- Config: `PluginConfig` + `DaemonConfig` added to `PyraConfig`
|
||||||
- Bootstrap: `~/.pyra/plugins/` and `~/.pyra/logs/` created on startup
|
- Bootstrap: `~/.pyra/plugins/` and `~/.pyra/logs/` created on startup
|
||||||
- Chat session: AI tool-use loop (up to 10 iterations), approval gate, plugin slash commands
|
- Chat session: AI tool-use loop (up to 10 iterations), approval gate, plugin slash commands
|
||||||
- CLI: `pyra plugin list/install/enable/disable/setup`, `pyra daemon *` stubs
|
- CLI: `pyra plugin list/install/enable/disable/setup`, `pyra daemon *` (stubs at Stage 2; implemented in Stage 6)
|
||||||
|
|
||||||
### Stage 3 — Memory Database ✅ COMPLETE
|
### Stage 3 — Memory Database ✅ COMPLETE
|
||||||
- `src/pyra/memory/database.py`: SQLite + FTS5 via `memory_meta` + `memory_fts` tables
|
- `src/pyra/memory/database.py`: SQLite + FTS5 via `memory_meta` + `memory_fts` tables
|
||||||
@@ -99,6 +100,7 @@ the vault under namespaced keys (`plugin:{name}:{key}`).
|
|||||||
| `config/manager.py` | ruamel.yaml round-trip config read/write, chmod 600 enforced |
|
| `config/manager.py` | ruamel.yaml round-trip config read/write, chmod 600 enforced |
|
||||||
| `config/dirs.py` | `bootstrap()` — creates `~/.pyra/` tree, checks vault sentinel every startup |
|
| `config/dirs.py` | `bootstrap()` — creates `~/.pyra/` tree, checks vault sentinel every startup |
|
||||||
| `chat/session.py` | prompt_toolkit REPL loop, AI tool-use loop, plugin slash commands |
|
| `chat/session.py` | prompt_toolkit REPL loop, AI tool-use loop, plugin slash commands |
|
||||||
|
| `chat/planner.py` | `TaskPlanner` — multi-step plan approval loop, per-step AI execution and verification |
|
||||||
| `chat/renderer.py` | Streaming + non-streaming markdown via rich, injection warning panel |
|
| `chat/renderer.py` | Streaming + non-streaming markdown via rich, injection warning panel |
|
||||||
| `chat/history.py` | Conversation list, token budget trimming, tool message support |
|
| `chat/history.py` | Conversation list, token budget trimming, tool message support |
|
||||||
| `memory/database.py` | SQLite+FTS5 — `init_db()`, `upsert()`, `remove()`, `search()`, `list_all()`, `migrate_from_files()` |
|
| `memory/database.py` | SQLite+FTS5 — `init_db()`, `upsert()`, `remove()`, `search()`, `list_all()`, `migrate_from_files()` |
|
||||||
@@ -116,7 +118,11 @@ the vault under namespaced keys (`plugin:{name}:{key}`).
|
|||||||
| `plugins/executor.py` | Approval gate: scan args → prompt → execute → scan result → log |
|
| `plugins/executor.py` | Approval gate: scan args → prompt → execute → scan result → log |
|
||||||
| `plugins/install.py` | Copies bundled plugins to `~/.pyra/plugins/` |
|
| `plugins/install.py` | Copies bundled plugins to `~/.pyra/plugins/` |
|
||||||
| `bundled_plugins/` | Standalone plugin scripts shipped with pyra (installed on demand) |
|
| `bundled_plugins/` | Standalone plugin scripts shipped with pyra (installed on demand) |
|
||||||
| `daemon/__init__.py` | Daemon package stub (implementation in Stage 2.4) |
|
| `daemon/pid.py` | Atomic PID file — write, read, stale detection (POSIX + Windows), context manager |
|
||||||
|
| `daemon/ipc.py` | IPC transport — Unix socket chmod 600 + UID-check (Linux/macOS) or TCP loopback + port file (Windows); newline-delimited JSON protocol |
|
||||||
|
| `daemon/service.py` | OS service file generation + install/uninstall — launchd plist (macOS), systemd user unit (Linux), schtasks XML (Windows) |
|
||||||
|
| `daemon/core.py` | asyncio event loop entry point, `PluginSupervisor` (per-task restart, max 10×, 5s back-off, reload), IPC command dispatch, signal handling |
|
||||||
|
| `daemon/__init__.py` | Public daemon API exports |
|
||||||
|
|
||||||
### Runtime: `~/.pyra/`
|
### Runtime: `~/.pyra/`
|
||||||
|
|
||||||
@@ -243,7 +249,7 @@ uv pip install -e ".[all-plugins]" # Everything
|
|||||||
## Running Tests
|
## Running Tests
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pytest tests/ -v # all unit + security tests (161 tests)
|
pytest tests/ -v # all unit + security tests
|
||||||
pytest tests/integration/test_lmstudio.py # requires LM Studio at localhost:1234
|
pytest tests/integration/test_lmstudio.py # requires LM Studio at localhost:1234
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -492,6 +498,40 @@ Dataclass: `MemoryFile(name, path, category, size_bytes, modified)`
|
|||||||
| `list_bundled_plugins` | `(bundled_dir: Path) -> list[str]` | Names of all bundled plugins that have a `manifest.json` |
|
| `list_bundled_plugins` | `(bundled_dir: Path) -> list[str]` | Names of all bundled plugins that have a `manifest.json` |
|
||||||
| `read_manifest` | `(plugin_dir: Path) -> dict` | Reads `manifest.json`; returns `{}` if missing |
|
| `read_manifest` | `(plugin_dir: Path) -> dict` | Reads `manifest.json`; returns `{}` if missing |
|
||||||
|
|
||||||
|
#### `daemon.core`
|
||||||
|
|
||||||
|
| Function | Signature | Purpose |
|
||||||
|
|----------|-----------|---------|
|
||||||
|
| `run_foreground` | `() -> None` | Entry point for `pyra daemon run` — loads config + plugins, writes PID file, runs asyncio loop |
|
||||||
|
| `start_background` | `() -> None` | Spawns `pyra daemon run` as a detached subprocess (`start_new_session` on POSIX, `DETACHED_PROCESS` on Windows) |
|
||||||
|
|
||||||
|
#### `daemon.pid`
|
||||||
|
|
||||||
|
| Function | Signature | Purpose |
|
||||||
|
|----------|-----------|---------|
|
||||||
|
| `resolve_pid_path` | `(cfg_path: str) -> Path` | Expand `~` and resolve to absolute Path |
|
||||||
|
|
||||||
|
#### `daemon.ipc`
|
||||||
|
|
||||||
|
| Function | Signature | Purpose |
|
||||||
|
|----------|-----------|---------|
|
||||||
|
| `send_command` | `(address, msg, timeout=5.0) -> IpcResponse` | Synchronous CLI helper — `asyncio.run(IpcClient.send(...))` |
|
||||||
|
| `get_socket_path` | `(cfg: str) -> Path` | Expand `~` and return Unix socket path |
|
||||||
|
| `is_unix_socket` | `() -> bool` | True on Linux/macOS (`sys.platform != 'nt'`) |
|
||||||
|
| `get_port_file_path` | `() -> Path` | Path to `~/.pyra/daemon.port` (Windows TCP port file) |
|
||||||
|
|
||||||
|
#### `daemon.service`
|
||||||
|
|
||||||
|
| Function | Signature | Purpose |
|
||||||
|
|----------|-----------|---------|
|
||||||
|
| `detect_platform` | `() -> Literal["macos","linux","windows"]` | Detect current OS |
|
||||||
|
| `find_pyra_executable` | `() -> str` | `shutil.which("pyra")` → sibling fallback → `sys.executable -m pyra` |
|
||||||
|
| `install_service` | `() -> None` | Generate + register OS service (reads config for log/pid paths) |
|
||||||
|
| `uninstall_service` | `() -> None` | Deregister OS service |
|
||||||
|
| `render_launchd_plist` | `(exe, log_file, pid_file) -> str` | macOS plist template |
|
||||||
|
| `render_systemd_unit` | `(exe, log_file) -> str` | Linux systemd unit template |
|
||||||
|
| `render_schtasks_xml` | `(exe) -> str` | Windows Task Scheduler XML template (write as UTF-16) |
|
||||||
|
|
||||||
#### `chat.renderer` — rendering functions and shared `console`
|
#### `chat.renderer` — rendering functions and shared `console`
|
||||||
|
|
||||||
Import `console` from here; do not create a second `rich.Console()` in new code.
|
Import `console` from here; do not create a second `rich.Console()` in new code.
|
||||||
@@ -514,7 +554,7 @@ Import `console` from here; do not create a second `rich.Console()` in new code.
|
|||||||
| `GeneralConfig` | `config.schema` | `general:` block — `user_name`, `assistant_name` |
|
| `GeneralConfig` | `config.schema` | `general:` block — `user_name`, `assistant_name` |
|
||||||
| `ProviderConfig` | `config.schema` | `ai:` block — `provider_id`, `model`, `base_url` |
|
| `ProviderConfig` | `config.schema` | `ai:` block — `provider_id`, `model`, `base_url` |
|
||||||
| `PluginConfig` | `config.schema` | `plugins:` block — `enabled`, `require_approval`, `log_executions` |
|
| `PluginConfig` | `config.schema` | `plugins:` block — `enabled`, `require_approval`, `log_executions` |
|
||||||
| `DaemonConfig` | `config.schema` | `daemon:` block |
|
| `DaemonConfig` | `config.schema` | `daemon:` block — `enabled`, `socket_path`, `log_file`, `pid_file`, `ipc_port` |
|
||||||
| `MemoryConfig` | `config.schema` | `memory:` block — `max_tokens_in_context`, `auto_load` |
|
| `MemoryConfig` | `config.schema` | `memory:` block — `max_tokens_in_context`, `auto_load` |
|
||||||
| `SecurityConfig` | `config.schema` | `security:` block — `injection_detection`, `log_injections` |
|
| `SecurityConfig` | `config.schema` | `security:` block — `injection_detection`, `log_injections` |
|
||||||
| `ConversationHistory` | `chat.history` | Holds message list; builds API payload via `build_for_api()`; trims to token budget |
|
| `ConversationHistory` | `chat.history` | Holds message list; builds API payload via `build_for_api()`; trims to token budget |
|
||||||
@@ -524,3 +564,6 @@ Import `console` from here; do not create a second `rich.Console()` in new code.
|
|||||||
| `Tool` | `plugins.base` | Dataclass — `name`, `description`, `parameters` (JSON Schema), `handler`, `requires_approval` |
|
| `Tool` | `plugins.base` | Dataclass — `name`, `description`, `parameters` (JSON Schema), `handler`, `requires_approval` |
|
||||||
| `PyraPlugin` | `plugins.base` | `@runtime_checkable` Protocol — the plugin interface |
|
| `PyraPlugin` | `plugins.base` | `@runtime_checkable` Protocol — the plugin interface |
|
||||||
| `BasePlugin` | `plugins.base` | Concrete base with no-op defaults; plugins should inherit this |
|
| `BasePlugin` | `plugins.base` | Concrete base with no-op defaults; plugins should inherit this |
|
||||||
|
| `TaskPlanner` | `chat.planner` | Multi-step plan runner; `make_tool_handler()` returns the callable wired into the chat session; presents plan for user approval, executes each step via litellm with up to 5 tool-use iterations, verifies output before proceeding |
|
||||||
|
| `PluginSupervisor` | `daemon.core` | asyncio supervisor — `add_task(name, factory)`, `start()`, `stop()`, `reload()`, `status()`; restarts crashed tasks up to 10× with 5s back-off |
|
||||||
|
| `PidFile` | `daemon.pid` | `write()` (atomic), `read()`, `is_stale()`, `remove()`, context manager; `PidFileError(OSError)` raised when live PID already exists |
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Pyra
|
# Pyra
|
||||||
|
|
||||||
A personal AI assistant CLI with vault-first security. Combines multi-provider AI chat with
|
A personal AI assistant CLI with vault-first security. Combines multi-provider AI chat,
|
||||||
long-term memory and (coming) automation skills.
|
long-term memory, and an extensible plugin system.
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
@@ -31,6 +31,17 @@ pyra chat # start talking
|
|||||||
| `pyra memory read <name>` | Read a memory file |
|
| `pyra memory read <name>` | Read a memory file |
|
||||||
| `pyra memory write <name> <content>` | Write a memory file |
|
| `pyra memory write <name> <content>` | Write a memory file |
|
||||||
| `pyra memory append <name> <content>` | Append to a memory file |
|
| `pyra memory append <name> <content>` | Append to a memory file |
|
||||||
|
| `pyra plugin list` | List installed and available plugins |
|
||||||
|
| `pyra plugin install <name>` | Install a bundled plugin |
|
||||||
|
| `pyra plugin enable <name>` | Enable an installed plugin |
|
||||||
|
| `pyra plugin disable <name>` | Disable a plugin (keeps it installed) |
|
||||||
|
| `pyra plugin setup <name>` | Run a plugin's credential setup wizard |
|
||||||
|
| `pyra daemon start` | Start the background daemon *(Stage 6, not yet implemented)* |
|
||||||
|
| `pyra daemon stop` | Stop the running daemon *(Stage 6, not yet implemented)* |
|
||||||
|
| `pyra daemon status` | Show daemon status *(Stage 6, not yet implemented)* |
|
||||||
|
| `pyra daemon restart` | Restart the daemon *(Stage 6, not yet implemented)* |
|
||||||
|
| `pyra daemon install` | Register Pyra as a system service *(Stage 6, not yet implemented)* |
|
||||||
|
| `pyra daemon uninstall` | Remove the system service *(Stage 6, not yet implemented)* |
|
||||||
|
|
||||||
### In-chat slash commands
|
### In-chat slash commands
|
||||||
|
|
||||||
@@ -38,6 +49,7 @@ pyra chat # start talking
|
|||||||
|---------|-------------|
|
|---------|-------------|
|
||||||
| `/help` | Show available commands |
|
| `/help` | Show available commands |
|
||||||
| `/memory list` | List memory files |
|
| `/memory list` | List memory files |
|
||||||
|
| `/config` | Open the configuration TUI |
|
||||||
| `/clear` | Clear conversation history |
|
| `/clear` | Clear conversation history |
|
||||||
| `/quit` or `/exit` | Exit Pyra |
|
| `/quit` or `/exit` | Exit Pyra |
|
||||||
|
|
||||||
@@ -48,16 +60,41 @@ pyra chat # start talking
|
|||||||
- **Prompt injection scanner** — warns on suspicious AI output, logs to `~/.pyra/security.log`
|
- **Prompt injection scanner** — warns on suspicious AI output, logs to `~/.pyra/security.log`
|
||||||
- **Path sandboxing** — the AI can only reference memory files by name; traversal is blocked
|
- **Path sandboxing** — the AI can only reference memory files by name; traversal is blocked
|
||||||
|
|
||||||
|
## Plugins
|
||||||
|
|
||||||
|
Pyra has an extensible plugin system. Bundled plugins are shipped with Pyra and installed on
|
||||||
|
demand; third-party plugins can be dropped into `~/.pyra/plugins/` directly.
|
||||||
|
|
||||||
|
Each plugin is a directory containing a `manifest.json` and a `plugin.py`. Plugin credentials
|
||||||
|
are stored in the vault under namespaced keys (`plugin:<name>:<key>`) — never in `config.yaml`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pyra plugin list # see what's available
|
||||||
|
pyra plugin install <name> # copy a bundled plugin to ~/.pyra/plugins/
|
||||||
|
pyra plugin setup <name> # enter credentials (stored in vault)
|
||||||
|
pyra plugin enable <name> # activate for the next chat session
|
||||||
|
```
|
||||||
|
|
||||||
|
## Multi-step Planning
|
||||||
|
|
||||||
|
When given a complex task the AI can propose a **multi-step plan** using the built-in
|
||||||
|
`plan_and_execute` tool. Pyra prints the plan and asks for approval before executing
|
||||||
|
anything. Each step runs as a separate AI call with access to enabled plugin tools; each
|
||||||
|
result is verified before moving on to the next step. You can decline the plan or
|
||||||
|
interrupt at any point.
|
||||||
|
|
||||||
## Memory
|
## Memory
|
||||||
|
|
||||||
Pyra reads your memory files at the start of each session and injects them as context.
|
Pyra reads your memory files at the start of each session and injects them as context.
|
||||||
Files are plain Markdown stored in `~/.pyra/memory/`:
|
Files are plain Markdown stored in `~/.pyra/memory/`, indexed by a SQLite full-text search
|
||||||
|
database (`memory.db`) for fast in-chat lookup.
|
||||||
|
|
||||||
```
|
```
|
||||||
~/.pyra/memory/
|
~/.pyra/memory/
|
||||||
├── user/profile.md ← who you are
|
├── user/profile.md ← who you are
|
||||||
├── context/ ← ongoing projects
|
├── context/ ← ongoing projects
|
||||||
└── knowledge/ ← general notes
|
├── knowledge/ ← general notes
|
||||||
|
└── memory.db ← FTS5 search index (auto-managed)
|
||||||
```
|
```
|
||||||
|
|
||||||
## `~/.pyra/` Directory
|
## `~/.pyra/` Directory
|
||||||
@@ -67,15 +104,25 @@ Files are plain Markdown stored in `~/.pyra/memory/`:
|
|||||||
├── config.yaml ← provider + model (no secrets)
|
├── config.yaml ← provider + model (no secrets)
|
||||||
├── security.log ← injection event log
|
├── security.log ← injection event log
|
||||||
├── memory/ ← AI-readable long-term memory
|
├── memory/ ← AI-readable long-term memory
|
||||||
├── skills/ ← automation scripts (Stage 2)
|
│ └── memory.db ← SQLite FTS5 search index
|
||||||
|
├── plugins/ ← installed plugins
|
||||||
|
│ └── <name>/
|
||||||
|
│ ├── manifest.json
|
||||||
|
│ └── plugin.py
|
||||||
|
├── logs/ ← execution logs
|
||||||
|
│ ├── tool_executions.log
|
||||||
|
│ └── plugin_errors.log
|
||||||
└── vault/ ← secure, AI-inaccessible storage
|
└── vault/ ← secure, AI-inaccessible storage
|
||||||
└── secrets/api_keys.json
|
└── secrets/api_keys.json
|
||||||
```
|
```
|
||||||
|
|
||||||
## Roadmap
|
## Roadmap
|
||||||
|
|
||||||
- **Stage 1** (now): Core CLI, multi-provider chat, memory, vault security
|
- **Stage 1** ✅ Core CLI — multi-provider chat, memory, vault security
|
||||||
- **Stage 2**: Skills — shell/PowerShell/Python automations with user approval gates
|
- **Stage 2** ✅ Plugin Framework — extensible tools, slash commands, approval gates
|
||||||
- **Stage 3**: Vault encryption with `age`
|
- **Stage 3** ✅ Memory Database — SQLite + FTS5 full-text search index
|
||||||
- **Stage 4**: Security audit sub-agent
|
- **Stage 4** Vault Encryption — `age`-based encryption of `~/.pyra/vault/secrets/`
|
||||||
- **Stage 5**: Web UI, embedding-based memory search
|
- **Stage 5** Skills System — YAML-defined multi-plugin workflows with event triggers
|
||||||
|
- **Stage 6** Daemon + Messaging Bots — always-on asyncio daemon, Matrix/Telegram/Signal bots
|
||||||
|
- **Stage 7** Security Audit Sub-agent — automated scanning for injection, CVEs, permission drift
|
||||||
|
- **Stage 8** Web UI — optional local interface, embedding-based memory search
|
||||||
|
|||||||
+2
-2
@@ -28,7 +28,7 @@ dev = [
|
|||||||
]
|
]
|
||||||
nextcloud = ["caldav>=1.3.0", "webdav4>=0.9.0", "vobject>=0.9.6"]
|
nextcloud = ["caldav>=1.3.0", "webdav4>=0.9.0", "vobject>=0.9.6"]
|
||||||
matrix = ["matrix-nio>=0.24.0", "aiofiles>=23.0.0"]
|
matrix = ["matrix-nio>=0.24.0", "aiofiles>=23.0.0"]
|
||||||
telegram = ["python-telegram-bot>=21.0"]
|
telegram = ["python-telegram-bot>=21.0", "bcrypt>=4.0.0"]
|
||||||
ssh = ["paramiko>=3.4.0"]
|
ssh = ["paramiko>=3.4.0"]
|
||||||
docker = ["docker>=7.0.0"]
|
docker = ["docker>=7.0.0"]
|
||||||
gdrive = ["google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0"]
|
gdrive = ["google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0"]
|
||||||
@@ -38,7 +38,7 @@ daemon = ["aiofiles>=23.0.0"]
|
|||||||
all-plugins = [
|
all-plugins = [
|
||||||
"caldav>=1.3.0", "webdav4>=0.9.0", "vobject>=0.9.6",
|
"caldav>=1.3.0", "webdav4>=0.9.0", "vobject>=0.9.6",
|
||||||
"matrix-nio>=0.24.0", "aiofiles>=23.0.0",
|
"matrix-nio>=0.24.0", "aiofiles>=23.0.0",
|
||||||
"python-telegram-bot>=21.0",
|
"python-telegram-bot>=21.0", "bcrypt>=4.0.0",
|
||||||
"paramiko>=3.4.0",
|
"paramiko>=3.4.0",
|
||||||
"docker>=7.0.0",
|
"docker>=7.0.0",
|
||||||
"google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0",
|
"google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0",
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"name": "telegram_bot",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"description": "Remote Pyra chat over Telegram — full AI chat with tool approval via inline buttons",
|
||||||
|
"author": "pyra",
|
||||||
|
"requires": ["python-telegram-bot>=21.0", "bcrypt>=4.0.0"]
|
||||||
|
}
|
||||||
@@ -0,0 +1,665 @@
|
|||||||
|
"""Telegram bot plugin — remote Pyra chat over Telegram.
|
||||||
|
|
||||||
|
Runs as a daemon task (long-polling). Each chat session requires passphrase
|
||||||
|
authentication and is rate-limited to 20 messages/hour. Incoming messages are
|
||||||
|
injection-scanned before reaching the AI. Tool calls are approved via Telegram
|
||||||
|
inline keyboard buttons (2-minute timeout). AI responses are streamed
|
||||||
|
progressively by editing a placeholder message.
|
||||||
|
|
||||||
|
Vault keys used:
|
||||||
|
plugin:telegram_bot:token — bot token from @BotFather
|
||||||
|
plugin:telegram_bot:allowed_users — comma-separated Telegram user IDs
|
||||||
|
plugin:telegram_bot:passphrase_hash — bcrypt hash of the session passphrase
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections import deque
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
import litellm
|
||||||
|
from telegram import Bot, InlineKeyboardButton, InlineKeyboardMarkup, Update
|
||||||
|
from telegram.ext import (
|
||||||
|
Application,
|
||||||
|
CallbackQueryHandler,
|
||||||
|
CommandHandler,
|
||||||
|
ContextTypes,
|
||||||
|
MessageHandler,
|
||||||
|
filters,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pyra.plugins.base import BasePlugin, ConfigField
|
||||||
|
from pyra.utils.paths import pyra_home
|
||||||
|
|
||||||
|
_log = logging.getLogger("pyra.plugin.telegram_bot")
|
||||||
|
|
||||||
|
_HISTORY_DB = pyra_home() / "telegram_history.db"
|
||||||
|
_MAX_HISTORY = 40 # messages kept per chat
|
||||||
|
_RATE_LIMIT = 20 # messages per hour per user
|
||||||
|
_APPROVAL_TIMEOUT = 120 # seconds to wait for inline button press
|
||||||
|
_EDIT_INTERVAL = 1.5 # minimum seconds between progressive message edits
|
||||||
|
_MAX_TOOL_ITER = 10
|
||||||
|
_MAX_MSG_LEN = 4096 # Telegram hard limit
|
||||||
|
|
||||||
|
|
||||||
|
# ── SQLite history ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _open_db() -> sqlite3.Connection:
|
||||||
|
conn = sqlite3.connect(_HISTORY_DB)
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
|
chat_id INTEGER PRIMARY KEY,
|
||||||
|
history TEXT NOT NULL DEFAULT '[]',
|
||||||
|
updated REAL NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
conn.commit()
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
os.chmod(_HISTORY_DB, 0o600)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
def _load_history(chat_id: int) -> list[dict]:
|
||||||
|
conn = _open_db()
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT history FROM sessions WHERE chat_id = ?", (chat_id,)
|
||||||
|
).fetchone()
|
||||||
|
conn.close()
|
||||||
|
return json.loads(row[0]) if row else []
|
||||||
|
|
||||||
|
|
||||||
|
def _save_history(chat_id: int, messages: list[dict]) -> None:
|
||||||
|
trimmed = messages[-_MAX_HISTORY:]
|
||||||
|
conn = _open_db()
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR REPLACE INTO sessions (chat_id, history, updated) VALUES (?,?,?)",
|
||||||
|
(chat_id, json.dumps(trimmed), time.time()),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Rate limiter ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _RateLimiter:
|
||||||
|
def __init__(self, per_hour: int = _RATE_LIMIT) -> None:
|
||||||
|
self._buckets: dict[int, deque] = {}
|
||||||
|
self._limit = per_hour
|
||||||
|
|
||||||
|
def allow(self, user_id: int) -> bool:
|
||||||
|
now = time.monotonic()
|
||||||
|
bucket = self._buckets.setdefault(user_id, deque())
|
||||||
|
cutoff = now - 3600
|
||||||
|
while bucket and bucket[0] < cutoff:
|
||||||
|
bucket.popleft()
|
||||||
|
if len(bucket) >= self._limit:
|
||||||
|
return False
|
||||||
|
bucket.append(now)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# ── Plugin ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TelegramBotPlugin(BasePlugin):
|
||||||
|
name = "telegram_bot"
|
||||||
|
description = "Remote Pyra chat over Telegram (daemon task, long-polling)"
|
||||||
|
version = "1.0.0"
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._vault_reader: Callable[[str], str | None] | None = None
|
||||||
|
self._rate_limiter = _RateLimiter()
|
||||||
|
# chat_id -> {authenticated, awaiting_passphrase, attempts}
|
||||||
|
self._sessions: dict[int, dict] = {}
|
||||||
|
# short call_id -> asyncio.Future[bool]
|
||||||
|
self._pending_approvals: dict[str, asyncio.Future] = {}
|
||||||
|
|
||||||
|
# ── Plugin lifecycle ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def on_load(self, vault_reader: Callable[[str], str | None]) -> None:
|
||||||
|
self._vault_reader = vault_reader
|
||||||
|
|
||||||
|
def setup(self, console: Any, vault_writer: Callable[[str, str], None]) -> None:
|
||||||
|
import questionary
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.rule import Rule
|
||||||
|
|
||||||
|
console.print()
|
||||||
|
console.print(Panel(
|
||||||
|
"[bold]Telegram Bot Setup Wizard[/bold]\n\n"
|
||||||
|
"This wizard connects Pyra to Telegram so you can chat with your\n"
|
||||||
|
"assistant from anywhere. You will need Telegram open on your phone\n"
|
||||||
|
"or desktop to complete the next steps.",
|
||||||
|
border_style="cyan",
|
||||||
|
))
|
||||||
|
|
||||||
|
# ── Step 1: Create bot ────────────────────────────────────────────────
|
||||||
|
console.print()
|
||||||
|
console.print(Rule("[bold cyan]Step 1 / 5[/bold cyan] Create your Telegram bot"))
|
||||||
|
console.print()
|
||||||
|
console.print(
|
||||||
|
" 1. Open Telegram and search for [bold]@BotFather[/bold]\n"
|
||||||
|
" 2. Send [bold]/newbot[/bold] and follow the prompts\n"
|
||||||
|
" 3. Choose a display name (e.g. [dim]My Pyra Assistant[/dim])\n"
|
||||||
|
" 4. Choose a username ending in [bold]bot[/bold] "
|
||||||
|
"(e.g. [dim]my_pyra_bot[/dim])\n"
|
||||||
|
" 5. BotFather replies with a token that looks like:\n"
|
||||||
|
" [dim]123456789:AABBccDDeeFFggHHiiJJkkLL[/dim]"
|
||||||
|
)
|
||||||
|
console.print()
|
||||||
|
questionary.press_any_key_to_continue(
|
||||||
|
" Press any key when you have your token ready ..."
|
||||||
|
).ask()
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
token = questionary.password(" Bot token:").ask()
|
||||||
|
if not token or not token.strip():
|
||||||
|
console.print("[dim]Setup cancelled.[/dim]")
|
||||||
|
return
|
||||||
|
token = token.strip()
|
||||||
|
|
||||||
|
# ── Step 2: Find user ID ──────────────────────────────────────────────
|
||||||
|
console.print()
|
||||||
|
console.print(Rule("[bold cyan]Step 2 / 5[/bold cyan] Find your Telegram user ID"))
|
||||||
|
console.print()
|
||||||
|
console.print(
|
||||||
|
" Your user ID is a permanent number that identifies your account.\n"
|
||||||
|
" It never changes, even if you change your username.\n\n"
|
||||||
|
" 1. Search for [bold]@userinfobot[/bold] in Telegram\n"
|
||||||
|
" 2. Send any message (e.g. [dim]/start[/dim])\n"
|
||||||
|
" 3. Copy the [bold]Id:[/bold] number from the reply "
|
||||||
|
"(e.g. [dim]123456789[/dim])"
|
||||||
|
)
|
||||||
|
console.print()
|
||||||
|
questionary.press_any_key_to_continue(
|
||||||
|
" Press any key when you have your user ID ready ..."
|
||||||
|
).ask()
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
allowed = questionary.text(
|
||||||
|
" Allowed Telegram user IDs (comma-separated, leave blank to allow anyone):"
|
||||||
|
).ask()
|
||||||
|
if allowed is None:
|
||||||
|
console.print("[dim]Setup cancelled.[/dim]")
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Step 3: Session passphrase ────────────────────────────────────────
|
||||||
|
console.print()
|
||||||
|
console.print(Rule("[bold cyan]Step 3 / 5[/bold cyan] Set a session passphrase"))
|
||||||
|
console.print()
|
||||||
|
console.print(
|
||||||
|
" The passphrase is an extra layer of security. Every new chat\n"
|
||||||
|
" session must pass this challenge before Pyra responds — even\n"
|
||||||
|
" if someone else gains access to your Telegram account."
|
||||||
|
)
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
passphrase = questionary.password(" Session passphrase:").ask()
|
||||||
|
if not passphrase:
|
||||||
|
console.print("[dim]Setup cancelled.[/dim]")
|
||||||
|
return
|
||||||
|
|
||||||
|
confirm = questionary.password(" Confirm passphrase:").ask()
|
||||||
|
if passphrase != confirm:
|
||||||
|
console.print("[red]Passphrases do not match. Run setup again to retry.[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Step 4: Save to vault ─────────────────────────────────────────────
|
||||||
|
console.print()
|
||||||
|
console.print(Rule("[bold cyan]Step 4 / 5[/bold cyan] Saving configuration"))
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
pw_hash = bcrypt.hashpw(passphrase.encode(), bcrypt.gensalt()).decode()
|
||||||
|
vault_writer("plugin:telegram_bot:token", token)
|
||||||
|
vault_writer("plugin:telegram_bot:allowed_users", (allowed or "").strip())
|
||||||
|
vault_writer("plugin:telegram_bot:passphrase_hash", pw_hash)
|
||||||
|
|
||||||
|
allowed_display = (allowed or "").strip() or "[dim](any user — consider restricting)[/dim]"
|
||||||
|
console.print(f" [green]✓[/green] Bot token stored in vault")
|
||||||
|
console.print(f" [green]✓[/green] Allowed users: {allowed_display}")
|
||||||
|
console.print(f" [green]✓[/green] Passphrase stored as bcrypt hash")
|
||||||
|
|
||||||
|
# ── Step 5: Done ──────────────────────────────────────────────────────
|
||||||
|
console.print()
|
||||||
|
console.print(Rule("[bold cyan]Step 5 / 5[/bold cyan] Configuration complete"))
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
def config_fields(self) -> list[ConfigField]:
|
||||||
|
return [
|
||||||
|
ConfigField(
|
||||||
|
"rate_limit",
|
||||||
|
"Rate limit (messages/hour)",
|
||||||
|
"text",
|
||||||
|
str(_RATE_LIMIT),
|
||||||
|
description="Maximum messages per hour per Telegram user",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def daemon_tasks(self) -> list:
|
||||||
|
return [self._run_polling()]
|
||||||
|
|
||||||
|
# ── Daemon task ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _run_polling(self) -> None:
|
||||||
|
assert self._vault_reader is not None
|
||||||
|
|
||||||
|
token = self._vault_reader("plugin:telegram_bot:token")
|
||||||
|
if not token:
|
||||||
|
_log.error(
|
||||||
|
"Telegram bot token not set. Run `pyra plugin setup telegram_bot`."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
passphrase_hash = self._vault_reader("plugin:telegram_bot:passphrase_hash") or ""
|
||||||
|
allowed_str = self._vault_reader("plugin:telegram_bot:allowed_users") or ""
|
||||||
|
allowed_users: set[int] = {
|
||||||
|
int(uid.strip())
|
||||||
|
for uid in allowed_str.split(",")
|
||||||
|
if uid.strip().isdigit()
|
||||||
|
}
|
||||||
|
|
||||||
|
app = Application.builder().token(token).build()
|
||||||
|
plugin = self # closure reference
|
||||||
|
|
||||||
|
async def _on_start(update: Update, ctx: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
|
if update.effective_user and update.effective_user.id in allowed_users:
|
||||||
|
await update.message.reply_text(
|
||||||
|
"Pyra is online. Send any message to authenticate."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _on_message(update: Update, ctx: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
|
await plugin._handle_message(update, ctx, allowed_users, passphrase_hash)
|
||||||
|
|
||||||
|
async def _on_approval(update: Update, ctx: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
|
await plugin._handle_approval_callback(update)
|
||||||
|
|
||||||
|
app.add_handler(CommandHandler("start", _on_start))
|
||||||
|
app.add_handler(CallbackQueryHandler(_on_approval))
|
||||||
|
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, _on_message))
|
||||||
|
|
||||||
|
_log.info("Telegram bot starting (long-polling).")
|
||||||
|
await app.initialize()
|
||||||
|
try:
|
||||||
|
await app.start()
|
||||||
|
await app.updater.start_polling(drop_pending_updates=True)
|
||||||
|
_log.info("Telegram bot is polling for updates.")
|
||||||
|
await asyncio.Event().wait() # block until CancelledError
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
_log.info("Telegram bot shutting down.")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await app.updater.stop()
|
||||||
|
await app.stop()
|
||||||
|
await app.shutdown()
|
||||||
|
except Exception as exc:
|
||||||
|
_log.warning("Error during Telegram bot shutdown: %s", exc)
|
||||||
|
|
||||||
|
# ── Message handler ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _handle_message(
|
||||||
|
self,
|
||||||
|
update: Update,
|
||||||
|
ctx: ContextTypes.DEFAULT_TYPE,
|
||||||
|
allowed_users: set[int],
|
||||||
|
passphrase_hash: str,
|
||||||
|
) -> None:
|
||||||
|
if update.effective_user is None or update.message is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
user_id = update.effective_user.id
|
||||||
|
chat_id = update.effective_chat.id if update.effective_chat else user_id
|
||||||
|
|
||||||
|
# Allowlist — silently ignore unknown senders
|
||||||
|
if allowed_users and user_id not in allowed_users:
|
||||||
|
return
|
||||||
|
|
||||||
|
text = (update.message.text or "").strip()
|
||||||
|
if not text:
|
||||||
|
return
|
||||||
|
|
||||||
|
session = self._sessions.setdefault(
|
||||||
|
chat_id,
|
||||||
|
{"authenticated": False, "awaiting_passphrase": False, "attempts": 0},
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Passphrase authentication ─────────────────────────────────────────
|
||||||
|
if not session["authenticated"]:
|
||||||
|
if not session["awaiting_passphrase"]:
|
||||||
|
session["awaiting_passphrase"] = True
|
||||||
|
session["attempts"] = 0
|
||||||
|
await update.message.reply_text("Enter your passphrase to continue:")
|
||||||
|
return
|
||||||
|
|
||||||
|
if passphrase_hash and bcrypt.checkpw(text.encode(), passphrase_hash.encode()):
|
||||||
|
session["authenticated"] = True
|
||||||
|
session["awaiting_passphrase"] = False
|
||||||
|
session["attempts"] = 0
|
||||||
|
await update.message.reply_text(
|
||||||
|
"Authenticated. How can I help you?\n"
|
||||||
|
"Send /start at any time to check bot status."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
session["attempts"] += 1
|
||||||
|
remaining = 3 - session["attempts"]
|
||||||
|
if remaining <= 0:
|
||||||
|
session["awaiting_passphrase"] = False
|
||||||
|
await update.message.reply_text(
|
||||||
|
"Too many failed attempts. Send any message to try again."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await update.message.reply_text(
|
||||||
|
f"Wrong passphrase. {remaining} attempt(s) left."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Rate limit ────────────────────────────────────────────────────────
|
||||||
|
if not self._rate_limiter.allow(user_id):
|
||||||
|
await update.message.reply_text(
|
||||||
|
"Rate limit reached (20 messages/hour). Try again later."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Injection scan ────────────────────────────────────────────────────
|
||||||
|
from pyra.security.injection import scan_response
|
||||||
|
|
||||||
|
warnings = scan_response(text)
|
||||||
|
if warnings:
|
||||||
|
labels = ", ".join(w.pattern_label for w in warnings)
|
||||||
|
_log.warning("Injection in Telegram message (user %d): %s", user_id, labels)
|
||||||
|
await update.message.reply_text(
|
||||||
|
"Your message was blocked: injection pattern detected."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Load history + system context ─────────────────────────────────────
|
||||||
|
history = _load_history(chat_id)
|
||||||
|
if not history:
|
||||||
|
try:
|
||||||
|
from pyra.memory.reader import load_context_for_session
|
||||||
|
|
||||||
|
ctx_text = load_context_for_session()
|
||||||
|
if ctx_text:
|
||||||
|
history = [{"role": "system", "content": ctx_text}]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
history.append({"role": "user", "content": text})
|
||||||
|
|
||||||
|
try:
|
||||||
|
await ctx.bot.send_chat_action(chat_id=chat_id, action="typing")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ── AI response ───────────────────────────────────────────────────────
|
||||||
|
try:
|
||||||
|
reply = await self._ai_chat(chat_id, history, ctx.bot)
|
||||||
|
except Exception as exc:
|
||||||
|
_log.error("AI error (chat %d): %s", chat_id, exc, exc_info=True)
|
||||||
|
await ctx.bot.send_message(chat_id=chat_id, text=f"AI error: {exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
history.append({"role": "assistant", "content": reply})
|
||||||
|
_save_history(chat_id, history)
|
||||||
|
|
||||||
|
# ── AI streaming + tool-use loop ──────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _ai_chat(self, chat_id: int, messages: list[dict], bot: Bot) -> str:
|
||||||
|
from pyra.config.manager import load_config
|
||||||
|
from pyra.plugins.registry import PluginRegistry
|
||||||
|
from pyra.setup.providers import get_provider
|
||||||
|
from pyra.vault.reader import get_key
|
||||||
|
|
||||||
|
cfg = load_config()
|
||||||
|
provider = get_provider(cfg.ai.provider_id)
|
||||||
|
api_key = get_key(cfg.ai.provider_id) if provider.requires_key else "local"
|
||||||
|
|
||||||
|
call_kwargs: dict[str, Any] = {
|
||||||
|
"model": f"{provider.litellm_prefix}{cfg.ai.model}",
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
||||||
|
base_url = cfg.ai.base_url or provider.base_url
|
||||||
|
if base_url:
|
||||||
|
call_kwargs["api_base"] = base_url
|
||||||
|
|
||||||
|
litellm.suppress_debug_info = True
|
||||||
|
|
||||||
|
registry = PluginRegistry.instance()
|
||||||
|
tools_spec = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": t.name,
|
||||||
|
"description": t.description,
|
||||||
|
"parameters": t.parameters,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for t in registry.get_all_tools()
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mutable state shared with helpers below
|
||||||
|
state: dict[str, Any] = {"msg_id": None, "last_edit": 0.0}
|
||||||
|
|
||||||
|
placeholder = await bot.send_message(chat_id=chat_id, text="…")
|
||||||
|
state["msg_id"] = placeholder.message_id
|
||||||
|
|
||||||
|
async def _update(text: str) -> None:
|
||||||
|
if not state["msg_id"]:
|
||||||
|
return
|
||||||
|
now = time.monotonic()
|
||||||
|
if now - state["last_edit"] < _EDIT_INTERVAL:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await bot.edit_message_text(
|
||||||
|
chat_id=chat_id,
|
||||||
|
message_id=state["msg_id"],
|
||||||
|
text=text[:_MAX_MSG_LEN],
|
||||||
|
)
|
||||||
|
state["last_edit"] = now
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _finalize(text: str) -> None:
|
||||||
|
if not state["msg_id"]:
|
||||||
|
return
|
||||||
|
if text:
|
||||||
|
try:
|
||||||
|
await bot.edit_message_text(
|
||||||
|
chat_id=chat_id,
|
||||||
|
message_id=state["msg_id"],
|
||||||
|
text=text[:_MAX_MSG_LEN],
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
await bot.delete_message(chat_id=chat_id, message_id=state["msg_id"])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
state["msg_id"] = None
|
||||||
|
|
||||||
|
accumulated = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
for _iter in range(_MAX_TOOL_ITER):
|
||||||
|
tool_chunks: dict[int, dict] = {}
|
||||||
|
accumulated = ""
|
||||||
|
|
||||||
|
stream = await litellm.acompletion(
|
||||||
|
**call_kwargs,
|
||||||
|
messages=messages,
|
||||||
|
tools=tools_spec if tools_spec else None,
|
||||||
|
tool_choice="auto" if tools_spec else None,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in stream:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if delta.content:
|
||||||
|
accumulated += delta.content
|
||||||
|
await _update(accumulated)
|
||||||
|
|
||||||
|
if delta.tool_calls:
|
||||||
|
for tc in delta.tool_calls:
|
||||||
|
idx = tc.index
|
||||||
|
if idx not in tool_chunks:
|
||||||
|
tool_chunks[idx] = {"id": tc.id or "", "name": "", "args": ""}
|
||||||
|
if tc.function:
|
||||||
|
if tc.function.name:
|
||||||
|
tool_chunks[idx]["name"] += tc.function.name
|
||||||
|
if tc.function.arguments:
|
||||||
|
tool_chunks[idx]["args"] += tc.function.arguments
|
||||||
|
|
||||||
|
if not tool_chunks:
|
||||||
|
await _finalize(accumulated)
|
||||||
|
return accumulated
|
||||||
|
|
||||||
|
# Show any intermediate prose before tool calls
|
||||||
|
if accumulated:
|
||||||
|
await _finalize(accumulated)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
await bot.delete_message(chat_id=chat_id, message_id=state["msg_id"])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
state["msg_id"] = None
|
||||||
|
|
||||||
|
tool_calls_list = [
|
||||||
|
{
|
||||||
|
"id": data["id"],
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": data["name"], "arguments": data["args"]},
|
||||||
|
}
|
||||||
|
for _, data in sorted(tool_chunks.items())
|
||||||
|
]
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": accumulated or None,
|
||||||
|
"tool_calls": tool_calls_list,
|
||||||
|
})
|
||||||
|
|
||||||
|
for tc in tool_calls_list:
|
||||||
|
result = await self._execute_tool_with_approval(tc, chat_id, bot)
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tc["id"],
|
||||||
|
"content": result,
|
||||||
|
})
|
||||||
|
|
||||||
|
# New placeholder for the next AI response
|
||||||
|
new_ph = await bot.send_message(chat_id=chat_id, text="…")
|
||||||
|
state["msg_id"] = new_ph.message_id
|
||||||
|
state["last_edit"] = 0.0
|
||||||
|
|
||||||
|
except litellm.BadRequestError:
|
||||||
|
# Provider doesn't support tool calls — retry without tools
|
||||||
|
accumulated = ""
|
||||||
|
state["last_edit"] = 0.0
|
||||||
|
stream = await litellm.acompletion(
|
||||||
|
**call_kwargs, messages=messages, stream=True
|
||||||
|
)
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk.choices[0].delta.content:
|
||||||
|
accumulated += chunk.choices[0].delta.content
|
||||||
|
await _update(accumulated)
|
||||||
|
await _finalize(accumulated)
|
||||||
|
return accumulated
|
||||||
|
|
||||||
|
return accumulated or "Error: tool-use loop exceeded maximum iterations."
|
||||||
|
|
||||||
|
# ── Tool approval via inline buttons ──────────────────────────────────────
|
||||||
|
|
||||||
|
async def _execute_tool_with_approval(
|
||||||
|
self, tool_call: dict, chat_id: int, bot: Bot
|
||||||
|
) -> str:
|
||||||
|
from pyra.plugins.registry import PluginRegistry
|
||||||
|
from pyra.security.injection import scan_response
|
||||||
|
|
||||||
|
tool_name = tool_call["function"]["name"]
|
||||||
|
args_raw = tool_call["function"]["arguments"]
|
||||||
|
try:
|
||||||
|
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return f"Error: invalid tool arguments for {tool_name}"
|
||||||
|
|
||||||
|
args_preview = json.dumps(args, indent=2)[:500]
|
||||||
|
call_id = uuid.uuid4().hex[:8]
|
||||||
|
|
||||||
|
keyboard = InlineKeyboardMarkup([[
|
||||||
|
InlineKeyboardButton("✅ Approve", callback_data=f"approve:{call_id}"),
|
||||||
|
InlineKeyboardButton("❌ Deny", callback_data=f"deny:{call_id}"),
|
||||||
|
]])
|
||||||
|
|
||||||
|
await bot.send_message(
|
||||||
|
chat_id=chat_id,
|
||||||
|
text=f"Tool request: {tool_name}\n\n{args_preview}",
|
||||||
|
reply_markup=keyboard,
|
||||||
|
)
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
future: asyncio.Future[bool] = loop.create_future()
|
||||||
|
self._pending_approvals[call_id] = future
|
||||||
|
|
||||||
|
try:
|
||||||
|
approved = await asyncio.wait_for(future, timeout=_APPROVAL_TIMEOUT)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self._pending_approvals.pop(call_id, None)
|
||||||
|
await bot.send_message(
|
||||||
|
chat_id=chat_id, text=f"Tool {tool_name}: approval timed out — denied."
|
||||||
|
)
|
||||||
|
return "Tool execution denied (timeout)."
|
||||||
|
|
||||||
|
if not approved:
|
||||||
|
return "Tool execution denied by user."
|
||||||
|
|
||||||
|
registry = PluginRegistry.instance()
|
||||||
|
tool = registry.find_tool(tool_name)
|
||||||
|
if tool is None:
|
||||||
|
return f"Error: tool '{tool_name}' not found in registry."
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = tool.handler(**args)
|
||||||
|
if not isinstance(result, str):
|
||||||
|
result = str(result)
|
||||||
|
except Exception as exc:
|
||||||
|
return f"Tool error: {exc}"
|
||||||
|
|
||||||
|
injection_warnings = scan_response(result)
|
||||||
|
if injection_warnings:
|
||||||
|
labels = ", ".join(w.pattern_label for w in injection_warnings)
|
||||||
|
await bot.send_message(
|
||||||
|
chat_id=chat_id,
|
||||||
|
text=f"Warning: tool result contains suspicious content ({labels}).",
|
||||||
|
)
|
||||||
|
|
||||||
|
return result[:4000]
|
||||||
|
|
||||||
|
# ── Approval callback ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _handle_approval_callback(self, update: Update) -> None:
|
||||||
|
query = update.callback_query
|
||||||
|
if query is None:
|
||||||
|
return
|
||||||
|
await query.answer()
|
||||||
|
data = query.data or ""
|
||||||
|
if ":" not in data:
|
||||||
|
return
|
||||||
|
action, call_id = data.split(":", 1)
|
||||||
|
future = self._pending_approvals.pop(call_id, None)
|
||||||
|
if future and not future.done():
|
||||||
|
future.set_result(action == "approve")
|
||||||
|
try:
|
||||||
|
await query.edit_message_reply_markup(reply_markup=None)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_plugin() -> TelegramBotPlugin:
|
||||||
|
return TelegramBotPlugin()
|
||||||
@@ -25,6 +25,7 @@ from pyra.plugins.executor import ToolExecutor
|
|||||||
from pyra.plugins.registry import PluginRegistry
|
from pyra.plugins.registry import PluginRegistry
|
||||||
from pyra.security.injection import scan_response
|
from pyra.security.injection import scan_response
|
||||||
from pyra.setup.providers import get_provider
|
from pyra.setup.providers import get_provider
|
||||||
|
from pyra.setup.wizard import fetch_loaded_models
|
||||||
from pyra.utils.paths import pyra_home
|
from pyra.utils.paths import pyra_home
|
||||||
|
|
||||||
_HISTORY_FILE = pyra_home() / ".chat_history"
|
_HISTORY_FILE = pyra_home() / ".chat_history"
|
||||||
@@ -176,6 +177,16 @@ def start_chat() -> None:
|
|||||||
"[dim]Type /help for commands, /quit to exit.[/dim]"
|
"[dim]Type /help for commands, /quit to exit.[/dim]"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if provider.group == "Local":
|
||||||
|
loaded = fetch_loaded_models(provider)
|
||||||
|
if not loaded:
|
||||||
|
render_info(f"No model currently loaded in {provider.display_name}.")
|
||||||
|
elif cfg.ai.model not in loaded:
|
||||||
|
render_info(
|
||||||
|
f"Model '{cfg.ai.model}' not loaded in {provider.display_name}. "
|
||||||
|
f"Loaded: {', '.join(loaded)}"
|
||||||
|
)
|
||||||
|
|
||||||
_flags: dict = {"use_tools": True}
|
_flags: dict = {"use_tools": True}
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
+158
-13
@@ -165,6 +165,7 @@ def plugin_list() -> None:
|
|||||||
@click.argument("name")
|
@click.argument("name")
|
||||||
def plugin_install(name: str) -> None:
|
def plugin_install(name: str) -> None:
|
||||||
"""Install a bundled plugin to ~/.pyra/plugins/."""
|
"""Install a bundled plugin to ~/.pyra/plugins/."""
|
||||||
|
import questionary
|
||||||
from pyra.plugins.install import get_bundled_plugins_dir, install_bundled_plugin
|
from pyra.plugins.install import get_bundled_plugins_dir, install_bundled_plugin
|
||||||
from pyra.utils.paths import pyra_home
|
from pyra.utils.paths import pyra_home
|
||||||
|
|
||||||
@@ -173,12 +174,55 @@ def plugin_install(name: str) -> None:
|
|||||||
try:
|
try:
|
||||||
install_bundled_plugin(name, bundled_dir, plugins_dir)
|
install_bundled_plugin(name, bundled_dir, plugins_dir)
|
||||||
console.print(f"[green]Installed:[/green] {name}")
|
console.print(f"[green]Installed:[/green] {name}")
|
||||||
console.print(f" Enable: [dim]pyra plugin enable {name}[/dim]")
|
|
||||||
console.print(f" Confirm: [dim]pyra plugin setup {name}[/dim]")
|
|
||||||
except FileNotFoundError as exc:
|
except FileNotFoundError as exc:
|
||||||
console.print(f"[red]Error:[/red] {exc}")
|
console.print(f"[red]Error:[/red] {exc}")
|
||||||
|
return
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
console.print(f"[red]Install failed:[/red] {exc}")
|
console.print(f"[red]Install failed:[/red] {exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
configure_now = questionary.confirm(f"Configure {name} now?", default=True).ask()
|
||||||
|
except (KeyboardInterrupt, EOFError):
|
||||||
|
return
|
||||||
|
if not configure_now:
|
||||||
|
console.print(f" Configure later: [dim]pyra plugin setup {name}[/dim]")
|
||||||
|
console.print(f" Enable: [dim]pyra plugin enable {name}[/dim]")
|
||||||
|
return
|
||||||
|
|
||||||
|
from pyra.plugins.loader import load_plugin_by_name
|
||||||
|
from pyra.vault.writer import set_key as _set_key
|
||||||
|
|
||||||
|
p = load_plugin_by_name(name, plugins_dir)
|
||||||
|
if p is None:
|
||||||
|
console.print(f"[red]Could not load {name} for setup.[/red]")
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
p.setup(console, _set_key)
|
||||||
|
except (KeyboardInterrupt, EOFError):
|
||||||
|
console.print("[dim]Setup cancelled.[/dim]")
|
||||||
|
return
|
||||||
|
except Exception as exc:
|
||||||
|
console.print(f"[red]Setup error:[/red] {exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
enable_now = questionary.confirm(f"Enable {name} now?", default=True).ask()
|
||||||
|
except (KeyboardInterrupt, EOFError):
|
||||||
|
enable_now = False
|
||||||
|
if enable_now:
|
||||||
|
from pyra.config.manager import load_config, save_config
|
||||||
|
try:
|
||||||
|
cfg = load_config()
|
||||||
|
if name not in cfg.plugins.enabled:
|
||||||
|
cfg.plugins.enabled.append(name)
|
||||||
|
save_config(cfg)
|
||||||
|
console.print(f"[green]Enabled:[/green] {name}")
|
||||||
|
except Exception as exc:
|
||||||
|
console.print(f"[yellow]Could not enable automatically:[/yellow] {exc}")
|
||||||
|
console.print(f" Enable manually: [dim]pyra plugin enable {name}[/dim]")
|
||||||
|
|
||||||
|
console.print(f"[dim]Run [bold]pyra daemon start[/bold] to bring {name} online.[/dim]")
|
||||||
|
|
||||||
|
|
||||||
@plugin.command("enable")
|
@plugin.command("enable")
|
||||||
@@ -266,43 +310,144 @@ def daemon() -> None:
|
|||||||
_bootstrap_or_exit()
|
_bootstrap_or_exit()
|
||||||
|
|
||||||
|
|
||||||
|
@daemon.command("run", hidden=True)
|
||||||
|
def daemon_run() -> None:
|
||||||
|
"""Run daemon in foreground (used by service manager)."""
|
||||||
|
from pyra.daemon.core import run_foreground
|
||||||
|
run_foreground()
|
||||||
|
|
||||||
|
|
||||||
@daemon.command("start")
|
@daemon.command("start")
|
||||||
def daemon_start() -> None:
|
def daemon_start() -> None:
|
||||||
"""Start the Pyra daemon in the background."""
|
"""Start the Pyra daemon in the background."""
|
||||||
console.print("[yellow]Daemon (Stage 2.4) is not yet implemented.[/yellow]")
|
from pyra.daemon.core import start_background
|
||||||
|
try:
|
||||||
|
start_background()
|
||||||
|
except FileNotFoundError:
|
||||||
|
console.print("[red]Error:[/red] Run [dim]pyra setup[/dim] first.")
|
||||||
|
|
||||||
|
|
||||||
@daemon.command("stop")
|
@daemon.command("stop")
|
||||||
def daemon_stop() -> None:
|
def daemon_stop() -> None:
|
||||||
"""Stop the running Pyra daemon."""
|
"""Stop the running Pyra daemon."""
|
||||||
console.print("[yellow]Daemon (Stage 2.4) is not yet implemented.[/yellow]")
|
_daemon_ipc("stop", success_msg="Daemon stopped.")
|
||||||
|
|
||||||
|
|
||||||
@daemon.command("status")
|
@daemon.command("status")
|
||||||
def daemon_status() -> None:
|
def daemon_status() -> None:
|
||||||
"""Show daemon status."""
|
"""Show daemon status."""
|
||||||
console.print("[yellow]Daemon (Stage 2.4) is not yet implemented.[/yellow]")
|
_daemon_ipc("status")
|
||||||
|
|
||||||
|
|
||||||
@daemon.command("restart")
|
@daemon.command("restart")
|
||||||
def daemon_restart() -> None:
|
def daemon_restart() -> None:
|
||||||
"""Restart the Pyra daemon."""
|
"""Restart the Pyra daemon."""
|
||||||
console.print("[yellow]Daemon (Stage 2.4) is not yet implemented.[/yellow]")
|
import time
|
||||||
|
from pyra.daemon.core import start_background
|
||||||
|
_daemon_ipc("stop", success_msg=None)
|
||||||
|
time.sleep(1.5)
|
||||||
|
try:
|
||||||
|
start_background()
|
||||||
|
except FileNotFoundError:
|
||||||
|
console.print("[red]Error:[/red] Run [dim]pyra setup[/dim] first.")
|
||||||
|
|
||||||
|
|
||||||
@daemon.command("install")
|
@daemon.command("install")
|
||||||
def daemon_install() -> None:
|
def daemon_install() -> None:
|
||||||
"""Install Pyra as a system service (launchd/systemd)."""
|
"""Install Pyra as a system service (launchd/systemd/schtasks)."""
|
||||||
console.print("[yellow]Daemon service install (Stage 2.4) is not yet implemented.[/yellow]")
|
from pyra.daemon.service import detect_platform, install_service
|
||||||
|
try:
|
||||||
|
install_service()
|
||||||
|
console.print(f"[green]Service installed[/green] ({detect_platform()}).")
|
||||||
|
except Exception as exc:
|
||||||
|
console.print(f"[red]Install failed:[/red] {exc}")
|
||||||
|
|
||||||
|
|
||||||
@daemon.command("uninstall")
|
@daemon.command("uninstall")
|
||||||
def daemon_uninstall() -> None:
|
def daemon_uninstall() -> None:
|
||||||
"""Remove the Pyra system service."""
|
"""Remove the Pyra system service."""
|
||||||
console.print("[yellow]Daemon service uninstall (Stage 2.4) is not yet implemented.[/yellow]")
|
from pyra.daemon.service import uninstall_service
|
||||||
|
try:
|
||||||
|
uninstall_service()
|
||||||
|
console.print("[green]Service removed.[/green]")
|
||||||
|
except Exception as exc:
|
||||||
|
console.print(f"[red]Uninstall failed:[/red] {exc}")
|
||||||
|
|
||||||
|
|
||||||
@daemon.command("run", hidden=True)
|
def _daemon_ipc(cmd: str, *, success_msg: str | None = None) -> None:
|
||||||
def daemon_run() -> None:
|
"""Send a command to the running daemon via IPC and render the response."""
|
||||||
"""Run daemon in foreground (used by service manager)."""
|
from pyra.config.manager import load_config
|
||||||
console.print("[yellow]Daemon (Stage 2.4) is not yet implemented.[/yellow]")
|
from pyra.daemon.ipc import (
|
||||||
|
get_socket_path,
|
||||||
|
is_unix_socket,
|
||||||
|
get_port_file_path,
|
||||||
|
send_command,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
cfg = load_config()
|
||||||
|
except FileNotFoundError:
|
||||||
|
console.print("[red]Error:[/red] Run [dim]pyra setup[/dim] first.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if is_unix_socket():
|
||||||
|
address = get_socket_path(cfg.daemon.socket_path)
|
||||||
|
else:
|
||||||
|
port = _read_windows_port()
|
||||||
|
if port is None:
|
||||||
|
console.print("[yellow]Daemon is not running.[/yellow]")
|
||||||
|
return
|
||||||
|
address = ("127.0.0.1", port)
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = send_command(address, {"cmd": cmd})
|
||||||
|
except (ConnectionRefusedError, FileNotFoundError, OSError):
|
||||||
|
console.print("[yellow]Daemon is not running.[/yellow]")
|
||||||
|
return
|
||||||
|
except ConnectionResetError:
|
||||||
|
console.print("[red]Permission denied:[/red] daemon rejected connection.")
|
||||||
|
return
|
||||||
|
except TimeoutError:
|
||||||
|
console.print("[red]Daemon did not respond in time.[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not resp.get("ok"):
|
||||||
|
console.print(f"[red]Error:[/red] {resp.get('data', {}).get('error', 'unknown')}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if cmd == "status":
|
||||||
|
_render_daemon_status(resp["data"])
|
||||||
|
elif success_msg:
|
||||||
|
console.print(f"[green]{success_msg}[/green]")
|
||||||
|
|
||||||
|
|
||||||
|
def _read_windows_port() -> int | None:
|
||||||
|
from pyra.daemon.ipc import get_port_file_path
|
||||||
|
try:
|
||||||
|
return int(get_port_file_path().read_text().strip())
|
||||||
|
except (FileNotFoundError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _render_daemon_status(data: dict) -> None:
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
uptime = data.get("uptime", 0.0)
|
||||||
|
pid = data.get("pid", "?")
|
||||||
|
tasks = data.get("tasks", [])
|
||||||
|
|
||||||
|
hours, rem = divmod(int(uptime), 3600)
|
||||||
|
mins, secs = divmod(rem, 60)
|
||||||
|
uptime_str = f"{hours}h {mins}m {secs}s" if hours else f"{mins}m {secs}s"
|
||||||
|
|
||||||
|
console.print(f"[bold green]Daemon running[/bold green] — PID {pid}, uptime {uptime_str}")
|
||||||
|
|
||||||
|
if tasks:
|
||||||
|
table = Table("Task", "Alive", "Restarts", "Last error", show_header=True)
|
||||||
|
for t in tasks:
|
||||||
|
alive = "[green]yes[/green]" if t.get("alive") else "[red]no[/red]"
|
||||||
|
error = t.get("last_error") or "—"
|
||||||
|
table.add_row(t.get("name", "?"), alive, str(t.get("restart_count", 0)), error)
|
||||||
|
console.print(table)
|
||||||
|
else:
|
||||||
|
console.print("[dim]No plugin tasks registered.[/dim]")
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ class DaemonConfig(BaseModel):
|
|||||||
socket_path: str = "~/.pyra/daemon.sock"
|
socket_path: str = "~/.pyra/daemon.sock"
|
||||||
log_file: str = "~/.pyra/daemon.log"
|
log_file: str = "~/.pyra/daemon.log"
|
||||||
pid_file: str = "~/.pyra/daemon.pid"
|
pid_file: str = "~/.pyra/daemon.pid"
|
||||||
|
ipc_port: int = 0 # Windows TCP loopback: 0 = OS-assigned, written to ~/.pyra/daemon.port
|
||||||
|
|
||||||
|
|
||||||
class PyraConfig(BaseModel):
|
class PyraConfig(BaseModel):
|
||||||
|
|||||||
@@ -0,0 +1,21 @@
|
|||||||
|
"""Pyra background daemon package."""
|
||||||
|
|
||||||
|
from pyra.daemon.core import PluginSupervisor, run_foreground, start_background
|
||||||
|
from pyra.daemon.ipc import IpcClient, IpcServer, send_command
|
||||||
|
from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path
|
||||||
|
from pyra.daemon.service import detect_platform, install_service, uninstall_service
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"run_foreground",
|
||||||
|
"start_background",
|
||||||
|
"PluginSupervisor",
|
||||||
|
"IpcClient",
|
||||||
|
"IpcServer",
|
||||||
|
"send_command",
|
||||||
|
"PidFile",
|
||||||
|
"PidFileError",
|
||||||
|
"resolve_pid_path",
|
||||||
|
"detect_platform",
|
||||||
|
"install_service",
|
||||||
|
"uninstall_service",
|
||||||
|
]
|
||||||
|
|||||||
@@ -0,0 +1,313 @@
|
|||||||
|
"""Pyra daemon core — asyncio event loop, plugin task supervisor, signal handling."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import logging.handlers
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Coroutine
|
||||||
|
|
||||||
|
from pyra.utils.paths import pyra_home, safe_chmod
|
||||||
|
|
||||||
|
|
||||||
|
_log = logging.getLogger("pyra.daemon")
|
||||||
|
_start_time: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ── Plugin task supervisor ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskRecord:
|
||||||
|
name: str
|
||||||
|
coro_factory: Callable[[], Coroutine] # type: ignore[type-arg]
|
||||||
|
task: asyncio.Task | None = field(default=None, repr=False)
|
||||||
|
restart_count: int = 0
|
||||||
|
last_error: str | None = None
|
||||||
|
|
||||||
|
def is_alive(self) -> bool:
|
||||||
|
return self.task is not None and not self.task.done()
|
||||||
|
|
||||||
|
|
||||||
|
class PluginSupervisor:
|
||||||
|
_RESTART_DELAY: float = 5.0
|
||||||
|
_MAX_RESTARTS: int = 10
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._records: list[TaskRecord] = []
|
||||||
|
self._shutdown = asyncio.Event()
|
||||||
|
|
||||||
|
def add_task(self, name: str, factory: Callable[[], Coroutine]) -> None: # type: ignore[type-arg]
|
||||||
|
self._records.append(TaskRecord(name=name, coro_factory=factory))
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
for record in self._records:
|
||||||
|
record.task = asyncio.create_task(
|
||||||
|
self._supervise(record), name=record.name
|
||||||
|
)
|
||||||
|
_log.info("Supervisor started with %d plugin task(s).", len(self._records))
|
||||||
|
|
||||||
|
async def run_until_shutdown(self) -> None:
|
||||||
|
await self._shutdown.wait()
|
||||||
|
_log.info("Shutdown requested — stopping supervisor.")
|
||||||
|
|
||||||
|
def request_shutdown(self) -> None:
|
||||||
|
self._shutdown.set()
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
for record in self._records:
|
||||||
|
if record.task and not record.task.done():
|
||||||
|
record.task.cancel()
|
||||||
|
tasks = [r.task for r in self._records if r.task]
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
async def reload(self) -> None:
|
||||||
|
"""Cancel all running tasks and restart them with fresh coroutines."""
|
||||||
|
for record in self._records:
|
||||||
|
if record.task and not record.task.done():
|
||||||
|
record.task.cancel()
|
||||||
|
try:
|
||||||
|
await record.task
|
||||||
|
except (asyncio.CancelledError, Exception):
|
||||||
|
pass
|
||||||
|
record.restart_count = 0
|
||||||
|
record.last_error = None
|
||||||
|
record.task = asyncio.create_task(
|
||||||
|
self._supervise(record), name=record.name
|
||||||
|
)
|
||||||
|
_log.info("Reloaded %d plugin task(s).", len(self._records))
|
||||||
|
|
||||||
|
def status(self) -> list[dict]:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": r.name,
|
||||||
|
"alive": r.is_alive(),
|
||||||
|
"restart_count": r.restart_count,
|
||||||
|
"last_error": r.last_error,
|
||||||
|
}
|
||||||
|
for r in self._records
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _supervise(self, record: TaskRecord) -> None:
|
||||||
|
while not self._shutdown.is_set():
|
||||||
|
try:
|
||||||
|
await record.coro_factory()
|
||||||
|
_log.info("Plugin task %s completed normally.", record.name)
|
||||||
|
return
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return
|
||||||
|
except Exception as exc:
|
||||||
|
record.restart_count += 1
|
||||||
|
record.last_error = f"{type(exc).__name__}: {exc}"
|
||||||
|
_log.error(
|
||||||
|
"Plugin task %s crashed (restart #%d): %s",
|
||||||
|
record.name, record.restart_count, exc,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
if record.restart_count >= self._MAX_RESTARTS:
|
||||||
|
_log.critical(
|
||||||
|
"Plugin task %s exceeded max restarts (%d). Giving up.",
|
||||||
|
record.name, self._MAX_RESTARTS,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
asyncio.sleep(self._RESTART_DELAY),
|
||||||
|
timeout=self._RESTART_DELAY + 1,
|
||||||
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# ── IPC command dispatch ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_ipc_handler(supervisor: PluginSupervisor):
|
||||||
|
async def handler(msg: dict) -> dict:
|
||||||
|
cmd = msg.get("cmd", "")
|
||||||
|
match cmd:
|
||||||
|
case "ping":
|
||||||
|
return {"ok": True, "data": {"pong": True}}
|
||||||
|
case "status":
|
||||||
|
return {
|
||||||
|
"ok": True,
|
||||||
|
"data": {
|
||||||
|
"uptime": time.monotonic() - _start_time,
|
||||||
|
"pid": os.getpid(),
|
||||||
|
"tasks": supervisor.status(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
case "stop":
|
||||||
|
supervisor.request_shutdown()
|
||||||
|
return {"ok": True, "data": {}}
|
||||||
|
case "reload":
|
||||||
|
await supervisor.reload()
|
||||||
|
return {"ok": True, "data": {"tasks_reloaded": len(supervisor._records)}}
|
||||||
|
case _:
|
||||||
|
return {"ok": False, "data": {"error": f"unknown command: {cmd}"}}
|
||||||
|
|
||||||
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
# ── Main async entrypoint ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _run_daemon(cfg, supervisor: PluginSupervisor) -> None:
|
||||||
|
from pyra.daemon.ipc import IpcServer, get_socket_path, is_unix_socket
|
||||||
|
|
||||||
|
# Install signal handlers now that the event loop is running.
|
||||||
|
_install_signal_handlers(supervisor)
|
||||||
|
|
||||||
|
if is_unix_socket():
|
||||||
|
address = get_socket_path(cfg.daemon.socket_path)
|
||||||
|
else:
|
||||||
|
address = ("127.0.0.1", cfg.daemon.ipc_port)
|
||||||
|
|
||||||
|
server = IpcServer(address, _make_ipc_handler(supervisor))
|
||||||
|
|
||||||
|
await supervisor.start()
|
||||||
|
|
||||||
|
async with asyncio.TaskGroup() as tg:
|
||||||
|
tg.create_task(server.start(), name="ipc_server")
|
||||||
|
tg.create_task(supervisor.run_until_shutdown(), name="shutdown_waiter")
|
||||||
|
|
||||||
|
await server.stop()
|
||||||
|
await supervisor.stop()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Foreground entry point (pyra daemon run) ──────────────────────────────────
|
||||||
|
|
||||||
|
def run_foreground() -> None:
|
||||||
|
"""Run the daemon in the foreground. Called by `pyra daemon run`."""
|
||||||
|
from pyra.config.manager import load_config
|
||||||
|
from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path
|
||||||
|
from pyra.plugins.registry import PluginRegistry
|
||||||
|
|
||||||
|
global _start_time
|
||||||
|
|
||||||
|
cfg = load_config()
|
||||||
|
_setup_logging(cfg.daemon.log_file)
|
||||||
|
pid_path = resolve_pid_path(cfg.daemon.pid_file)
|
||||||
|
pid_file = PidFile(pid_path)
|
||||||
|
|
||||||
|
existing = pid_file.read()
|
||||||
|
if existing is not None and not pid_file.is_stale():
|
||||||
|
_log.error("Daemon already running (PID %d). Exiting.", existing)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
registry = PluginRegistry()
|
||||||
|
from pyra.utils.paths import pyra_home as _pyra_home
|
||||||
|
plugins_dir = _pyra_home() / "plugins"
|
||||||
|
if plugins_dir.exists():
|
||||||
|
registry.load_all(plugins_dir, cfg.plugins.enabled)
|
||||||
|
|
||||||
|
supervisor = PluginSupervisor()
|
||||||
|
for name, factory in registry.get_daemon_task_factories():
|
||||||
|
supervisor.add_task(name, factory)
|
||||||
|
|
||||||
|
_start_time = time.monotonic()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pid_file:
|
||||||
|
_log.info("Pyra daemon starting (PID %d).", os.getpid())
|
||||||
|
try:
|
||||||
|
asyncio.run(_run_daemon(cfg, supervisor))
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
_log.info("Pyra daemon stopped.")
|
||||||
|
except PidFileError as exc:
|
||||||
|
_log.error("Could not acquire PID file: %s", exc)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Background spawn (pyra daemon start) ─────────────────────────────────────
|
||||||
|
|
||||||
|
def start_background() -> None:
|
||||||
|
"""Spawn `pyra daemon run` as a detached background process."""
|
||||||
|
from pyra.config.manager import load_config
|
||||||
|
from pyra.daemon.pid import PidFile, resolve_pid_path
|
||||||
|
from pyra.daemon.service import find_pyra_executable
|
||||||
|
|
||||||
|
cfg = load_config()
|
||||||
|
pid_path = resolve_pid_path(cfg.daemon.pid_file)
|
||||||
|
pid_file = PidFile(pid_path)
|
||||||
|
|
||||||
|
existing = pid_file.read()
|
||||||
|
if existing is not None and not pid_file.is_stale():
|
||||||
|
from pyra.chat.renderer import console
|
||||||
|
console.print(f"[yellow]Daemon already running (PID {existing}).[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
exe = find_pyra_executable()
|
||||||
|
log_path = Path(cfg.daemon.log_file).expanduser()
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with open(log_path, "a") as log_fh:
|
||||||
|
if sys.platform == "win32":
|
||||||
|
DETACHED_PROCESS = 0x00000008
|
||||||
|
CREATE_NEW_PROCESS_GROUP = 0x00000200
|
||||||
|
subprocess.Popen(
|
||||||
|
[exe, "daemon", "run"],
|
||||||
|
creationflags=DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP,
|
||||||
|
stdout=log_fh,
|
||||||
|
stderr=log_fh,
|
||||||
|
close_fds=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
subprocess.Popen(
|
||||||
|
[exe, "daemon", "run"],
|
||||||
|
start_new_session=True,
|
||||||
|
stdout=log_fh,
|
||||||
|
stderr=log_fh,
|
||||||
|
stdin=subprocess.DEVNULL,
|
||||||
|
close_fds=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pyra.chat.renderer import console
|
||||||
|
|
||||||
|
# Poll the PID file for up to 3 seconds to confirm startup.
|
||||||
|
for _ in range(30):
|
||||||
|
time.sleep(0.1)
|
||||||
|
pid = pid_file.read()
|
||||||
|
if pid is not None:
|
||||||
|
console.print(f"[green]Daemon started (PID {pid}).[/green]")
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print("[yellow]Daemon process spawned but PID file not yet written.[/yellow]")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Signal handling ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _install_signal_handlers(supervisor: PluginSupervisor) -> None:
|
||||||
|
if sys.platform == "win32":
|
||||||
|
signal.signal(signal.SIGTERM, lambda *_: supervisor.request_shutdown())
|
||||||
|
return
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
loop.add_signal_handler(signal.SIGTERM, supervisor.request_shutdown)
|
||||||
|
loop.add_signal_handler(signal.SIGHUP, supervisor.request_shutdown)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Logging setup ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _setup_logging(log_file_str: str) -> None:
|
||||||
|
log_path = Path(log_file_str).expanduser()
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
handler = logging.handlers.RotatingFileHandler(
|
||||||
|
log_path, maxBytes=5 * 1024 * 1024, backupCount=3
|
||||||
|
)
|
||||||
|
handler.setFormatter(
|
||||||
|
logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||||
|
)
|
||||||
|
|
||||||
|
root = logging.getLogger("pyra")
|
||||||
|
root.addHandler(handler)
|
||||||
|
root.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
safe_chmod(log_path, 0o600)
|
||||||
@@ -0,0 +1,241 @@
|
|||||||
|
"""IPC transport for the Pyra daemon.
|
||||||
|
|
||||||
|
Linux/macOS: Unix domain socket at ~/.pyra/daemon.sock (chmod 600, UID-checked).
|
||||||
|
Windows: TCP loopback on an OS-assigned port; actual port written to
|
||||||
|
~/.pyra/daemon.port so clients can connect without knowing it ahead
|
||||||
|
of time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Awaitable, Callable
|
||||||
|
|
||||||
|
|
||||||
|
# ── Protocol types ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
IpcMessage = dict[str, Any] # must have "cmd" key
|
||||||
|
IpcResponse = dict[str, Any] # must have "ok" and "data" keys
|
||||||
|
|
||||||
|
|
||||||
|
# ── Encode / decode ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def encode_message(msg: IpcMessage) -> bytes:
|
||||||
|
return (json.dumps(msg) + "\n").encode()
|
||||||
|
|
||||||
|
|
||||||
|
def decode_message(line: bytes) -> IpcMessage:
|
||||||
|
try:
|
||||||
|
return json.loads(line.rstrip(b"\n"))
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise ValueError(f"Invalid IPC message: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
# ── Address helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def is_unix_socket() -> bool:
|
||||||
|
return sys.platform != "win32"
|
||||||
|
|
||||||
|
|
||||||
|
def get_socket_path(cfg_socket_path: str) -> Path:
|
||||||
|
"""Expand ~ and return the Unix socket path."""
|
||||||
|
return Path(cfg_socket_path).expanduser()
|
||||||
|
|
||||||
|
|
||||||
|
def get_port_file_path() -> Path:
|
||||||
|
from pyra.utils.paths import pyra_home
|
||||||
|
return pyra_home() / "daemon.port"
|
||||||
|
|
||||||
|
|
||||||
|
def _read_windows_port() -> int | None:
|
||||||
|
port_file = get_port_file_path()
|
||||||
|
try:
|
||||||
|
return int(port_file.read_text().strip())
|
||||||
|
except (FileNotFoundError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Server ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class IpcServer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
address: Path | tuple[str, int],
|
||||||
|
handler: Callable[[IpcMessage], Awaitable[IpcResponse]],
|
||||||
|
) -> None:
|
||||||
|
self._address = address
|
||||||
|
self._handler = handler
|
||||||
|
self._server: asyncio.Server | None = None
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
if is_unix_socket():
|
||||||
|
assert isinstance(self._address, Path)
|
||||||
|
sock_path = self._address
|
||||||
|
if sock_path.exists():
|
||||||
|
sock_path.unlink()
|
||||||
|
self._server = await asyncio.start_unix_server(
|
||||||
|
self._handle_client, path=str(sock_path)
|
||||||
|
)
|
||||||
|
os.chmod(sock_path, 0o600)
|
||||||
|
else:
|
||||||
|
host, port = self._address if isinstance(self._address, tuple) else ("127.0.0.1", 0)
|
||||||
|
self._server = await asyncio.start_server(
|
||||||
|
self._handle_client, host=host, port=port
|
||||||
|
)
|
||||||
|
actual_port = self._server.sockets[0].getsockname()[1]
|
||||||
|
port_file = get_port_file_path()
|
||||||
|
port_file.write_text(str(actual_port))
|
||||||
|
|
||||||
|
await self._server.start_serving()
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
if self._server is not None:
|
||||||
|
self._server.close()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._server.wait_closed(), timeout=5.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
if is_unix_socket() and isinstance(self._address, Path):
|
||||||
|
try:
|
||||||
|
self._address.unlink()
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
port_file = get_port_file_path()
|
||||||
|
try:
|
||||||
|
port_file.unlink()
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _handle_client(
|
||||||
|
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
if is_unix_socket() and not self._check_peer_uid(writer):
|
||||||
|
writer.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
line = await asyncio.wait_for(reader.readline(), timeout=5.0)
|
||||||
|
if not line:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
msg = decode_message(line)
|
||||||
|
except ValueError:
|
||||||
|
resp: IpcResponse = {"ok": False, "data": {"error": "invalid JSON"}}
|
||||||
|
else:
|
||||||
|
resp = await self._handler(msg)
|
||||||
|
|
||||||
|
writer.write(encode_message(resp))
|
||||||
|
await writer.drain()
|
||||||
|
except (asyncio.TimeoutError, ConnectionResetError, BrokenPipeError):
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
writer.close()
|
||||||
|
await writer.wait_closed()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _check_peer_uid(self, writer: asyncio.StreamWriter) -> bool:
|
||||||
|
"""Return True if the peer's UID matches ours. Falls back to True on error."""
|
||||||
|
try:
|
||||||
|
peer_uid = _get_peer_uid(writer)
|
||||||
|
if peer_uid is None:
|
||||||
|
return True # can't determine — allow (socket perms are the guard)
|
||||||
|
return peer_uid == os.getuid()
|
||||||
|
except Exception:
|
||||||
|
return True # don't crash the server on unexpected errors
|
||||||
|
|
||||||
|
|
||||||
|
# ── Client ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class IpcClient:
|
||||||
|
def __init__(self, address: Path | tuple[str, int]) -> None:
|
||||||
|
self._address = address
|
||||||
|
|
||||||
|
async def send(self, msg: IpcMessage, timeout: float = 5.0) -> IpcResponse:
|
||||||
|
if is_unix_socket():
|
||||||
|
assert isinstance(self._address, Path)
|
||||||
|
reader, writer = await asyncio.wait_for(
|
||||||
|
asyncio.open_unix_connection(str(self._address)), timeout=timeout
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
host, port = self._address if isinstance(self._address, tuple) else ("127.0.0.1", 0)
|
||||||
|
reader, writer = await asyncio.wait_for(
|
||||||
|
asyncio.open_connection(host, port), timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
writer.write(encode_message(msg))
|
||||||
|
await writer.drain()
|
||||||
|
line = await asyncio.wait_for(reader.readline(), timeout=timeout)
|
||||||
|
return decode_message(line)
|
||||||
|
finally:
|
||||||
|
writer.close()
|
||||||
|
try:
|
||||||
|
await writer.wait_closed()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def send_command(
|
||||||
|
address: Path | tuple[str, int],
|
||||||
|
msg: IpcMessage,
|
||||||
|
timeout: float = 5.0,
|
||||||
|
) -> IpcResponse:
|
||||||
|
"""Synchronous wrapper around IpcClient.send() for CLI callers."""
|
||||||
|
return asyncio.run(IpcClient(address).send(msg, timeout=timeout))
|
||||||
|
|
||||||
|
|
||||||
|
# ── Peer UID detection ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _get_peer_uid(writer: asyncio.StreamWriter) -> int | None:
|
||||||
|
"""Return the connecting peer's UID, or None if unavailable."""
|
||||||
|
try:
|
||||||
|
sock = writer.get_extra_info("socket")
|
||||||
|
if sock is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if sys.platform == "linux":
|
||||||
|
# SO_PEERCRED: struct { pid_t pid; uid_t uid; gid_t gid; }
|
||||||
|
SO_PEERCRED = 17
|
||||||
|
cred = sock.getsockopt(
|
||||||
|
socket_module().SOL_SOCKET, SO_PEERCRED, struct.calcsize("3i")
|
||||||
|
)
|
||||||
|
_pid, uid, _gid = struct.unpack("3i", cred)
|
||||||
|
return uid
|
||||||
|
|
||||||
|
if sys.platform == "darwin":
|
||||||
|
return _macos_peer_uid(sock.fileno())
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def socket_module(): # lazy import to avoid top-level import on non-Unix
|
||||||
|
import socket
|
||||||
|
return socket
|
||||||
|
|
||||||
|
|
||||||
|
def _macos_peer_uid(fd: int) -> int | None:
|
||||||
|
"""Use getpeereid(2) via ctypes to retrieve the peer UID on macOS."""
|
||||||
|
import ctypes
|
||||||
|
import ctypes.util
|
||||||
|
|
||||||
|
libc_name = ctypes.util.find_library("c")
|
||||||
|
if not libc_name:
|
||||||
|
return None
|
||||||
|
libc = ctypes.CDLL(libc_name)
|
||||||
|
|
||||||
|
euid = ctypes.c_uint32(0)
|
||||||
|
egid = ctypes.c_uint32(0)
|
||||||
|
if libc.getpeereid(fd, ctypes.byref(euid), ctypes.byref(egid)) != 0:
|
||||||
|
return None
|
||||||
|
return euid.value
|
||||||
@@ -0,0 +1,94 @@
|
|||||||
|
"""PID file management for the Pyra daemon."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class PidFileError(OSError):
|
||||||
|
"""Raised when a PID file operation fails due to a live conflicting process."""
|
||||||
|
|
||||||
|
|
||||||
|
class PidFile:
|
||||||
|
def __init__(self, path: Path) -> None:
|
||||||
|
self._path = path
|
||||||
|
|
||||||
|
def write(self) -> None:
|
||||||
|
"""Write the current PID atomically.
|
||||||
|
|
||||||
|
Raises PidFileError if a non-stale PID file already exists.
|
||||||
|
"""
|
||||||
|
existing = self.read()
|
||||||
|
if existing is not None and not self.is_stale():
|
||||||
|
raise PidFileError(
|
||||||
|
f"Daemon already running with PID {existing} "
|
||||||
|
f"(PID file: {self._path})"
|
||||||
|
)
|
||||||
|
tmp = self._path.with_suffix(".pid.tmp")
|
||||||
|
tmp.write_text(str(os.getpid()))
|
||||||
|
tmp.replace(self._path)
|
||||||
|
|
||||||
|
def read(self) -> int | None:
|
||||||
|
"""Return the PID from the file, or None if the file is absent or unreadable."""
|
||||||
|
try:
|
||||||
|
return int(self._path.read_text().strip())
|
||||||
|
except (FileNotFoundError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_stale(self) -> bool:
|
||||||
|
"""True when the PID file exists but the process no longer runs."""
|
||||||
|
pid = self.read()
|
||||||
|
if pid is None:
|
||||||
|
return False
|
||||||
|
return not _process_is_alive(pid)
|
||||||
|
|
||||||
|
def remove(self) -> None:
|
||||||
|
"""Delete the PID file, ignoring FileNotFoundError."""
|
||||||
|
try:
|
||||||
|
self._path.unlink()
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __enter__(self) -> "PidFile":
|
||||||
|
self.write()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *_: object) -> None:
|
||||||
|
self.remove()
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_pid_path(cfg_path: str) -> Path:
|
||||||
|
"""Expand ~ and return an absolute Path."""
|
||||||
|
return Path(cfg_path).expanduser().resolve()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Platform-specific process liveness check ─────────────────────────────────
|
||||||
|
|
||||||
|
def _process_is_alive(pid: int) -> bool:
|
||||||
|
if sys.platform == "win32":
|
||||||
|
return _win_process_is_alive(pid)
|
||||||
|
return _posix_process_is_alive(pid)
|
||||||
|
|
||||||
|
|
||||||
|
def _posix_process_is_alive(pid: int) -> bool:
|
||||||
|
try:
|
||||||
|
os.kill(pid, 0)
|
||||||
|
return True
|
||||||
|
except ProcessLookupError:
|
||||||
|
return False
|
||||||
|
except PermissionError:
|
||||||
|
# Process exists but is owned by another user — still alive.
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _win_process_is_alive(pid: int) -> bool:
|
||||||
|
import ctypes
|
||||||
|
|
||||||
|
SYNCHRONIZE = 0x00100000
|
||||||
|
handle = ctypes.windll.kernel32.OpenProcess(SYNCHRONIZE, False, pid) # type: ignore[attr-defined]
|
||||||
|
if handle == 0:
|
||||||
|
return False
|
||||||
|
ctypes.windll.kernel32.CloseHandle(handle) # type: ignore[attr-defined]
|
||||||
|
return True
|
||||||
@@ -0,0 +1,212 @@
|
|||||||
|
"""OS-specific service file generation and install/uninstall for the Pyra daemon."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import platform
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pyra.utils.paths import safe_chmod
|
||||||
|
|
||||||
|
|
||||||
|
def detect_platform() -> Literal["macos", "linux", "windows"]:
|
||||||
|
s = platform.system()
|
||||||
|
if s == "Darwin":
|
||||||
|
return "macos"
|
||||||
|
if s == "Linux":
|
||||||
|
return "linux"
|
||||||
|
if s == "Windows":
|
||||||
|
return "windows"
|
||||||
|
raise RuntimeError(f"Unsupported platform: {s}")
|
||||||
|
|
||||||
|
|
||||||
|
def find_pyra_executable() -> str:
|
||||||
|
"""Return the full path to the active pyra executable.
|
||||||
|
|
||||||
|
Tries, in order:
|
||||||
|
1. shutil.which("pyra") — works when pyra is on PATH (activated venv)
|
||||||
|
2. sys.executable's sibling "pyra" script — covers editable installs
|
||||||
|
3. Fallback: sys.executable -m pyra
|
||||||
|
"""
|
||||||
|
found = shutil.which("pyra")
|
||||||
|
if found:
|
||||||
|
return found
|
||||||
|
|
||||||
|
sibling = Path(sys.executable).parent / "pyra"
|
||||||
|
if sibling.exists():
|
||||||
|
return str(sibling)
|
||||||
|
|
||||||
|
return f"{sys.executable} -m pyra"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Template generators ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def render_launchd_plist(exe: str, log_file: str, pid_file: str) -> str:
|
||||||
|
log = str(Path(log_file).expanduser())
|
||||||
|
return f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN"
|
||||||
|
"http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||||
|
<plist version="1.0">
|
||||||
|
<dict>
|
||||||
|
<key>Label</key>
|
||||||
|
<string>com.pyra.daemon</string>
|
||||||
|
<key>ProgramArguments</key>
|
||||||
|
<array>
|
||||||
|
<string>{exe}</string>
|
||||||
|
<string>daemon</string>
|
||||||
|
<string>run</string>
|
||||||
|
</array>
|
||||||
|
<key>RunAtLoad</key>
|
||||||
|
<true/>
|
||||||
|
<key>KeepAlive</key>
|
||||||
|
<true/>
|
||||||
|
<key>StandardOutPath</key>
|
||||||
|
<string>{log}</string>
|
||||||
|
<key>StandardErrorPath</key>
|
||||||
|
<string>{log}</string>
|
||||||
|
<key>ProcessType</key>
|
||||||
|
<string>Background</string>
|
||||||
|
</dict>
|
||||||
|
</plist>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def render_systemd_unit(exe: str, log_file: str) -> str:
|
||||||
|
log = str(Path(log_file).expanduser())
|
||||||
|
return f"""[Unit]
|
||||||
|
Description=Pyra Personal AI Assistant Daemon
|
||||||
|
After=default.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
ExecStart={exe} daemon run
|
||||||
|
Restart=on-failure
|
||||||
|
RestartSec=5s
|
||||||
|
StandardOutput=append:{log}
|
||||||
|
StandardError=append:{log}
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=default.target
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def render_schtasks_xml(exe: str) -> str:
|
||||||
|
return f"""<?xml version="1.0" encoding="UTF-16"?>
|
||||||
|
<Task version="1.2" xmlns="http://schemas.microsoft.com/windows/2004/02/mit/task">
|
||||||
|
<RegistrationInfo>
|
||||||
|
<Description>Pyra Personal AI Assistant Daemon</Description>
|
||||||
|
</RegistrationInfo>
|
||||||
|
<Triggers>
|
||||||
|
<LogonTrigger>
|
||||||
|
<Enabled>true</Enabled>
|
||||||
|
</LogonTrigger>
|
||||||
|
</Triggers>
|
||||||
|
<Settings>
|
||||||
|
<MultipleInstancesPolicy>IgnoreNew</MultipleInstancesPolicy>
|
||||||
|
<DisallowStartIfOnBatteries>false</DisallowStartIfOnBatteries>
|
||||||
|
<StopIfGoingOnBatteries>false</StopIfGoingOnBatteries>
|
||||||
|
<ExecutionTimeLimit>PT0S</ExecutionTimeLimit>
|
||||||
|
<RestartOnFailure>
|
||||||
|
<Interval>PT1M</Interval>
|
||||||
|
<Count>999</Count>
|
||||||
|
</RestartOnFailure>
|
||||||
|
</Settings>
|
||||||
|
<Actions Context="Author">
|
||||||
|
<Exec>
|
||||||
|
<Command>{exe}</Command>
|
||||||
|
<Arguments>daemon run</Arguments>
|
||||||
|
</Exec>
|
||||||
|
</Actions>
|
||||||
|
</Task>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ── Install / uninstall ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def install_service() -> None:
|
||||||
|
"""Generate and register the OS service for the current platform."""
|
||||||
|
from pyra.config.manager import load_config
|
||||||
|
|
||||||
|
cfg = load_config()
|
||||||
|
exe = find_pyra_executable()
|
||||||
|
plat = detect_platform()
|
||||||
|
|
||||||
|
if plat == "macos":
|
||||||
|
_install_launchd(exe, cfg.daemon.log_file, cfg.daemon.pid_file)
|
||||||
|
elif plat == "linux":
|
||||||
|
_install_systemd(exe, cfg.daemon.log_file)
|
||||||
|
else:
|
||||||
|
_install_windows(exe)
|
||||||
|
|
||||||
|
|
||||||
|
def uninstall_service() -> None:
|
||||||
|
"""Deregister the OS service for the current platform."""
|
||||||
|
plat = detect_platform()
|
||||||
|
if plat == "macos":
|
||||||
|
_uninstall_launchd()
|
||||||
|
elif plat == "linux":
|
||||||
|
_uninstall_systemd()
|
||||||
|
else:
|
||||||
|
_uninstall_windows()
|
||||||
|
|
||||||
|
|
||||||
|
# ── macOS launchd ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_PLIST_PATH = Path.home() / "Library" / "LaunchAgents" / "com.pyra.daemon.plist"
|
||||||
|
|
||||||
|
|
||||||
|
def _install_launchd(exe: str, log_file: str, pid_file: str) -> None:
|
||||||
|
_PLIST_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
_PLIST_PATH.write_text(render_launchd_plist(exe, log_file, pid_file))
|
||||||
|
safe_chmod(_PLIST_PATH, 0o644) # launchd requires 644, not 600
|
||||||
|
subprocess.run(["launchctl", "load", str(_PLIST_PATH)], check=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _uninstall_launchd() -> None:
|
||||||
|
if _PLIST_PATH.exists():
|
||||||
|
subprocess.run(["launchctl", "unload", str(_PLIST_PATH)], check=False)
|
||||||
|
_PLIST_PATH.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Linux systemd ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SYSTEMD_UNIT = Path.home() / ".config" / "systemd" / "user" / "pyra.service"
|
||||||
|
|
||||||
|
|
||||||
|
def _install_systemd(exe: str, log_file: str) -> None:
|
||||||
|
_SYSTEMD_UNIT.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
_SYSTEMD_UNIT.write_text(render_systemd_unit(exe, log_file))
|
||||||
|
subprocess.run(["systemctl", "--user", "daemon-reload"], check=True)
|
||||||
|
subprocess.run(["systemctl", "--user", "enable", "pyra"], check=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _uninstall_systemd() -> None:
|
||||||
|
subprocess.run(
|
||||||
|
["systemctl", "--user", "disable", "--now", "pyra"], check=False
|
||||||
|
)
|
||||||
|
if _SYSTEMD_UNIT.exists():
|
||||||
|
_SYSTEMD_UNIT.unlink()
|
||||||
|
subprocess.run(["systemctl", "--user", "daemon-reload"], check=False)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Windows Task Scheduler ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _install_windows(exe: str) -> None:
|
||||||
|
from pyra.utils.paths import pyra_home
|
||||||
|
|
||||||
|
xml_path = pyra_home() / "daemon_task.xml"
|
||||||
|
# schtasks /Create /XML requires UTF-16 encoding
|
||||||
|
xml_path.write_text(render_schtasks_xml(exe), encoding="utf-16")
|
||||||
|
subprocess.run(
|
||||||
|
["schtasks", "/Create", "/TN", "PyraAssistant", "/XML", str(xml_path), "/F"],
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _uninstall_windows() -> None:
|
||||||
|
subprocess.run(
|
||||||
|
["schtasks", "/Delete", "/TN", "PyraAssistant", "/F"], check=False
|
||||||
|
)
|
||||||
@@ -75,6 +75,32 @@ class PluginRegistry:
|
|||||||
pass
|
pass
|
||||||
return tasks
|
return tasks
|
||||||
|
|
||||||
|
def get_daemon_task_factories(
|
||||||
|
self,
|
||||||
|
) -> list[tuple[str, Callable[[], Coroutine]]]: # type: ignore[type-arg]
|
||||||
|
"""Return (name, factory) pairs for all plugin daemon tasks.
|
||||||
|
|
||||||
|
Each factory re-calls plugin.daemon_tasks() to produce a fresh coroutine,
|
||||||
|
enabling the supervisor to restart crashed tasks without changing the plugin
|
||||||
|
protocol.
|
||||||
|
"""
|
||||||
|
factories: list[tuple[str, Callable[[], Coroutine]]] = [] # type: ignore[type-arg]
|
||||||
|
for plugin in self._plugins.values():
|
||||||
|
try:
|
||||||
|
initial = plugin.daemon_tasks()
|
||||||
|
n_tasks = len(initial)
|
||||||
|
for c in initial:
|
||||||
|
c.close() # prevent "coroutine never awaited" RuntimeWarning
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
for i in range(n_tasks):
|
||||||
|
name = f"{plugin.name}.task_{i}"
|
||||||
|
# Capture plugin and index by value so each closure is independent.
|
||||||
|
def _factory(p=plugin, idx=i) -> Coroutine: # type: ignore[type-arg]
|
||||||
|
return p.daemon_tasks()[idx]
|
||||||
|
factories.append((name, _factory))
|
||||||
|
return factories
|
||||||
|
|
||||||
def find_tool(self, name: str) -> Tool | None:
|
def find_tool(self, name: str) -> Tool | None:
|
||||||
return self._tools.get(name)
|
return self._tools.get(name)
|
||||||
|
|
||||||
|
|||||||
+252
-10
@@ -1,12 +1,16 @@
|
|||||||
|
import contextlib
|
||||||
|
import json
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import questionary
|
import questionary
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
|
|
||||||
from pyra.config.manager import save_config
|
from pyra.config.manager import load_config, save_config
|
||||||
from pyra.config.schema import GeneralConfig, ProviderConfig, PyraConfig
|
from pyra.config.schema import GeneralConfig, ProviderConfig, PyraConfig
|
||||||
from pyra.setup.providers import PROVIDERS, Provider, get_provider
|
from pyra.setup.providers import PROVIDERS, Provider, get_provider
|
||||||
|
from pyra.utils.paths import pyra_home, safe_chmod
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
@@ -20,6 +24,72 @@ _USE_CASE_PLUGINS: dict[str, list[str]] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
_DRAFT_FILE = "setup.draft.json"
|
||||||
|
|
||||||
|
|
||||||
|
def _draft_path():
|
||||||
|
return pyra_home() / _DRAFT_FILE
|
||||||
|
|
||||||
|
|
||||||
|
def _save_draft(state: dict) -> None:
|
||||||
|
path = _draft_path()
|
||||||
|
path.write_text(json.dumps(state, indent=2))
|
||||||
|
safe_chmod(path, 0o600)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_draft() -> dict | None:
|
||||||
|
path = _draft_path()
|
||||||
|
if not path.exists():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return json.loads(path.read_text())
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _delete_draft() -> None:
|
||||||
|
with contextlib.suppress(FileNotFoundError):
|
||||||
|
_draft_path().unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def _mark_step_done(state: dict, step: str) -> None:
|
||||||
|
state.setdefault("completed_steps", [])
|
||||||
|
if step not in state["completed_steps"]:
|
||||||
|
state["completed_steps"].append(step)
|
||||||
|
|
||||||
|
|
||||||
|
def _offer_resume(draft: dict) -> bool:
|
||||||
|
"""Show a summary of the incomplete setup and ask Resume / Start fresh."""
|
||||||
|
completed = draft.get("completed_steps", [])
|
||||||
|
step_display = {
|
||||||
|
"profile": f"Profile: {draft.get('user_name', '?')}",
|
||||||
|
"provider": f"Provider: {draft.get('provider_id', '?')}",
|
||||||
|
"model": f"Model: {draft.get('model', '?')}",
|
||||||
|
"api_key": "API key: stored in vault",
|
||||||
|
"connection": "Connection: test passed",
|
||||||
|
}
|
||||||
|
lines = ["[bold]An incomplete setup was found.[/bold]\n"]
|
||||||
|
for step, label in step_display.items():
|
||||||
|
if step in completed:
|
||||||
|
lines.append(f" [green]✓[/green] {label}")
|
||||||
|
else:
|
||||||
|
lines.append(f" [dim]○ {label.split(':')[0].strip()}: pending[/dim]")
|
||||||
|
|
||||||
|
console.print(Panel("\n".join(lines), title="Incomplete setup", border_style="yellow"))
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
action = questionary.select(
|
||||||
|
"What would you like to do?",
|
||||||
|
choices=[
|
||||||
|
questionary.Choice("Resume from where you left off", value="resume"),
|
||||||
|
questionary.Choice("Start fresh", value="fresh"),
|
||||||
|
],
|
||||||
|
).ask()
|
||||||
|
if action is None:
|
||||||
|
raise SystemExit(0)
|
||||||
|
return action == "resume"
|
||||||
|
|
||||||
|
|
||||||
def run_setup() -> None:
|
def run_setup() -> None:
|
||||||
console.print(Panel(
|
console.print(Panel(
|
||||||
Text("Welcome to Pyra Setup", justify="center", style="bold cyan"),
|
Text("Welcome to Pyra Setup", justify="center", style="bold cyan"),
|
||||||
@@ -28,16 +98,63 @@ def run_setup() -> None:
|
|||||||
))
|
))
|
||||||
console.print()
|
console.print()
|
||||||
|
|
||||||
|
state: dict = {}
|
||||||
|
draft = _load_draft()
|
||||||
|
if draft:
|
||||||
|
if _offer_resume(draft):
|
||||||
|
state = draft
|
||||||
|
else:
|
||||||
|
_delete_draft()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# ── Step 1: profile ────────────────────────────────────────────────
|
||||||
|
if "profile" in state.get("completed_steps", []):
|
||||||
|
user_name = state["user_name"]
|
||||||
|
purpose = state["purpose"]
|
||||||
|
use_cases = state["use_cases"]
|
||||||
|
console.print(f" [dim]✓ Profile: {user_name}[/dim]")
|
||||||
|
else:
|
||||||
user_name, purpose, use_cases = _collect_user_profile()
|
user_name, purpose, use_cases = _collect_user_profile()
|
||||||
|
state.update(user_name=user_name, purpose=purpose, use_cases=use_cases)
|
||||||
|
_mark_step_done(state, "profile")
|
||||||
|
_save_draft(state)
|
||||||
|
|
||||||
|
# ── Step 2: provider ───────────────────────────────────────────────
|
||||||
|
if "provider" in state.get("completed_steps", []):
|
||||||
|
provider = get_provider(state["provider_id"])
|
||||||
|
console.print(f" [dim]✓ Provider: {provider.display_name}[/dim]")
|
||||||
|
else:
|
||||||
provider = _choose_provider()
|
provider = _choose_provider()
|
||||||
|
state.update(provider_id=provider.id)
|
||||||
|
_mark_step_done(state, "provider")
|
||||||
|
_save_draft(state)
|
||||||
|
|
||||||
|
# ── Step 3: model ──────────────────────────────────────────────────
|
||||||
|
if "model" in state.get("completed_steps", []):
|
||||||
|
model = state["model"]
|
||||||
|
console.print(f" [dim]✓ Model: {model}[/dim]")
|
||||||
|
else:
|
||||||
model = _choose_model(provider)
|
model = _choose_model(provider)
|
||||||
|
state.update(model=model)
|
||||||
|
_mark_step_done(state, "model")
|
||||||
|
_save_draft(state)
|
||||||
|
|
||||||
if provider.requires_key:
|
# ── Step 4: API key ────────────────────────────────────────────────
|
||||||
|
if "api_key" not in state.get("completed_steps", []) and provider.requires_key:
|
||||||
|
from pyra.vault.reader import get_key as _get_key
|
||||||
|
if not _get_key(provider.id):
|
||||||
_collect_api_key(provider)
|
_collect_api_key(provider)
|
||||||
|
_mark_step_done(state, "api_key")
|
||||||
|
_save_draft(state)
|
||||||
|
|
||||||
_test_connection(provider, model)
|
# ── Step 5: connection test ────────────────────────────────────────
|
||||||
|
if "connection" not in state.get("completed_steps", []):
|
||||||
|
model = _test_connection(provider, model)
|
||||||
|
state["model"] = model
|
||||||
|
_mark_step_done(state, "connection")
|
||||||
|
_save_draft(state)
|
||||||
|
|
||||||
|
# ── Finalise ───────────────────────────────────────────────────────
|
||||||
cfg = PyraConfig(
|
cfg = PyraConfig(
|
||||||
ai=ProviderConfig(
|
ai=ProviderConfig(
|
||||||
provider_id=provider.id,
|
provider_id=provider.id,
|
||||||
@@ -47,8 +164,10 @@ def run_setup() -> None:
|
|||||||
general=GeneralConfig(user_name=user_name, purpose=purpose),
|
general=GeneralConfig(user_name=user_name, purpose=purpose),
|
||||||
)
|
)
|
||||||
save_config(cfg)
|
save_config(cfg)
|
||||||
|
_delete_draft()
|
||||||
|
|
||||||
_suggest_plugins(use_cases)
|
_suggest_plugins(use_cases)
|
||||||
|
_offer_telegram_setup_if_selected(use_cases)
|
||||||
|
|
||||||
console.print()
|
console.print()
|
||||||
console.print(Panel(
|
console.print(Panel(
|
||||||
@@ -59,6 +178,14 @@ def run_setup() -> None:
|
|||||||
border_style="green",
|
border_style="green",
|
||||||
))
|
))
|
||||||
|
|
||||||
|
except SystemExit:
|
||||||
|
if state.get("completed_steps"):
|
||||||
|
console.print()
|
||||||
|
console.print(
|
||||||
|
" [dim]Setup paused — run [bold]pyra setup[/bold] to resume.[/dim]"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def _collect_user_profile() -> tuple[str, str, list[str]]:
|
def _collect_user_profile() -> tuple[str, str, list[str]]:
|
||||||
console.print("[bold]Let's personalise your setup.[/bold]")
|
console.print("[bold]Let's personalise your setup.[/bold]")
|
||||||
@@ -112,6 +239,68 @@ def _suggest_plugins(use_cases: list[str]) -> None:
|
|||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
|
def _offer_telegram_setup_if_selected(use_cases: list[str]) -> None:
|
||||||
|
"""If the user chose 'Communication bots', offer to install and configure telegram_bot."""
|
||||||
|
relevant = any(
|
||||||
|
"telegram_bot" in _USE_CASE_PLUGINS.get(uc, [])
|
||||||
|
for uc in use_cases
|
||||||
|
)
|
||||||
|
if not relevant:
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print()
|
||||||
|
try:
|
||||||
|
answer = questionary.confirm(
|
||||||
|
"Set up the Telegram bot for remote access to Pyra?", default=True
|
||||||
|
).ask()
|
||||||
|
except (KeyboardInterrupt, EOFError):
|
||||||
|
return
|
||||||
|
if not answer:
|
||||||
|
console.print(
|
||||||
|
" [dim]You can do this later: pyra plugin install telegram_bot[/dim]"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
from pyra.plugins.install import get_bundled_plugins_dir, install_bundled_plugin
|
||||||
|
from pyra.plugins.loader import load_plugin_by_name
|
||||||
|
from pyra.utils.paths import pyra_home
|
||||||
|
from pyra.vault.writer import set_key
|
||||||
|
|
||||||
|
bundled_dir = get_bundled_plugins_dir()
|
||||||
|
plugins_dir = pyra_home() / "plugins"
|
||||||
|
|
||||||
|
try:
|
||||||
|
install_bundled_plugin("telegram_bot", bundled_dir, plugins_dir)
|
||||||
|
except Exception as exc:
|
||||||
|
console.print(f"[red]Could not install telegram_bot:[/red] {exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
p = load_plugin_by_name("telegram_bot", plugins_dir)
|
||||||
|
if p is None:
|
||||||
|
console.print("[red]Could not load telegram_bot for setup.[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
p.setup(console, set_key)
|
||||||
|
except (KeyboardInterrupt, EOFError):
|
||||||
|
console.print("[dim]Telegram setup skipped.[/dim]")
|
||||||
|
return
|
||||||
|
except Exception as exc:
|
||||||
|
console.print(f"[red]Telegram setup error:[/red] {exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
cfg = load_config()
|
||||||
|
if "telegram_bot" not in cfg.plugins.enabled:
|
||||||
|
cfg.plugins.enabled.append("telegram_bot")
|
||||||
|
save_config(cfg)
|
||||||
|
console.print("[green]Telegram bot enabled.[/green]")
|
||||||
|
except Exception:
|
||||||
|
console.print(
|
||||||
|
" [dim]Enable manually: pyra plugin enable telegram_bot[/dim]"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _choose_provider() -> Provider:
|
def _choose_provider() -> Provider:
|
||||||
local = [p for p in PROVIDERS if p.group == "Local"]
|
local = [p for p in PROVIDERS if p.group == "Local"]
|
||||||
cloud = [p for p in PROVIDERS if p.group == "Cloud"]
|
cloud = [p for p in PROVIDERS if p.group == "Cloud"]
|
||||||
@@ -135,6 +324,7 @@ def _choose_provider() -> Provider:
|
|||||||
|
|
||||||
if provider.connectivity_check:
|
if provider.connectivity_check:
|
||||||
_check_local_server(provider)
|
_check_local_server(provider)
|
||||||
|
_show_local_model_status(provider)
|
||||||
|
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
@@ -250,8 +440,44 @@ def _check_local_server(provider: Provider) -> None:
|
|||||||
# "retry" → loop
|
# "retry" → loop
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_loaded_models(provider: Provider) -> list[str]:
|
||||||
|
"""Return models currently loaded in RAM from a local provider's API."""
|
||||||
|
if not provider.base_url:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
if provider.id == "ollama":
|
||||||
|
resp = httpx.get(f"{provider.base_url}/api/ps", timeout=3.0)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return [m["name"] for m in resp.json().get("models", [])]
|
||||||
|
elif provider.id == "lmstudio":
|
||||||
|
resp = httpx.get("http://localhost:1234/api/v0/models", timeout=3.0)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return [
|
||||||
|
m["id"] for m in resp.json().get("data", [])
|
||||||
|
if m.get("state") == "loaded"
|
||||||
|
]
|
||||||
|
else: # llamacpp — /models returns only the active loaded model
|
||||||
|
resp = httpx.get(f"{provider.base_url}/models", timeout=3.0)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return [m["id"] for m in resp.json().get("data", [])]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _show_local_model_status(provider: Provider) -> None:
|
||||||
|
"""Print a one-line status showing which models are currently loaded."""
|
||||||
|
models = fetch_loaded_models(provider)
|
||||||
|
if not models:
|
||||||
|
console.print(" [yellow]No model currently loaded[/yellow]")
|
||||||
|
elif len(models) == 1:
|
||||||
|
console.print(f" [green]Loaded model:[/green] {models[0]}")
|
||||||
|
else:
|
||||||
|
names = ", ".join(models)
|
||||||
|
console.print(f" [green]{len(models)} models loaded:[/green] {names}")
|
||||||
|
|
||||||
|
|
||||||
def _fetch_local_models(provider: Provider) -> list[str]:
|
def _fetch_local_models(provider: Provider) -> list[str]:
|
||||||
"""Return currently loaded/available models from a local provider's API."""
|
"""Return currently loaded models from a local provider's API."""
|
||||||
if not provider.base_url:
|
if not provider.base_url:
|
||||||
return []
|
return []
|
||||||
try:
|
try:
|
||||||
@@ -259,6 +485,13 @@ def _fetch_local_models(provider: Provider) -> list[str]:
|
|||||||
resp = httpx.get(f"{provider.base_url}/api/tags", timeout=3.0)
|
resp = httpx.get(f"{provider.base_url}/api/tags", timeout=3.0)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return [m["name"] for m in resp.json().get("models", [])]
|
return [m["name"] for m in resp.json().get("models", [])]
|
||||||
|
elif provider.id == "lmstudio":
|
||||||
|
resp = httpx.get("http://localhost:1234/api/v0/models", timeout=3.0)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return [
|
||||||
|
m["id"] for m in resp.json().get("data", [])
|
||||||
|
if m.get("state") == "loaded"
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
resp = httpx.get(f"{provider.base_url}/models", timeout=3.0)
|
resp = httpx.get(f"{provider.base_url}/models", timeout=3.0)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
@@ -361,7 +594,7 @@ def _collect_api_key(provider: Provider) -> None:
|
|||||||
console.print(" [green]✓ Key stored in vault[/green]")
|
console.print(" [green]✓ Key stored in vault[/green]")
|
||||||
|
|
||||||
|
|
||||||
def _test_connection(provider: Provider, model: str) -> None:
|
def _test_connection(provider: Provider, model: str) -> str:
|
||||||
from pyra.vault.reader import get_key
|
from pyra.vault.reader import get_key
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -381,7 +614,7 @@ def _test_connection(provider: Provider, model: str) -> None:
|
|||||||
|
|
||||||
litellm.completion(**kwargs)
|
litellm.completion(**kwargs)
|
||||||
console.print("[green]✓ Connection OK[/green]")
|
console.print("[green]✓ Connection OK[/green]")
|
||||||
return
|
return model
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
label, hint = _classify_error(exc)
|
label, hint = _classify_error(exc)
|
||||||
@@ -393,8 +626,15 @@ def _test_connection(provider: Provider, model: str) -> None:
|
|||||||
border_style="red",
|
border_style="red",
|
||||||
))
|
))
|
||||||
|
|
||||||
is_auth_error = "AuthenticationError" in type(exc).__name__
|
exc_name = type(exc).__name__
|
||||||
|
is_auth_error = "AuthenticationError" in exc_name
|
||||||
|
is_model_error = any(
|
||||||
|
kw in exc_name for kw in ("NotFoundError", "BadRequestError", "InvalidRequest")
|
||||||
|
)
|
||||||
|
|
||||||
choices = [questionary.Choice("Retry", value="retry")]
|
choices = [questionary.Choice("Retry", value="retry")]
|
||||||
|
if is_model_error:
|
||||||
|
choices.append(questionary.Choice("Change model", value="change_model"))
|
||||||
if provider.requires_key and is_auth_error:
|
if provider.requires_key and is_auth_error:
|
||||||
choices.append(questionary.Choice("Re-enter API key", value="rekey"))
|
choices.append(questionary.Choice("Re-enter API key", value="rekey"))
|
||||||
choices += [
|
choices += [
|
||||||
@@ -413,7 +653,9 @@ def _test_connection(provider: Provider, model: str) -> None:
|
|||||||
console.print(
|
console.print(
|
||||||
" [dim]Test skipped — run [bold]pyra setup[/bold] again if chat doesn't work.[/dim]"
|
" [dim]Test skipped — run [bold]pyra setup[/bold] again if chat doesn't work.[/dim]"
|
||||||
)
|
)
|
||||||
return
|
return model
|
||||||
if action == "rekey":
|
if action == "change_model":
|
||||||
|
model = _choose_model(provider)
|
||||||
|
elif action == "rekey":
|
||||||
_collect_api_key(provider)
|
_collect_api_key(provider)
|
||||||
# "retry" or after "rekey" → loop
|
# loop → retry (with possibly new model or key)
|
||||||
|
|||||||
@@ -137,3 +137,47 @@ def test_main_skips_setup_when_config_exists(tmp_pyra_home, monkeypatch):
|
|||||||
def test_config_slash_command_registered():
|
def test_config_slash_command_registered():
|
||||||
from pyra.chat.session import _STATIC_COMMANDS
|
from pyra.chat.session import _STATIC_COMMANDS
|
||||||
assert "/config" in _STATIC_COMMANDS
|
assert "/config" in _STATIC_COMMANDS
|
||||||
|
|
||||||
|
|
||||||
|
# ── plugin install ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_plugin_install_decline_setup(tmp_pyra_home, monkeypatch):
|
||||||
|
"""Declining 'Configure now?' shows manual instructions and exits cleanly."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
import questionary
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"pyra.plugins.install.install_bundled_plugin", lambda *a, **kw: None
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
questionary, "confirm",
|
||||||
|
lambda *a, **kw: MagicMock(ask=lambda: False),
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(main, ["plugin", "install", "telegram_bot"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Installed" in result.output
|
||||||
|
assert "Configure later" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_plugin_install_error_does_not_prompt(tmp_pyra_home, monkeypatch):
|
||||||
|
"""If install fails, the configure prompt is never shown."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
import questionary
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"pyra.plugins.install.install_bundled_plugin",
|
||||||
|
lambda *a, **kw: (_ for _ in ()).throw(FileNotFoundError("not found")),
|
||||||
|
)
|
||||||
|
confirm_calls = []
|
||||||
|
monkeypatch.setattr(
|
||||||
|
questionary, "confirm",
|
||||||
|
lambda *a, **kw: confirm_calls.append(1) or MagicMock(ask=lambda: False),
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(main, ["plugin", "install", "telegram_bot"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Error" in result.output
|
||||||
|
assert len(confirm_calls) == 0 # prompt never reached
|
||||||
|
|||||||
@@ -0,0 +1,226 @@
|
|||||||
|
"""Unit tests for the daemon core — PluginSupervisor and IPC handler dispatch."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pyra.daemon.core import PluginSupervisor, _make_ipc_handler
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _drain(n: int = 20) -> None:
|
||||||
|
"""Yield to the event loop n times to let scheduled tasks run."""
|
||||||
|
for _ in range(n):
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
|
# ── PluginSupervisor — lifecycle ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def test_supervisor_empty_starts_and_stops_cleanly() -> None:
|
||||||
|
sup = PluginSupervisor()
|
||||||
|
await sup.start()
|
||||||
|
await sup.stop()
|
||||||
|
assert sup.status() == []
|
||||||
|
|
||||||
|
|
||||||
|
async def test_supervisor_runs_task_to_completion() -> None:
|
||||||
|
done = asyncio.Event()
|
||||||
|
|
||||||
|
async def task():
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
sup = PluginSupervisor()
|
||||||
|
sup._RESTART_DELAY = 0.0
|
||||||
|
sup.add_task("t", task)
|
||||||
|
await sup.start()
|
||||||
|
|
||||||
|
await asyncio.wait_for(done.wait(), timeout=1.0)
|
||||||
|
await sup.stop()
|
||||||
|
|
||||||
|
assert sup._records[0].restart_count == 0
|
||||||
|
assert sup._records[0].last_error is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_supervisor_restarts_crashed_task() -> None:
|
||||||
|
call_count = 0
|
||||||
|
completed = asyncio.Event()
|
||||||
|
|
||||||
|
async def flaky():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
raise RuntimeError("first call fails")
|
||||||
|
completed.set()
|
||||||
|
|
||||||
|
sup = PluginSupervisor()
|
||||||
|
sup._RESTART_DELAY = 0.0
|
||||||
|
sup.add_task("flaky", flaky)
|
||||||
|
await sup.start()
|
||||||
|
|
||||||
|
await asyncio.wait_for(completed.wait(), timeout=1.0)
|
||||||
|
await sup.stop()
|
||||||
|
|
||||||
|
assert sup._records[0].restart_count == 1
|
||||||
|
assert "RuntimeError" in (sup._records[0].last_error or "")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_supervisor_gives_up_after_max_restarts() -> None:
|
||||||
|
async def always_fails():
|
||||||
|
raise ValueError("always")
|
||||||
|
|
||||||
|
sup = PluginSupervisor()
|
||||||
|
sup._RESTART_DELAY = 0.0
|
||||||
|
sup._MAX_RESTARTS = 3
|
||||||
|
sup.add_task("failing", always_fails)
|
||||||
|
await sup.start()
|
||||||
|
|
||||||
|
# Allow enough iterations for 3 restarts + give-up.
|
||||||
|
for _ in range(200):
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
if sup._records[0].task and sup._records[0].task.done():
|
||||||
|
break
|
||||||
|
|
||||||
|
await sup.stop()
|
||||||
|
|
||||||
|
assert sup._records[0].restart_count == 3
|
||||||
|
assert sup._records[0].last_error is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ── PluginSupervisor — status ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def test_supervisor_status_returns_correct_shape() -> None:
|
||||||
|
sup = PluginSupervisor()
|
||||||
|
sup._RESTART_DELAY = 0.0
|
||||||
|
|
||||||
|
async def noop():
|
||||||
|
pass
|
||||||
|
|
||||||
|
sup.add_task("noop", noop)
|
||||||
|
await sup.start()
|
||||||
|
await _drain()
|
||||||
|
|
||||||
|
statuses = sup.status()
|
||||||
|
assert len(statuses) == 1
|
||||||
|
s = statuses[0]
|
||||||
|
assert set(s.keys()) == {"name", "alive", "restart_count", "last_error"}
|
||||||
|
assert s["name"] == "noop"
|
||||||
|
assert isinstance(s["alive"], bool)
|
||||||
|
assert isinstance(s["restart_count"], int)
|
||||||
|
|
||||||
|
await sup.stop()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_supervisor_status_empty_when_no_tasks() -> None:
|
||||||
|
sup = PluginSupervisor()
|
||||||
|
await sup.start()
|
||||||
|
assert sup.status() == []
|
||||||
|
await sup.stop()
|
||||||
|
|
||||||
|
|
||||||
|
# ── PluginSupervisor — reload ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def test_supervisor_reload_restarts_tasks() -> None:
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def counting():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
# Hang until cancelled so reload can cancel it.
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
sup = PluginSupervisor()
|
||||||
|
sup._RESTART_DELAY = 0.0
|
||||||
|
sup.add_task("c", counting)
|
||||||
|
await sup.start()
|
||||||
|
|
||||||
|
await _drain()
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
await sup.reload()
|
||||||
|
await _drain()
|
||||||
|
|
||||||
|
# After reload, the task should have been restarted (called a second time).
|
||||||
|
assert call_count == 2
|
||||||
|
assert sup._records[0].restart_count == 0 # reset by reload
|
||||||
|
|
||||||
|
await sup.stop()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_supervisor_reload_resets_restart_count() -> None:
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def flaky():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count <= 2:
|
||||||
|
raise RuntimeError("crash")
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
sup = PluginSupervisor()
|
||||||
|
sup._RESTART_DELAY = 0.0
|
||||||
|
sup.add_task("f", flaky)
|
||||||
|
await sup.start()
|
||||||
|
|
||||||
|
# Wait for 2 crashes to accumulate.
|
||||||
|
for _ in range(200):
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
if sup._records[0].restart_count >= 2:
|
||||||
|
break
|
||||||
|
|
||||||
|
assert sup._records[0].restart_count == 2
|
||||||
|
|
||||||
|
await sup.reload()
|
||||||
|
# Reload must reset the counter.
|
||||||
|
assert sup._records[0].restart_count == 0
|
||||||
|
|
||||||
|
await sup.stop()
|
||||||
|
|
||||||
|
|
||||||
|
# ── IPC command handler ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def test_ipc_handler_ping() -> None:
|
||||||
|
sup = PluginSupervisor()
|
||||||
|
handler = _make_ipc_handler(sup)
|
||||||
|
resp = await handler({"cmd": "ping"})
|
||||||
|
assert resp["ok"] is True
|
||||||
|
assert resp["data"]["pong"] is True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_ipc_handler_status_shape() -> None:
|
||||||
|
sup = PluginSupervisor()
|
||||||
|
handler = _make_ipc_handler(sup)
|
||||||
|
resp = await handler({"cmd": "status"})
|
||||||
|
assert resp["ok"] is True
|
||||||
|
assert "uptime" in resp["data"]
|
||||||
|
assert "pid" in resp["data"]
|
||||||
|
assert "tasks" in resp["data"]
|
||||||
|
assert isinstance(resp["data"]["tasks"], list)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_ipc_handler_stop_signals_shutdown() -> None:
|
||||||
|
sup = PluginSupervisor()
|
||||||
|
handler = _make_ipc_handler(sup)
|
||||||
|
assert not sup._shutdown.is_set()
|
||||||
|
resp = await handler({"cmd": "stop"})
|
||||||
|
assert resp["ok"] is True
|
||||||
|
assert sup._shutdown.is_set()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_ipc_handler_reload_returns_task_count() -> None:
|
||||||
|
sup = PluginSupervisor()
|
||||||
|
handler = _make_ipc_handler(sup)
|
||||||
|
resp = await handler({"cmd": "reload"})
|
||||||
|
assert resp["ok"] is True
|
||||||
|
assert resp["data"]["tasks_reloaded"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_ipc_handler_unknown_command() -> None:
|
||||||
|
sup = PluginSupervisor()
|
||||||
|
handler = _make_ipc_handler(sup)
|
||||||
|
resp = await handler({"cmd": "bogus"})
|
||||||
|
assert resp["ok"] is False
|
||||||
|
assert "error" in resp["data"]
|
||||||
|
assert "bogus" in resp["data"]["error"]
|
||||||
@@ -0,0 +1,162 @@
|
|||||||
|
"""Unit tests for the IPC layer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pyra.daemon.ipc import (
|
||||||
|
IpcClient,
|
||||||
|
IpcMessage,
|
||||||
|
IpcResponse,
|
||||||
|
IpcServer,
|
||||||
|
decode_message,
|
||||||
|
encode_message,
|
||||||
|
is_unix_socket,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sock_path():
|
||||||
|
"""Short socket path that fits within macOS's 104-char AF_UNIX limit."""
|
||||||
|
with tempfile.TemporaryDirectory(dir="/tmp") as d:
|
||||||
|
yield Path(d) / "t.sock"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Protocol encode / decode ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_encode_appends_newline() -> None:
|
||||||
|
data = encode_message({"cmd": "ping"})
|
||||||
|
assert data.endswith(b"\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_is_valid_json() -> None:
|
||||||
|
import json
|
||||||
|
data = encode_message({"cmd": "status", "extra": 42})
|
||||||
|
assert json.loads(data) == {"cmd": "status", "extra": 42}
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_roundtrip() -> None:
|
||||||
|
msg: IpcMessage = {"cmd": "stop"}
|
||||||
|
assert decode_message(encode_message(msg)) == msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_strips_newline() -> None:
|
||||||
|
assert decode_message(b'{"cmd": "stop"}\n')["cmd"] == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_raises_on_bad_json() -> None:
|
||||||
|
with pytest.raises(ValueError, match="Invalid IPC message"):
|
||||||
|
decode_message(b"not json\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_raises_on_empty_line() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
decode_message(b"\n")
|
||||||
|
|
||||||
|
|
||||||
|
# ── is_unix_socket ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_is_unix_socket_matches_platform() -> None:
|
||||||
|
if sys.platform == "win32":
|
||||||
|
assert not is_unix_socket()
|
||||||
|
else:
|
||||||
|
assert is_unix_socket()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Server + client roundtrip (Unix only) ─────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
|
||||||
|
async def test_server_client_ping(sock_path: Path) -> None:
|
||||||
|
async def handler(msg: IpcMessage) -> IpcResponse:
|
||||||
|
return {"ok": True, "data": {"pong": True}}
|
||||||
|
|
||||||
|
server = IpcServer(sock_path, handler)
|
||||||
|
await server.start()
|
||||||
|
try:
|
||||||
|
resp = await IpcClient(sock_path).send({"cmd": "ping"})
|
||||||
|
assert resp["ok"] is True
|
||||||
|
assert resp["data"]["pong"] is True
|
||||||
|
finally:
|
||||||
|
await server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
|
||||||
|
async def test_server_echoes_error_for_bad_json(sock_path: Path) -> None:
|
||||||
|
async def handler(msg: IpcMessage) -> IpcResponse:
|
||||||
|
return {"ok": True, "data": {}}
|
||||||
|
|
||||||
|
server = IpcServer(sock_path, handler)
|
||||||
|
await server.start()
|
||||||
|
try:
|
||||||
|
reader, writer = await asyncio.open_unix_connection(str(sock_path))
|
||||||
|
writer.write(b"not valid json\n")
|
||||||
|
await writer.drain()
|
||||||
|
line = await asyncio.wait_for(reader.readline(), timeout=3.0)
|
||||||
|
resp = decode_message(line)
|
||||||
|
assert resp["ok"] is False
|
||||||
|
assert "error" in resp["data"]
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
writer.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
await server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
|
||||||
|
async def test_handler_response_returned_to_client(sock_path: Path) -> None:
|
||||||
|
async def handler(msg: IpcMessage) -> IpcResponse:
|
||||||
|
if msg.get("cmd") == "status":
|
||||||
|
return {"ok": True, "data": {"uptime": 99.0}}
|
||||||
|
return {"ok": False, "data": {"error": "unknown"}}
|
||||||
|
|
||||||
|
server = IpcServer(sock_path, handler)
|
||||||
|
await server.start()
|
||||||
|
try:
|
||||||
|
resp = await IpcClient(sock_path).send({"cmd": "status"})
|
||||||
|
assert resp["ok"] is True
|
||||||
|
assert resp["data"]["uptime"] == 99.0
|
||||||
|
|
||||||
|
resp2 = await IpcClient(sock_path).send({"cmd": "bogus"})
|
||||||
|
assert resp2["ok"] is False
|
||||||
|
finally:
|
||||||
|
await server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
|
||||||
|
async def test_client_raises_when_no_server(sock_path: Path) -> None:
|
||||||
|
client = IpcClient(sock_path)
|
||||||
|
with pytest.raises((ConnectionRefusedError, FileNotFoundError, OSError)):
|
||||||
|
await client.send({"cmd": "ping"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
|
||||||
|
async def test_socket_file_chmod_600(sock_path: Path) -> None:
|
||||||
|
async def handler(msg: IpcMessage) -> IpcResponse:
|
||||||
|
return {"ok": True, "data": {}}
|
||||||
|
|
||||||
|
server = IpcServer(sock_path, handler)
|
||||||
|
await server.start()
|
||||||
|
try:
|
||||||
|
mode = oct(sock_path.stat().st_mode & 0o777)
|
||||||
|
assert mode == oct(0o600), f"Expected 0o600, got {mode}"
|
||||||
|
finally:
|
||||||
|
await server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
|
||||||
|
async def test_stop_removes_socket_file(sock_path: Path) -> None:
|
||||||
|
async def handler(msg: IpcMessage) -> IpcResponse:
|
||||||
|
return {"ok": True, "data": {}}
|
||||||
|
|
||||||
|
server = IpcServer(sock_path, handler)
|
||||||
|
await server.start()
|
||||||
|
assert sock_path.exists()
|
||||||
|
await server.stop()
|
||||||
|
assert not sock_path.exists()
|
||||||
@@ -0,0 +1,103 @@
|
|||||||
|
"""Unit tests for daemon PID file management."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_creates_file(tmp_path: Path) -> None:
|
||||||
|
p = PidFile(tmp_path / "daemon.pid")
|
||||||
|
p.write()
|
||||||
|
assert (tmp_path / "daemon.pid").exists()
|
||||||
|
assert int((tmp_path / "daemon.pid").read_text().strip()) == os.getpid()
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_returns_none_when_absent(tmp_path: Path) -> None:
|
||||||
|
p = PidFile(tmp_path / "daemon.pid")
|
||||||
|
assert p.read() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_returns_pid_when_present(tmp_path: Path) -> None:
|
||||||
|
pid_file = tmp_path / "daemon.pid"
|
||||||
|
pid_file.write_text("12345")
|
||||||
|
p = PidFile(pid_file)
|
||||||
|
assert p.read() == 12345
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_returns_none_on_bad_content(tmp_path: Path) -> None:
|
||||||
|
pid_file = tmp_path / "daemon.pid"
|
||||||
|
pid_file.write_text("not-a-number")
|
||||||
|
p = PidFile(pid_file)
|
||||||
|
assert p.read() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_stale_false_for_self(tmp_path: Path) -> None:
|
||||||
|
p = PidFile(tmp_path / "daemon.pid")
|
||||||
|
p.write()
|
||||||
|
assert not p.is_stale()
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_stale_true_for_dead_pid(tmp_path: Path) -> None:
|
||||||
|
pid_file = tmp_path / "daemon.pid"
|
||||||
|
pid_file.write_text("999999999") # unrealistically large PID
|
||||||
|
p = PidFile(pid_file)
|
||||||
|
assert p.is_stale()
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_stale_false_when_file_absent(tmp_path: Path) -> None:
|
||||||
|
p = PidFile(tmp_path / "daemon.pid")
|
||||||
|
assert not p.is_stale()
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_deletes_file(tmp_path: Path) -> None:
|
||||||
|
p = PidFile(tmp_path / "daemon.pid")
|
||||||
|
p.write()
|
||||||
|
p.remove()
|
||||||
|
assert not (tmp_path / "daemon.pid").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_is_idempotent(tmp_path: Path) -> None:
|
||||||
|
p = PidFile(tmp_path / "daemon.pid")
|
||||||
|
p.remove() # must not raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_manager_writes_and_removes(tmp_path: Path) -> None:
|
||||||
|
pid_file = tmp_path / "daemon.pid"
|
||||||
|
p = PidFile(pid_file)
|
||||||
|
with p:
|
||||||
|
assert pid_file.exists()
|
||||||
|
assert int(pid_file.read_text().strip()) == os.getpid()
|
||||||
|
assert not pid_file.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_raises_when_live_pid_exists(tmp_path: Path) -> None:
|
||||||
|
p = PidFile(tmp_path / "daemon.pid")
|
||||||
|
p.write() # writes self PID (which is alive)
|
||||||
|
p2 = PidFile(tmp_path / "daemon.pid")
|
||||||
|
with pytest.raises(PidFileError):
|
||||||
|
p2.write()
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_succeeds_over_stale_pid(tmp_path: Path) -> None:
|
||||||
|
pid_file = tmp_path / "daemon.pid"
|
||||||
|
pid_file.write_text("999999999") # stale
|
||||||
|
p = PidFile(pid_file)
|
||||||
|
p.write() # should not raise
|
||||||
|
assert int(pid_file.read_text().strip()) == os.getpid()
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_pid_path_expands_tilde() -> None:
|
||||||
|
result = resolve_pid_path("~/.pyra/daemon.pid")
|
||||||
|
assert not str(result).startswith("~")
|
||||||
|
assert result.is_absolute()
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_pid_path_absolute_unchanged(tmp_path: Path) -> None:
|
||||||
|
path = tmp_path / "daemon.pid"
|
||||||
|
result = resolve_pid_path(str(path))
|
||||||
|
assert result == path
|
||||||
@@ -0,0 +1,189 @@
|
|||||||
|
"""Unit tests for daemon service file generation and platform detection."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pyra.daemon.service import (
|
||||||
|
detect_platform,
|
||||||
|
find_pyra_executable,
|
||||||
|
render_launchd_plist,
|
||||||
|
render_systemd_unit,
|
||||||
|
render_schtasks_xml,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Template rendering ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_render_launchd_plist_contains_exe() -> None:
|
||||||
|
xml = render_launchd_plist("/usr/local/bin/pyra", "~/.pyra/daemon.log", "~/.pyra/daemon.pid")
|
||||||
|
assert "/usr/local/bin/pyra" in xml
|
||||||
|
assert "<string>daemon</string>" in xml
|
||||||
|
assert "<string>run</string>" in xml
|
||||||
|
assert "com.pyra.daemon" in xml
|
||||||
|
assert "<true/>" in xml # KeepAlive and RunAtLoad
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_launchd_plist_expands_log_tilde() -> None:
|
||||||
|
xml = render_launchd_plist("/bin/pyra", "~/.pyra/daemon.log", "~/.pyra/daemon.pid")
|
||||||
|
assert "~" not in xml
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_systemd_unit_contains_exe() -> None:
|
||||||
|
unit = render_systemd_unit("/usr/local/bin/pyra", "~/.pyra/daemon.log")
|
||||||
|
assert "ExecStart=/usr/local/bin/pyra daemon run" in unit
|
||||||
|
assert "Restart=on-failure" in unit
|
||||||
|
assert "Type=simple" in unit
|
||||||
|
assert "WantedBy=default.target" in unit
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_systemd_unit_expands_log_tilde() -> None:
|
||||||
|
unit = render_systemd_unit("/bin/pyra", "~/.pyra/daemon.log")
|
||||||
|
assert "~" not in unit
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_schtasks_xml_contains_exe() -> None:
|
||||||
|
xml = render_schtasks_xml("C:\\Users\\test\\pyra.exe")
|
||||||
|
assert "C:\\Users\\test\\pyra.exe" in xml
|
||||||
|
assert "LogonTrigger" in xml
|
||||||
|
assert "daemon run" in xml
|
||||||
|
assert "IgnoreNew" in xml
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_schtasks_xml_no_time_limit() -> None:
|
||||||
|
xml = render_schtasks_xml("pyra.exe")
|
||||||
|
assert "PT0S" in xml # ExecutionTimeLimit=PT0S means unlimited
|
||||||
|
|
||||||
|
|
||||||
|
# ── Platform detection ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_detect_platform_returns_known_value() -> None:
|
||||||
|
result = detect_platform()
|
||||||
|
assert result in ("macos", "linux", "windows")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("system,expected", [
|
||||||
|
("Darwin", "macos"),
|
||||||
|
("Linux", "linux"),
|
||||||
|
("Windows", "windows"),
|
||||||
|
])
|
||||||
|
def test_detect_platform_maps_correctly(system: str, expected: str) -> None:
|
||||||
|
with patch("platform.system", return_value=system):
|
||||||
|
assert detect_platform() == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_platform_raises_on_unknown() -> None:
|
||||||
|
with patch("platform.system", return_value="FreeBSD"):
|
||||||
|
with pytest.raises(RuntimeError, match="Unsupported platform"):
|
||||||
|
detect_platform()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Executable detection ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_find_pyra_executable_returns_string() -> None:
|
||||||
|
result = find_pyra_executable()
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_pyra_executable_uses_which_when_available(tmp_path: Path) -> None:
|
||||||
|
fake_pyra = tmp_path / "pyra"
|
||||||
|
fake_pyra.touch()
|
||||||
|
with patch("shutil.which", return_value=str(fake_pyra)):
|
||||||
|
assert find_pyra_executable() == str(fake_pyra)
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_pyra_executable_falls_back_to_sibling(tmp_path: Path) -> None:
|
||||||
|
fake_python = tmp_path / "python3"
|
||||||
|
fake_pyra = tmp_path / "pyra"
|
||||||
|
fake_pyra.touch()
|
||||||
|
with patch("shutil.which", return_value=None):
|
||||||
|
with patch("sys.executable", str(fake_python)):
|
||||||
|
assert find_pyra_executable() == str(fake_pyra)
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_pyra_executable_falls_back_to_module(tmp_path: Path) -> None:
|
||||||
|
fake_python = tmp_path / "python3"
|
||||||
|
with patch("shutil.which", return_value=None):
|
||||||
|
with patch("sys.executable", str(fake_python)):
|
||||||
|
result = find_pyra_executable()
|
||||||
|
assert result == f"{fake_python} -m pyra"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Install / uninstall (subprocess mocked) ───────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="launchd install is macOS-only")
|
||||||
|
def test_install_launchd_writes_plist_and_calls_launchctl(
|
||||||
|
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||||
|
) -> None:
|
||||||
|
import pyra.daemon.service as svc
|
||||||
|
|
||||||
|
plist_path = tmp_path / "Library" / "LaunchAgents" / "com.pyra.daemon.plist"
|
||||||
|
monkeypatch.setattr(svc, "_PLIST_PATH", plist_path)
|
||||||
|
|
||||||
|
calls: list[list[str]] = []
|
||||||
|
monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: calls.append(cmd))
|
||||||
|
|
||||||
|
svc._install_launchd("/usr/local/bin/pyra", "~/.pyra/daemon.log", "~/.pyra/daemon.pid")
|
||||||
|
|
||||||
|
assert plist_path.exists()
|
||||||
|
assert "com.pyra.daemon" in plist_path.read_text()
|
||||||
|
assert any("launchctl" in c[0] for c in calls)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="systemd install is Linux-only")
|
||||||
|
def test_install_systemd_writes_unit_and_calls_systemctl(
|
||||||
|
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||||
|
) -> None:
|
||||||
|
import pyra.daemon.service as svc
|
||||||
|
|
||||||
|
unit_path = tmp_path / ".config" / "systemd" / "user" / "pyra.service"
|
||||||
|
monkeypatch.setattr(svc, "_SYSTEMD_UNIT", unit_path)
|
||||||
|
|
||||||
|
calls: list[list[str]] = []
|
||||||
|
monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: calls.append(cmd))
|
||||||
|
|
||||||
|
svc._install_systemd("/usr/local/bin/pyra", "~/.pyra/daemon.log")
|
||||||
|
|
||||||
|
assert unit_path.exists()
|
||||||
|
assert "ExecStart" in unit_path.read_text()
|
||||||
|
assert any("systemctl" in c[0] for c in calls)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="launchd uninstall is macOS-only")
|
||||||
|
def test_uninstall_launchd_removes_plist(
|
||||||
|
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||||
|
) -> None:
|
||||||
|
import pyra.daemon.service as svc
|
||||||
|
|
||||||
|
plist_path = tmp_path / "Library" / "LaunchAgents" / "com.pyra.daemon.plist"
|
||||||
|
plist_path.parent.mkdir(parents=True)
|
||||||
|
plist_path.write_text("<plist/>")
|
||||||
|
monkeypatch.setattr(svc, "_PLIST_PATH", plist_path)
|
||||||
|
monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: None)
|
||||||
|
|
||||||
|
svc._uninstall_launchd()
|
||||||
|
|
||||||
|
assert not plist_path.exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="systemd uninstall is Linux-only")
|
||||||
|
def test_uninstall_systemd_removes_unit(
|
||||||
|
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||||
|
) -> None:
|
||||||
|
import pyra.daemon.service as svc
|
||||||
|
|
||||||
|
unit_path = tmp_path / ".config" / "systemd" / "user" / "pyra.service"
|
||||||
|
unit_path.parent.mkdir(parents=True)
|
||||||
|
unit_path.write_text("[Service]")
|
||||||
|
monkeypatch.setattr(svc, "_SYSTEMD_UNIT", unit_path)
|
||||||
|
monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: None)
|
||||||
|
|
||||||
|
svc._uninstall_systemd()
|
||||||
|
|
||||||
|
assert not unit_path.exists()
|
||||||
@@ -96,14 +96,34 @@ def test_suggest_plugins_multiple_categories(monkeypatch):
|
|||||||
|
|
||||||
# ── _fetch_local_models ────────────────────────────────────────────────────────
|
# ── _fetch_local_models ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def test_fetch_local_models_lmstudio_returns_model_ids(monkeypatch):
|
def test_fetch_local_models_lmstudio_returns_loaded_model_ids(monkeypatch):
|
||||||
import pyra.setup.wizard as wiz
|
import pyra.setup.wizard as wiz
|
||||||
mock_resp = MagicMock()
|
mock_resp = MagicMock()
|
||||||
mock_resp.json.return_value = {"data": [{"id": "gemma-4b"}, {"id": "llama3"}]}
|
mock_resp.json.return_value = {
|
||||||
|
"data": [
|
||||||
|
{"id": "gemma-4b", "state": "loaded"},
|
||||||
|
{"id": "llama3", "state": "not_loaded"},
|
||||||
|
]
|
||||||
|
}
|
||||||
mock_resp.raise_for_status = lambda: None
|
mock_resp.raise_for_status = lambda: None
|
||||||
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
|
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
|
||||||
from pyra.setup.providers import get_provider
|
from pyra.setup.providers import get_provider
|
||||||
assert wiz._fetch_local_models(get_provider("lmstudio")) == ["gemma-4b", "llama3"]
|
assert wiz._fetch_local_models(get_provider("lmstudio")) == ["gemma-4b"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_local_models_lmstudio_filters_unloaded(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.json.return_value = {
|
||||||
|
"data": [
|
||||||
|
{"id": "model-a", "state": "not_loaded"},
|
||||||
|
{"id": "model-b", "state": "not_loaded"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_resp.raise_for_status = lambda: None
|
||||||
|
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
|
||||||
|
from pyra.setup.providers import get_provider
|
||||||
|
assert wiz._fetch_local_models(get_provider("lmstudio")) == []
|
||||||
|
|
||||||
|
|
||||||
def test_fetch_local_models_ollama_returns_model_names(monkeypatch):
|
def test_fetch_local_models_ollama_returns_model_names(monkeypatch):
|
||||||
@@ -329,3 +349,181 @@ def test_check_local_server_continue_returns(monkeypatch):
|
|||||||
lambda *a, **kw: MagicMock(ask=lambda: "continue"))
|
lambda *a, **kw: MagicMock(ask=lambda: "continue"))
|
||||||
|
|
||||||
wiz._check_local_server(get_provider("lmstudio")) # must return without raising
|
wiz._check_local_server(get_provider("lmstudio")) # must return without raising
|
||||||
|
|
||||||
|
|
||||||
|
# ── draft persistence ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_save_and_load_draft(tmp_pyra_home):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
state = {"completed_steps": ["profile"], "user_name": "Alice"}
|
||||||
|
wiz._save_draft(state)
|
||||||
|
loaded = wiz._load_draft()
|
||||||
|
assert loaded == state
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_draft_returns_none_when_no_file(tmp_pyra_home):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
assert wiz._load_draft() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_draft_removes_file(tmp_pyra_home):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
wiz._save_draft({"completed_steps": []})
|
||||||
|
wiz._delete_draft()
|
||||||
|
assert wiz._load_draft() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_draft_is_idempotent(tmp_pyra_home):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
wiz._delete_draft()
|
||||||
|
wiz._delete_draft()
|
||||||
|
|
||||||
|
|
||||||
|
def test_mark_step_done_appends_once(tmp_pyra_home):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
state = {}
|
||||||
|
wiz._mark_step_done(state, "profile")
|
||||||
|
wiz._mark_step_done(state, "profile")
|
||||||
|
assert state["completed_steps"].count("profile") == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_draft_file_has_correct_permissions(tmp_pyra_home):
|
||||||
|
import os
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
wiz._save_draft({"completed_steps": []})
|
||||||
|
path = wiz._draft_path()
|
||||||
|
if os.name != "nt":
|
||||||
|
mode = oct(path.stat().st_mode)[-3:]
|
||||||
|
assert mode == "600"
|
||||||
|
|
||||||
|
|
||||||
|
# ── fetch_loaded_models ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_fetch_loaded_models_ollama_uses_api_ps(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
from pyra.setup.providers import get_provider
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.json.return_value = {"models": [{"name": "llama3:latest"}, {"name": "mistral"}]}
|
||||||
|
mock_resp.raise_for_status = lambda: None
|
||||||
|
calls = []
|
||||||
|
monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: (calls.append(url), mock_resp)[1])
|
||||||
|
result = wiz.fetch_loaded_models(get_provider("ollama"))
|
||||||
|
assert result == ["llama3:latest", "mistral"]
|
||||||
|
assert any("/api/ps" in u for u in calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_loaded_models_lmstudio_uses_beta_api_and_filters(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
from pyra.setup.providers import get_provider
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.json.return_value = {
|
||||||
|
"data": [
|
||||||
|
{"id": "gemma-4b", "state": "loaded"},
|
||||||
|
{"id": "llama3", "state": "not_loaded"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_resp.raise_for_status = lambda: None
|
||||||
|
calls = []
|
||||||
|
monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: (calls.append(url), mock_resp)[1])
|
||||||
|
result = wiz.fetch_loaded_models(get_provider("lmstudio"))
|
||||||
|
assert result == ["gemma-4b"]
|
||||||
|
assert any("/api/v0/models" in u for u in calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_loaded_models_lmstudio_filters_unloaded(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
from pyra.setup.providers import get_provider
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.json.return_value = {
|
||||||
|
"data": [
|
||||||
|
{"id": "model-a", "state": "not_loaded"},
|
||||||
|
{"id": "model-b", "state": "not_loaded"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_resp.raise_for_status = lambda: None
|
||||||
|
monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: mock_resp)
|
||||||
|
assert wiz.fetch_loaded_models(get_provider("lmstudio")) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_loaded_models_returns_empty_on_error(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
from pyra.setup.providers import get_provider
|
||||||
|
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=Exception("conn refused")))
|
||||||
|
assert wiz.fetch_loaded_models(get_provider("ollama")) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_loaded_models_returns_empty_when_no_base_url():
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
from pyra.setup.providers import Provider
|
||||||
|
provider = Provider(
|
||||||
|
id="test", display_name="Test", requires_key=False,
|
||||||
|
default_model="x", litellm_prefix="openai/", group="Local",
|
||||||
|
)
|
||||||
|
assert wiz.fetch_loaded_models(provider) == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── _show_local_model_status ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_show_local_model_status_one_model(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
from pyra.setup.providers import get_provider
|
||||||
|
monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: ["gemma-4b"])
|
||||||
|
printed = []
|
||||||
|
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
|
||||||
|
wiz._show_local_model_status(get_provider("lmstudio"))
|
||||||
|
assert any("gemma-4b" in s for s in printed)
|
||||||
|
|
||||||
|
|
||||||
|
def test_show_local_model_status_none(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
from pyra.setup.providers import get_provider
|
||||||
|
monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: [])
|
||||||
|
printed = []
|
||||||
|
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
|
||||||
|
wiz._show_local_model_status(get_provider("lmstudio"))
|
||||||
|
assert any("No model" in s for s in printed)
|
||||||
|
|
||||||
|
|
||||||
|
def test_show_local_model_status_multiple(monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
from pyra.setup.providers import get_provider
|
||||||
|
monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: ["a", "b", "c"])
|
||||||
|
printed = []
|
||||||
|
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
|
||||||
|
wiz._show_local_model_status(get_provider("lmstudio"))
|
||||||
|
combined = " ".join(printed)
|
||||||
|
assert "3" in combined or ("a" in combined and "b" in combined)
|
||||||
|
|
||||||
|
|
||||||
|
# ── _test_connection model re-entry ───────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_test_connection_change_model(tmp_pyra_home, monkeypatch):
|
||||||
|
import pyra.setup.wizard as wiz
|
||||||
|
from pyra.setup.providers import get_provider
|
||||||
|
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
class FakeNotFound(Exception):
|
||||||
|
pass
|
||||||
|
FakeNotFound.__module__ = "litellm.exceptions"
|
||||||
|
FakeNotFound.__name__ = "NotFoundError"
|
||||||
|
|
||||||
|
def fake_completion(**kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
raise FakeNotFound("model not found")
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
monkeypatch.setattr(litellm, "completion", fake_completion)
|
||||||
|
monkeypatch.setattr(wiz.console, "print", _silent_print)
|
||||||
|
monkeypatch.setattr(wiz.questionary, "select",
|
||||||
|
lambda *a, **kw: MagicMock(ask=lambda: "change_model"))
|
||||||
|
monkeypatch.setattr(wiz.questionary, "text",
|
||||||
|
lambda *a, **kw: MagicMock(ask=lambda: "new-model"))
|
||||||
|
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: [])
|
||||||
|
monkeypatch.setattr("pyra.vault.reader.get_key", lambda k: "sk-test")
|
||||||
|
|
||||||
|
provider = get_provider("anthropic")
|
||||||
|
result = wiz._test_connection(provider, "old-model")
|
||||||
|
assert result == "new-model"
|
||||||
|
assert call_count["n"] == 2
|
||||||
|
|||||||
@@ -0,0 +1,281 @@
|
|||||||
|
"""Unit tests for the telegram_bot bundled plugin.
|
||||||
|
|
||||||
|
Tests cover pure-logic helpers: rate limiter, SQLite history, auth flow,
|
||||||
|
and tool argument parsing. Handler integration (live Telegram) is not tested.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# ── Import helpers straight from the bundled source ──────────────────────────
|
||||||
|
|
||||||
|
_PLUGIN_PATH = (
|
||||||
|
Path(__file__).parent.parent.parent
|
||||||
|
/ "src/pyra/bundled_plugins/telegram_bot/plugin.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_plugin_module():
|
||||||
|
"""Load plugin.py without requiring python-telegram-bot to be installed."""
|
||||||
|
# Provide stub modules so the top-level imports don't fail in unit tests
|
||||||
|
for mod_name in [
|
||||||
|
"bcrypt",
|
||||||
|
"telegram",
|
||||||
|
"telegram.ext",
|
||||||
|
"litellm",
|
||||||
|
]:
|
||||||
|
if mod_name not in sys.modules:
|
||||||
|
stub = MagicMock()
|
||||||
|
sys.modules[mod_name] = stub
|
||||||
|
# telegram.ext sub-attributes needed at import time
|
||||||
|
if mod_name == "telegram.ext":
|
||||||
|
stub.Application = MagicMock()
|
||||||
|
stub.CallbackQueryHandler = MagicMock()
|
||||||
|
stub.CommandHandler = MagicMock()
|
||||||
|
stub.ContextTypes = MagicMock()
|
||||||
|
stub.MessageHandler = MagicMock()
|
||||||
|
stub.filters = MagicMock()
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location("pyra_plugin_telegram_bot", _PLUGIN_PATH)
|
||||||
|
assert spec and spec.loader
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(mod) # type: ignore[union-attr]
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
_mod = _load_plugin_module()
|
||||||
|
_RateLimiter = _mod._RateLimiter
|
||||||
|
_load_history = _mod._load_history
|
||||||
|
_save_history = _mod._save_history
|
||||||
|
TelegramBotPlugin = _mod.TelegramBotPlugin
|
||||||
|
|
||||||
|
|
||||||
|
# ── Rate limiter ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestRateLimiter:
|
||||||
|
def test_allows_up_to_limit(self):
|
||||||
|
rl = _RateLimiter(per_hour=5)
|
||||||
|
for _ in range(5):
|
||||||
|
assert rl.allow(user_id=1)
|
||||||
|
|
||||||
|
def test_blocks_over_limit(self):
|
||||||
|
rl = _RateLimiter(per_hour=3)
|
||||||
|
for _ in range(3):
|
||||||
|
rl.allow(user_id=1)
|
||||||
|
assert not rl.allow(user_id=1)
|
||||||
|
|
||||||
|
def test_independent_per_user(self):
|
||||||
|
rl = _RateLimiter(per_hour=2)
|
||||||
|
rl.allow(1)
|
||||||
|
rl.allow(1)
|
||||||
|
assert not rl.allow(1)
|
||||||
|
assert rl.allow(2)
|
||||||
|
|
||||||
|
def test_old_entries_expire(self):
|
||||||
|
rl = _RateLimiter(per_hour=1)
|
||||||
|
# Manually populate bucket with an old timestamp
|
||||||
|
from collections import deque
|
||||||
|
rl._buckets[99] = deque([time.monotonic() - 3601])
|
||||||
|
assert rl.allow(99) # old entry removed, new one fits
|
||||||
|
|
||||||
|
def test_multiple_users_independent(self):
|
||||||
|
rl = _RateLimiter(per_hour=1)
|
||||||
|
rl.allow(10)
|
||||||
|
assert not rl.allow(10)
|
||||||
|
assert rl.allow(20)
|
||||||
|
assert rl.allow(30)
|
||||||
|
|
||||||
|
|
||||||
|
# ── History DB ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestHistoryDB:
|
||||||
|
def test_empty_history(self, tmp_path, monkeypatch):
|
||||||
|
db_path = tmp_path / "tg.db"
|
||||||
|
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||||
|
assert _load_history(42) == []
|
||||||
|
|
||||||
|
def test_save_and_load(self, tmp_path, monkeypatch):
|
||||||
|
db_path = tmp_path / "tg.db"
|
||||||
|
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
{"role": "assistant", "content": "hi there"},
|
||||||
|
]
|
||||||
|
_save_history(42, msgs)
|
||||||
|
assert _load_history(42) == msgs
|
||||||
|
|
||||||
|
def test_save_trims_to_max(self, tmp_path, monkeypatch):
|
||||||
|
db_path = tmp_path / "tg.db"
|
||||||
|
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||||
|
monkeypatch.setattr(_mod, "_MAX_HISTORY", 3)
|
||||||
|
msgs = [{"role": "user", "content": str(i)} for i in range(10)]
|
||||||
|
_save_history(1, msgs)
|
||||||
|
loaded = _load_history(1)
|
||||||
|
assert len(loaded) == 3
|
||||||
|
# Should keep the most recent messages
|
||||||
|
assert loaded[-1]["content"] == "9"
|
||||||
|
|
||||||
|
def test_overwrite(self, tmp_path, monkeypatch):
|
||||||
|
db_path = tmp_path / "tg.db"
|
||||||
|
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||||
|
_save_history(1, [{"role": "user", "content": "first"}])
|
||||||
|
_save_history(1, [{"role": "user", "content": "second"}])
|
||||||
|
assert _load_history(1)[0]["content"] == "second"
|
||||||
|
|
||||||
|
def test_separate_chats(self, tmp_path, monkeypatch):
|
||||||
|
db_path = tmp_path / "tg.db"
|
||||||
|
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
|
||||||
|
_save_history(1, [{"role": "user", "content": "chat1"}])
|
||||||
|
_save_history(2, [{"role": "user", "content": "chat2"}])
|
||||||
|
assert _load_history(1)[0]["content"] == "chat1"
|
||||||
|
assert _load_history(2)[0]["content"] == "chat2"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Plugin: on_load and setup ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestPluginLifecycle:
|
||||||
|
def test_on_load_stores_vault_reader(self):
|
||||||
|
plugin = TelegramBotPlugin()
|
||||||
|
reader = MagicMock(return_value=None)
|
||||||
|
plugin.on_load(reader)
|
||||||
|
assert plugin._vault_reader is reader
|
||||||
|
|
||||||
|
def test_daemon_tasks_returns_coroutine(self):
|
||||||
|
import inspect
|
||||||
|
plugin = TelegramBotPlugin()
|
||||||
|
plugin.on_load(MagicMock(return_value=None))
|
||||||
|
tasks = plugin.daemon_tasks()
|
||||||
|
assert len(tasks) == 1
|
||||||
|
assert inspect.iscoroutine(tasks[0])
|
||||||
|
tasks[0].close() # prevent "coroutine never awaited" warning
|
||||||
|
|
||||||
|
def test_get_plugin_factory(self):
|
||||||
|
plugin = _mod.get_plugin()
|
||||||
|
assert isinstance(plugin, TelegramBotPlugin)
|
||||||
|
assert plugin.name == "telegram_bot"
|
||||||
|
|
||||||
|
def test_config_fields(self):
|
||||||
|
plugin = TelegramBotPlugin()
|
||||||
|
fields = plugin.config_fields()
|
||||||
|
assert any(f.key == "rate_limit" for f in fields)
|
||||||
|
|
||||||
|
def _patch_setup(self, token, allowed, pass1, pass2):
|
||||||
|
"""Return a context manager that patches all questionary calls used by setup()."""
|
||||||
|
pw_answers = iter([token, pass1, pass2])
|
||||||
|
return (
|
||||||
|
patch("questionary.password",
|
||||||
|
side_effect=lambda *a, **kw: MagicMock(ask=lambda: next(pw_answers))),
|
||||||
|
patch("questionary.text",
|
||||||
|
return_value=MagicMock(ask=lambda: allowed)),
|
||||||
|
patch("questionary.press_any_key_to_continue",
|
||||||
|
return_value=MagicMock(ask=lambda: None)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_setup_mismatched_passphrase(self):
|
||||||
|
"""setup() writes nothing to the vault when passphrases don't match."""
|
||||||
|
plugin = TelegramBotPlugin()
|
||||||
|
console = MagicMock()
|
||||||
|
vault_writer = MagicMock()
|
||||||
|
|
||||||
|
pw_patch, text_patch, pakc_patch = self._patch_setup(
|
||||||
|
"fake-token", "123456789", "pass1", "pass2"
|
||||||
|
)
|
||||||
|
with pw_patch, text_patch, pakc_patch:
|
||||||
|
plugin.setup(console, vault_writer)
|
||||||
|
|
||||||
|
vault_writer.assert_not_called()
|
||||||
|
console.print.assert_called_with(
|
||||||
|
"[red]Passphrases do not match. Run setup again to retry.[/red]"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_setup_happy_path(self):
|
||||||
|
"""setup() writes all three vault keys when credentials are valid."""
|
||||||
|
plugin = TelegramBotPlugin()
|
||||||
|
console = MagicMock()
|
||||||
|
vault_writer = MagicMock()
|
||||||
|
|
||||||
|
pw_patch, text_patch, pakc_patch = self._patch_setup(
|
||||||
|
"real-token", "111222333", "s3cr3t", "s3cr3t"
|
||||||
|
)
|
||||||
|
with pw_patch, text_patch, pakc_patch:
|
||||||
|
plugin.setup(console, vault_writer)
|
||||||
|
|
||||||
|
calls = {call[0][0]: call[0][1] for call in vault_writer.call_args_list}
|
||||||
|
assert calls.get("plugin:telegram_bot:token") == "real-token"
|
||||||
|
assert calls.get("plugin:telegram_bot:allowed_users") == "111222333"
|
||||||
|
assert "plugin:telegram_bot:passphrase_hash" in calls
|
||||||
|
|
||||||
|
def test_setup_cancelled_on_empty_token(self):
|
||||||
|
"""setup() exits without writing if the token prompt is cancelled."""
|
||||||
|
plugin = TelegramBotPlugin()
|
||||||
|
console = MagicMock()
|
||||||
|
vault_writer = MagicMock()
|
||||||
|
|
||||||
|
with patch("questionary.password", return_value=MagicMock(ask=lambda: None)), \
|
||||||
|
patch("questionary.press_any_key_to_continue",
|
||||||
|
return_value=MagicMock(ask=lambda: None)):
|
||||||
|
plugin.setup(console, vault_writer)
|
||||||
|
|
||||||
|
vault_writer.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Auth session state ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestAuthState:
|
||||||
|
def test_session_defaults(self):
|
||||||
|
plugin = TelegramBotPlugin()
|
||||||
|
session = plugin._sessions.setdefault(
|
||||||
|
99, {"authenticated": False, "awaiting_passphrase": False, "attempts": 0}
|
||||||
|
)
|
||||||
|
assert not session["authenticated"]
|
||||||
|
assert not session["awaiting_passphrase"]
|
||||||
|
assert session["attempts"] == 0
|
||||||
|
|
||||||
|
def test_session_per_chat(self):
|
||||||
|
plugin = TelegramBotPlugin()
|
||||||
|
plugin._sessions[1] = {"authenticated": True, "awaiting_passphrase": False, "attempts": 0}
|
||||||
|
plugin._sessions[2] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 1}
|
||||||
|
assert plugin._sessions[1]["authenticated"]
|
||||||
|
assert not plugin._sessions[2]["authenticated"]
|
||||||
|
assert plugin._sessions[2]["attempts"] == 1
|
||||||
|
|
||||||
|
def test_session_auth_flag(self):
|
||||||
|
plugin = TelegramBotPlugin()
|
||||||
|
plugin._sessions[5] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 0}
|
||||||
|
# Simulate successful auth
|
||||||
|
plugin._sessions[5]["authenticated"] = True
|
||||||
|
plugin._sessions[5]["awaiting_passphrase"] = False
|
||||||
|
assert plugin._sessions[5]["authenticated"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tool argument parsing ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestToolArgParsing:
|
||||||
|
def test_valid_json_string(self):
|
||||||
|
args_raw = '{"query": "hello"}'
|
||||||
|
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||||
|
assert args == {"query": "hello"}
|
||||||
|
|
||||||
|
def test_dict_passthrough(self):
|
||||||
|
args_raw = {"query": "hello"}
|
||||||
|
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||||
|
assert args == {"query": "hello"}
|
||||||
|
|
||||||
|
def test_invalid_json_caught(self):
|
||||||
|
args_raw = "not json"
|
||||||
|
try:
|
||||||
|
json.loads(args_raw)
|
||||||
|
assert False, "Should have raised"
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_result_truncated_to_4000(self):
|
||||||
|
long_result = "x" * 5000
|
||||||
|
assert len(long_result[:4000]) == 4000
|
||||||
Reference in New Issue
Block a user