23 Commits

Author SHA1 Message Date
curo1305 aba28293b7 feat(cli,wizard): auto-offer telegram_bot setup after install; integrate into pyra setup
cli.py — plugin_install() now asks "Configure now?" after a successful install,
runs the plugin's setup wizard, and offers to enable inline. Failing to install
short-circuits before the prompt is shown.

wizard.py — _offer_telegram_setup_if_selected() runs install + wizard + enable
automatically at the end of pyra setup when the user selected "Communication bots".
Adds load_config import (was missing alongside save_config).

Tests: test_plugin_install_decline_setup, test_plugin_install_error_does_not_prompt.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 23:15:29 +02:00
curo1305 f59aa1a758 feat(plugin/telegram_bot): replace bare prompts with 5-step guided setup wizard
Step 1 — create bot via @BotFather (instructions + press-any-key pause)
Step 2 — find Telegram user ID via @userinfobot (instructions + pause)
Step 3 — set session passphrase with security explanation
Step 4 — save all three vault keys, print ✓ confirmations
Step 5 — configuration complete marker

Adds setup cancellation on empty token, updated tests: happy path, mismatch,
and cancel all covered; press_any_key_to_continue calls properly patched.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 23:15:22 +02:00
curo1305 3f30b782d2 feat(plugin/telegram_bot): add Telegram bot plugin — remote Pyra chat over Telegram
Runs as a supervised daemon task (asyncio long-polling). Features:
- Sender allowlist + bcrypt passphrase challenge (3 attempts before lockout)
- Rate limiting: 20 messages/hour per user (in-memory, resets on restart)
- Injection scanning on every incoming message (pyra.security.injection)
- Full AI chat with litellm streaming → progressive Telegram message editing
- Tool-use loop (up to 10 iterations) with inline-button approval (120s timeout)
- Conversation history persisted per chat_id in ~/.pyra/telegram_history.db
- Memory context loaded from ~/.pyra/memory/ as system prompt on first message

Vault keys: plugin:telegram_bot:token, allowed_users, passphrase_hash
Deps: python-telegram-bot>=21.0, bcrypt>=4.0.0 (added to telegram + all-plugins extras)

22 new unit tests covering rate limiter, history DB, plugin lifecycle, and auth state.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 22:45:36 +02:00
curo1305 bde0856979 feat(daemon): Stage 6 daemon infrastructure
Always-on asyncio daemon with IPC socket, OS service install/uninstall
(launchd/systemd/schtasks), and plugin task supervisor.

- daemon/pid.py: atomic PID file, stale detection (POSIX + Windows)
- daemon/ipc.py: Unix socket (chmod 600, UID-checked) on Linux/macOS;
  TCP loopback + port file on Windows; newline-delimited JSON protocol
- daemon/service.py: launchd plist, systemd user unit, schtasks XML;
  auto-detects platform; finds pyra executable via shutil.which
- daemon/core.py: asyncio event loop, PluginSupervisor (per-task
  restart up to 10x with 5s back-off, reload), IPC command dispatch,
  SIGTERM/SIGHUP signal handling via get_running_loop()
- cli.py: all 7 daemon stubs replaced with real commands
- 376 tests passing (13 new supervisor + IPC handler tests)
2026-05-19 16:14:51 +02:00
curo1305 4744cf819b docs: update CLAUDE.md for Stage 6 daemon infrastructure
- Current Status: add Stage 6 daemon infrastructure in progress
- Architecture table: expand daemon/__init__.py stub to all 5 daemon modules
- Code Inventory: add daemon.core, daemon.pid, daemon.ipc, daemon.service
  sections with function signatures and purposes
- Internal classes: add PluginSupervisor and PidFile; expand DaemonConfig

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 16:10:49 +02:00
curo1305 1d5d0387d9 test(daemon): add supervisor and IPC handler tests
13 async tests covering: supervisor lifecycle (start/stop), task
completion, crash-and-restart, max-restart enforcement, status shape,
reload (task restart + counter reset), and IPC handler dispatch for all
4 commands plus unknown commands.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:57:49 +02:00
curo1305 db6ca6ee57 feat(daemon): implement reload, fix PID race condition
- PluginSupervisor.reload(): cancels all running plugin tasks, resets
  restart counters, and re-creates them with fresh coroutines
- IPC reload command now calls supervisor.reload() instead of being a stub
- run_foreground(): wrap PID file acquisition in try/except PidFileError
  to produce a clean error if two daemon starts race on the PID file

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:54:56 +02:00
curo1305 cc24257ab0 fix(daemon): close discarded coroutines in get_daemon_task_factories
The initial daemon_tasks() call to count tasks created coroutines that
were immediately discarded, triggering RuntimeWarning "coroutine never
awaited". Explicitly close them after counting.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:53:06 +02:00
curo1305 68f9007ef0 fix(daemon): install signal handlers inside running event loop
_install_signal_handlers() was called before asyncio.run(), registering
handlers on a throwaway loop that asyncio.get_event_loop() created — so
SIGTERM would never reach the supervisor. Move the call into _run_daemon()
and switch to asyncio.get_running_loop() so handlers are registered on the
actual running loop.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:52:11 +02:00
curo1305 c41ad0afc6 feat(daemon): wire up all 7 daemon CLI commands
start/run/stop/status/restart/install/uninstall now call the real daemon
modules instead of printing stub messages. Includes a Rich status table
for `pyra daemon status` and friendly error messages when config is missing.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:38:50 +02:00
curo1305 3d3ce694b9 feat(daemon): add core asyncio daemon, supervisor, and registry factories
- core.py: asyncio event loop entry point, PluginSupervisor with per-task
  restart (up to 10 times, 5s back-off), IPC dispatch, signal handling
  (SIGTERM/SIGHUP on POSIX), RotatingFileHandler, start_background() helper
- daemon/__init__.py: export public API
- plugins/registry.py: add get_daemon_task_factories() so supervisor can
  restart crashed tasks by re-calling plugin.daemon_tasks()[i]
- config/schema.py: add DaemonConfig.ipc_port for Windows TCP loopback

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:33:57 +02:00
curo1305 d42b8b4a47 feat(daemon): add IPC transport module
Newline-delimited JSON over Unix socket (macOS/Linux, chmod 600, UID-checked
via SO_PEERCRED/getpeereid) with TCP loopback fallback on Windows. Port written
to ~/.pyra/daemon.port for Windows clients. Sync send_command() wrapper for CLI.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:26:25 +02:00
curo1305 513871ef96 feat(daemon): add OS service install/uninstall module
Generates launchd plist (macOS), systemd user unit (Linux), and Task
Scheduler XML (Windows). Auto-detects platform; finds pyra executable
via shutil.which with venv-sibling fallback.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:23:11 +02:00
curo1305 eaed52006f feat(daemon): add PID file management module
Atomic write-then-rename, stale-PID detection via os.kill on POSIX and
ctypes.OpenProcess on Windows, context manager for cleanup on exit.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:22:04 +02:00
curo1305 0e052c4992 fix(setup): correct LM Studio loaded state value to "loaded" not "loaded_instance"
Querying the live /api/v0/models endpoint shows LM Studio uses state="loaded"
for in-memory models (not "loaded_instance"), so the filter never matched.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:00:01 +02:00
curo1305 40aa934431 fix(setup): filter LM Studio models by state == "loaded_instance"
LM Studio's /v1/models returns all downloaded models, not just loaded
ones. Use /api/v0/models with state filtering in both fetch_loaded_models()
and _fetch_local_models() so only RAM-resident models are shown as loaded.
This also restores the _choose_model() fallback that offers downloaded-but-
unloaded models when nothing is active in LM Studio.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 14:54:06 +02:00
curo1305 833d1445f0 feat(setup,chat): detect actually-loaded local model via provider-specific API
For Ollama, /api/tags returns all installed models, not running ones.
Add fetch_loaded_models() using /api/ps for Ollama (and /v1/models for
LM Studio/llama.cpp, which already return only loaded models).

_show_local_model_status() now calls fetch_loaded_models() so the
setup wizard correctly shows only in-memory models for Ollama.

At chat session startup, local providers warn when the configured model
is not currently loaded, or when nothing is loaded at all.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 14:46:54 +02:00
curo1305 cb390ad6af docs: update README and CLAUDE.md to reflect current state
Add daemon subcommands to README command table (Stage 6 stubs), add
Multi-step Planning section, add chat/planner.py to CLAUDE.md
architecture table, add TaskPlanner to internal classes inventory,
and remove stale test count.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 14:35:44 +02:00
curo1305 01655124b5 Update description 2026-05-19 14:28:37 +02:00
curo1305 b3851a2715 test: add tests for draft persistence, model status, and model re-entry
10 new tests covering:
- _save_draft / _load_draft / _delete_draft / _mark_step_done helpers
- draft file permissions (chmod 600)
- _show_local_model_status with zero, one, and multiple loaded models
- _test_connection change_model path (model error → change → retry succeeds)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 13:43:53 +02:00
curo1305 019e8044a9 feat(setup): model re-entry, status indicator, and resumable setup wizard
- _test_connection() now returns the (possibly changed) model name and
  offers a "Change model" option when the error is model-related
- _show_local_model_status() prints which models are currently loaded
  immediately after selecting a local provider
- Draft persistence: each completed wizard step is saved to
  ~/.pyra/setup.draft.json (chmod 600); on the next run a yellow panel
  summarises progress and offers [Resume / Start fresh]; draft is
  deleted on successful completion or Ctrl-C with no completed steps

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 13:43:49 +02:00
curo1305 efc589cc56 test: add tests for _classify_error and _check_local_server retry behaviour
_classify_error: covers all litellm error types (auth, not-found, rate-limit,
service-unavailable, connection, timeout, bad-request), httpx connect and
timeout errors, and generic fallback — using dynamically constructed fake
exception classes to avoid importing litellm in tests.

_check_local_server: covers success, retry-then-success, abort (SystemExit),
and continue-anyway paths via monkeypatched httpx and questionary.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 13:27:04 +02:00
curo1305 9a392410e7 feat(setup): error controller with retry loops in setup wizard
Add _classify_error() that maps litellm/httpx exceptions to human-readable
labels and resolution hints without requiring a top-level litellm import.

_check_local_server() now loops with Retry / Continue anyway / Abort instead
of printing a one-shot warning and silently continuing.

_test_connection() now loops with Retry / Re-enter API key (auth errors only) /
Skip test / Abort instead of printing the raw exception string and falling through.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 13:26:58 +02:00
22 changed files with 3637 additions and 88 deletions
+49 -6
View File
@@ -8,7 +8,8 @@ a plugin/integration system (Stage 2+) and an encrypted vault (Stage 3+).
## Current Status ## Current Status
**Stage 3 — Memory Database: complete** (2026-05-18) **Stage 3 — Memory Database: complete** (2026-05-18)
Next: Stage 4Vault Encryption **Stage 6Daemon infrastructure: in progress** (`feat/daemon` branch)
Next: Stage 4 — Vault Encryption (skipped for now); messaging bots (Stage 6 remainder)
## Project Roadmap ## Project Roadmap
@@ -19,11 +20,11 @@ memory in `~/.pyra/memory/`, and hard security boundaries around the vault.
### Stage 2 — Plugin Framework ✅ COMPLETE ### Stage 2 — Plugin Framework ✅ COMPLETE
- `src/pyra/plugins/` package: `base.py`, `loader.py`, `registry.py`, `executor.py`, `install.py` - `src/pyra/plugins/` package: `base.py`, `loader.py`, `registry.py`, `executor.py`, `install.py`
- `src/pyra/bundled_plugins/` — ships bundled plugin scripts with pyra - `src/pyra/bundled_plugins/` — ships bundled plugin scripts with pyra
- `src/pyra/daemon/` stub (CLI surface only) - `src/pyra/daemon/` stub (CLI surface only; daemon itself is Stage 6)
- Config: `PluginConfig` + `DaemonConfig` added to `PyraConfig` - Config: `PluginConfig` + `DaemonConfig` added to `PyraConfig`
- Bootstrap: `~/.pyra/plugins/` and `~/.pyra/logs/` created on startup - Bootstrap: `~/.pyra/plugins/` and `~/.pyra/logs/` created on startup
- Chat session: AI tool-use loop (up to 10 iterations), approval gate, plugin slash commands - Chat session: AI tool-use loop (up to 10 iterations), approval gate, plugin slash commands
- CLI: `pyra plugin list/install/enable/disable/setup`, `pyra daemon *` stubs - CLI: `pyra plugin list/install/enable/disable/setup`, `pyra daemon *` (stubs at Stage 2; implemented in Stage 6)
### Stage 3 — Memory Database ✅ COMPLETE ### Stage 3 — Memory Database ✅ COMPLETE
- `src/pyra/memory/database.py`: SQLite + FTS5 via `memory_meta` + `memory_fts` tables - `src/pyra/memory/database.py`: SQLite + FTS5 via `memory_meta` + `memory_fts` tables
@@ -99,6 +100,7 @@ the vault under namespaced keys (`plugin:{name}:{key}`).
| `config/manager.py` | ruamel.yaml round-trip config read/write, chmod 600 enforced | | `config/manager.py` | ruamel.yaml round-trip config read/write, chmod 600 enforced |
| `config/dirs.py` | `bootstrap()` — creates `~/.pyra/` tree, checks vault sentinel every startup | | `config/dirs.py` | `bootstrap()` — creates `~/.pyra/` tree, checks vault sentinel every startup |
| `chat/session.py` | prompt_toolkit REPL loop, AI tool-use loop, plugin slash commands | | `chat/session.py` | prompt_toolkit REPL loop, AI tool-use loop, plugin slash commands |
| `chat/planner.py` | `TaskPlanner` — multi-step plan approval loop, per-step AI execution and verification |
| `chat/renderer.py` | Streaming + non-streaming markdown via rich, injection warning panel | | `chat/renderer.py` | Streaming + non-streaming markdown via rich, injection warning panel |
| `chat/history.py` | Conversation list, token budget trimming, tool message support | | `chat/history.py` | Conversation list, token budget trimming, tool message support |
| `memory/database.py` | SQLite+FTS5 — `init_db()`, `upsert()`, `remove()`, `search()`, `list_all()`, `migrate_from_files()` | | `memory/database.py` | SQLite+FTS5 — `init_db()`, `upsert()`, `remove()`, `search()`, `list_all()`, `migrate_from_files()` |
@@ -116,7 +118,11 @@ the vault under namespaced keys (`plugin:{name}:{key}`).
| `plugins/executor.py` | Approval gate: scan args → prompt → execute → scan result → log | | `plugins/executor.py` | Approval gate: scan args → prompt → execute → scan result → log |
| `plugins/install.py` | Copies bundled plugins to `~/.pyra/plugins/` | | `plugins/install.py` | Copies bundled plugins to `~/.pyra/plugins/` |
| `bundled_plugins/` | Standalone plugin scripts shipped with pyra (installed on demand) | | `bundled_plugins/` | Standalone plugin scripts shipped with pyra (installed on demand) |
| `daemon/__init__.py` | Daemon package stub (implementation in Stage 2.4) | | `daemon/pid.py` | Atomic PID file — write, read, stale detection (POSIX + Windows), context manager |
| `daemon/ipc.py` | IPC transport — Unix socket chmod 600 + UID-check (Linux/macOS) or TCP loopback + port file (Windows); newline-delimited JSON protocol |
| `daemon/service.py` | OS service file generation + install/uninstall — launchd plist (macOS), systemd user unit (Linux), schtasks XML (Windows) |
| `daemon/core.py` | asyncio event loop entry point, `PluginSupervisor` (per-task restart, max 10×, 5s back-off, reload), IPC command dispatch, signal handling |
| `daemon/__init__.py` | Public daemon API exports |
### Runtime: `~/.pyra/` ### Runtime: `~/.pyra/`
@@ -243,7 +249,7 @@ uv pip install -e ".[all-plugins]" # Everything
## Running Tests ## Running Tests
```bash ```bash
pytest tests/ -v # all unit + security tests (161 tests) pytest tests/ -v # all unit + security tests
pytest tests/integration/test_lmstudio.py # requires LM Studio at localhost:1234 pytest tests/integration/test_lmstudio.py # requires LM Studio at localhost:1234
``` ```
@@ -492,6 +498,40 @@ Dataclass: `MemoryFile(name, path, category, size_bytes, modified)`
| `list_bundled_plugins` | `(bundled_dir: Path) -> list[str]` | Names of all bundled plugins that have a `manifest.json` | | `list_bundled_plugins` | `(bundled_dir: Path) -> list[str]` | Names of all bundled plugins that have a `manifest.json` |
| `read_manifest` | `(plugin_dir: Path) -> dict` | Reads `manifest.json`; returns `{}` if missing | | `read_manifest` | `(plugin_dir: Path) -> dict` | Reads `manifest.json`; returns `{}` if missing |
#### `daemon.core`
| Function | Signature | Purpose |
|----------|-----------|---------|
| `run_foreground` | `() -> None` | Entry point for `pyra daemon run` — loads config + plugins, writes PID file, runs asyncio loop |
| `start_background` | `() -> None` | Spawns `pyra daemon run` as a detached subprocess (`start_new_session` on POSIX, `DETACHED_PROCESS` on Windows) |
#### `daemon.pid`
| Function | Signature | Purpose |
|----------|-----------|---------|
| `resolve_pid_path` | `(cfg_path: str) -> Path` | Expand `~` and resolve to absolute Path |
#### `daemon.ipc`
| Function | Signature | Purpose |
|----------|-----------|---------|
| `send_command` | `(address, msg, timeout=5.0) -> IpcResponse` | Synchronous CLI helper — `asyncio.run(IpcClient.send(...))` |
| `get_socket_path` | `(cfg: str) -> Path` | Expand `~` and return Unix socket path |
| `is_unix_socket` | `() -> bool` | True on Linux/macOS (`sys.platform != 'nt'`) |
| `get_port_file_path` | `() -> Path` | Path to `~/.pyra/daemon.port` (Windows TCP port file) |
#### `daemon.service`
| Function | Signature | Purpose |
|----------|-----------|---------|
| `detect_platform` | `() -> Literal["macos","linux","windows"]` | Detect current OS |
| `find_pyra_executable` | `() -> str` | `shutil.which("pyra")` → sibling fallback → `sys.executable -m pyra` |
| `install_service` | `() -> None` | Generate + register OS service (reads config for log/pid paths) |
| `uninstall_service` | `() -> None` | Deregister OS service |
| `render_launchd_plist` | `(exe, log_file, pid_file) -> str` | macOS plist template |
| `render_systemd_unit` | `(exe, log_file) -> str` | Linux systemd unit template |
| `render_schtasks_xml` | `(exe) -> str` | Windows Task Scheduler XML template (write as UTF-16) |
#### `chat.renderer` — rendering functions and shared `console` #### `chat.renderer` — rendering functions and shared `console`
Import `console` from here; do not create a second `rich.Console()` in new code. Import `console` from here; do not create a second `rich.Console()` in new code.
@@ -514,7 +554,7 @@ Import `console` from here; do not create a second `rich.Console()` in new code.
| `GeneralConfig` | `config.schema` | `general:` block — `user_name`, `assistant_name` | | `GeneralConfig` | `config.schema` | `general:` block — `user_name`, `assistant_name` |
| `ProviderConfig` | `config.schema` | `ai:` block — `provider_id`, `model`, `base_url` | | `ProviderConfig` | `config.schema` | `ai:` block — `provider_id`, `model`, `base_url` |
| `PluginConfig` | `config.schema` | `plugins:` block — `enabled`, `require_approval`, `log_executions` | | `PluginConfig` | `config.schema` | `plugins:` block — `enabled`, `require_approval`, `log_executions` |
| `DaemonConfig` | `config.schema` | `daemon:` block | | `DaemonConfig` | `config.schema` | `daemon:` block — `enabled`, `socket_path`, `log_file`, `pid_file`, `ipc_port` |
| `MemoryConfig` | `config.schema` | `memory:` block — `max_tokens_in_context`, `auto_load` | | `MemoryConfig` | `config.schema` | `memory:` block — `max_tokens_in_context`, `auto_load` |
| `SecurityConfig` | `config.schema` | `security:` block — `injection_detection`, `log_injections` | | `SecurityConfig` | `config.schema` | `security:` block — `injection_detection`, `log_injections` |
| `ConversationHistory` | `chat.history` | Holds message list; builds API payload via `build_for_api()`; trims to token budget | | `ConversationHistory` | `chat.history` | Holds message list; builds API payload via `build_for_api()`; trims to token budget |
@@ -524,3 +564,6 @@ Import `console` from here; do not create a second `rich.Console()` in new code.
| `Tool` | `plugins.base` | Dataclass — `name`, `description`, `parameters` (JSON Schema), `handler`, `requires_approval` | | `Tool` | `plugins.base` | Dataclass — `name`, `description`, `parameters` (JSON Schema), `handler`, `requires_approval` |
| `PyraPlugin` | `plugins.base` | `@runtime_checkable` Protocol — the plugin interface | | `PyraPlugin` | `plugins.base` | `@runtime_checkable` Protocol — the plugin interface |
| `BasePlugin` | `plugins.base` | Concrete base with no-op defaults; plugins should inherit this | | `BasePlugin` | `plugins.base` | Concrete base with no-op defaults; plugins should inherit this |
| `TaskPlanner` | `chat.planner` | Multi-step plan runner; `make_tool_handler()` returns the callable wired into the chat session; presents plan for user approval, executes each step via litellm with up to 5 tool-use iterations, verifies output before proceeding |
| `PluginSupervisor` | `daemon.core` | asyncio supervisor — `add_task(name, factory)`, `start()`, `stop()`, `reload()`, `status()`; restarts crashed tasks up to 10× with 5s back-off |
| `PidFile` | `daemon.pid` | `write()` (atomic), `read()`, `is_stale()`, `remove()`, context manager; `PidFileError(OSError)` raised when live PID already exists |
+57 -10
View File
@@ -1,7 +1,7 @@
# Pyra # Pyra
A personal AI assistant CLI with vault-first security. Combines multi-provider AI chat with A personal AI assistant CLI with vault-first security. Combines multi-provider AI chat,
long-term memory and (coming) automation skills. long-term memory, and an extensible plugin system.
## Quick Start ## Quick Start
@@ -31,6 +31,17 @@ pyra chat # start talking
| `pyra memory read <name>` | Read a memory file | | `pyra memory read <name>` | Read a memory file |
| `pyra memory write <name> <content>` | Write a memory file | | `pyra memory write <name> <content>` | Write a memory file |
| `pyra memory append <name> <content>` | Append to a memory file | | `pyra memory append <name> <content>` | Append to a memory file |
| `pyra plugin list` | List installed and available plugins |
| `pyra plugin install <name>` | Install a bundled plugin |
| `pyra plugin enable <name>` | Enable an installed plugin |
| `pyra plugin disable <name>` | Disable a plugin (keeps it installed) |
| `pyra plugin setup <name>` | Run a plugin's credential setup wizard |
| `pyra daemon start` | Start the background daemon *(Stage 6, not yet implemented)* |
| `pyra daemon stop` | Stop the running daemon *(Stage 6, not yet implemented)* |
| `pyra daemon status` | Show daemon status *(Stage 6, not yet implemented)* |
| `pyra daemon restart` | Restart the daemon *(Stage 6, not yet implemented)* |
| `pyra daemon install` | Register Pyra as a system service *(Stage 6, not yet implemented)* |
| `pyra daemon uninstall` | Remove the system service *(Stage 6, not yet implemented)* |
### In-chat slash commands ### In-chat slash commands
@@ -38,6 +49,7 @@ pyra chat # start talking
|---------|-------------| |---------|-------------|
| `/help` | Show available commands | | `/help` | Show available commands |
| `/memory list` | List memory files | | `/memory list` | List memory files |
| `/config` | Open the configuration TUI |
| `/clear` | Clear conversation history | | `/clear` | Clear conversation history |
| `/quit` or `/exit` | Exit Pyra | | `/quit` or `/exit` | Exit Pyra |
@@ -48,16 +60,41 @@ pyra chat # start talking
- **Prompt injection scanner** — warns on suspicious AI output, logs to `~/.pyra/security.log` - **Prompt injection scanner** — warns on suspicious AI output, logs to `~/.pyra/security.log`
- **Path sandboxing** — the AI can only reference memory files by name; traversal is blocked - **Path sandboxing** — the AI can only reference memory files by name; traversal is blocked
## Plugins
Pyra has an extensible plugin system. Bundled plugins are shipped with Pyra and installed on
demand; third-party plugins can be dropped into `~/.pyra/plugins/` directly.
Each plugin is a directory containing a `manifest.json` and a `plugin.py`. Plugin credentials
are stored in the vault under namespaced keys (`plugin:<name>:<key>`) — never in `config.yaml`.
```bash
pyra plugin list # see what's available
pyra plugin install <name> # copy a bundled plugin to ~/.pyra/plugins/
pyra plugin setup <name> # enter credentials (stored in vault)
pyra plugin enable <name> # activate for the next chat session
```
## Multi-step Planning
When given a complex task the AI can propose a **multi-step plan** using the built-in
`plan_and_execute` tool. Pyra prints the plan and asks for approval before executing
anything. Each step runs as a separate AI call with access to enabled plugin tools; each
result is verified before moving on to the next step. You can decline the plan or
interrupt at any point.
## Memory ## Memory
Pyra reads your memory files at the start of each session and injects them as context. Pyra reads your memory files at the start of each session and injects them as context.
Files are plain Markdown stored in `~/.pyra/memory/`: Files are plain Markdown stored in `~/.pyra/memory/`, indexed by a SQLite full-text search
database (`memory.db`) for fast in-chat lookup.
``` ```
~/.pyra/memory/ ~/.pyra/memory/
├── user/profile.md ← who you are ├── user/profile.md ← who you are
├── context/ ← ongoing projects ├── context/ ← ongoing projects
── knowledge/ ← general notes ── knowledge/ ← general notes
└── memory.db ← FTS5 search index (auto-managed)
``` ```
## `~/.pyra/` Directory ## `~/.pyra/` Directory
@@ -67,15 +104,25 @@ Files are plain Markdown stored in `~/.pyra/memory/`:
├── config.yaml ← provider + model (no secrets) ├── config.yaml ← provider + model (no secrets)
├── security.log ← injection event log ├── security.log ← injection event log
├── memory/ ← AI-readable long-term memory ├── memory/ ← AI-readable long-term memory
├── skills/ ← automation scripts (Stage 2) │ └── memory.db ← SQLite FTS5 search index
├── plugins/ ← installed plugins
│ └── <name>/
│ ├── manifest.json
│ └── plugin.py
├── logs/ ← execution logs
│ ├── tool_executions.log
│ └── plugin_errors.log
└── vault/ ← secure, AI-inaccessible storage └── vault/ ← secure, AI-inaccessible storage
└── secrets/api_keys.json └── secrets/api_keys.json
``` ```
## Roadmap ## Roadmap
- **Stage 1** (now): Core CLI, multi-provider chat, memory, vault security - **Stage 1** Core CLI multi-provider chat, memory, vault security
- **Stage 2**: Skills — shell/PowerShell/Python automations with user approval gates - **Stage 2** ✅ Plugin Framework — extensible tools, slash commands, approval gates
- **Stage 3**: Vault encryption with `age` - **Stage 3** ✅ Memory Database — SQLite + FTS5 full-text search index
- **Stage 4**: Security audit sub-agent - **Stage 4** Vault Encryption — `age`-based encryption of `~/.pyra/vault/secrets/`
- **Stage 5**: Web UI, embedding-based memory search - **Stage 5** Skills System — YAML-defined multi-plugin workflows with event triggers
- **Stage 6** Daemon + Messaging Bots — always-on asyncio daemon, Matrix/Telegram/Signal bots
- **Stage 7** Security Audit Sub-agent — automated scanning for injection, CVEs, permission drift
- **Stage 8** Web UI — optional local interface, embedding-based memory search
+2 -2
View File
@@ -28,7 +28,7 @@ dev = [
] ]
nextcloud = ["caldav>=1.3.0", "webdav4>=0.9.0", "vobject>=0.9.6"] nextcloud = ["caldav>=1.3.0", "webdav4>=0.9.0", "vobject>=0.9.6"]
matrix = ["matrix-nio>=0.24.0", "aiofiles>=23.0.0"] matrix = ["matrix-nio>=0.24.0", "aiofiles>=23.0.0"]
telegram = ["python-telegram-bot>=21.0"] telegram = ["python-telegram-bot>=21.0", "bcrypt>=4.0.0"]
ssh = ["paramiko>=3.4.0"] ssh = ["paramiko>=3.4.0"]
docker = ["docker>=7.0.0"] docker = ["docker>=7.0.0"]
gdrive = ["google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0"] gdrive = ["google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0"]
@@ -38,7 +38,7 @@ daemon = ["aiofiles>=23.0.0"]
all-plugins = [ all-plugins = [
"caldav>=1.3.0", "webdav4>=0.9.0", "vobject>=0.9.6", "caldav>=1.3.0", "webdav4>=0.9.0", "vobject>=0.9.6",
"matrix-nio>=0.24.0", "aiofiles>=23.0.0", "matrix-nio>=0.24.0", "aiofiles>=23.0.0",
"python-telegram-bot>=21.0", "python-telegram-bot>=21.0", "bcrypt>=4.0.0",
"paramiko>=3.4.0", "paramiko>=3.4.0",
"docker>=7.0.0", "docker>=7.0.0",
"google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0", "google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0",
@@ -0,0 +1,7 @@
{
"name": "telegram_bot",
"version": "1.0.0",
"description": "Remote Pyra chat over Telegram — full AI chat with tool approval via inline buttons",
"author": "pyra",
"requires": ["python-telegram-bot>=21.0", "bcrypt>=4.0.0"]
}
@@ -0,0 +1,665 @@
"""Telegram bot plugin — remote Pyra chat over Telegram.
Runs as a daemon task (long-polling). Each chat session requires passphrase
authentication and is rate-limited to 20 messages/hour. Incoming messages are
injection-scanned before reaching the AI. Tool calls are approved via Telegram
inline keyboard buttons (2-minute timeout). AI responses are streamed
progressively by editing a placeholder message.
Vault keys used:
plugin:telegram_bot:token — bot token from @BotFather
plugin:telegram_bot:allowed_users — comma-separated Telegram user IDs
plugin:telegram_bot:passphrase_hash — bcrypt hash of the session passphrase
"""
from __future__ import annotations
import asyncio
import json
import logging
import sqlite3
import time
import uuid
from collections import deque
from typing import Any, Callable
import bcrypt
import litellm
from telegram import Bot, InlineKeyboardButton, InlineKeyboardMarkup, Update
from telegram.ext import (
Application,
CallbackQueryHandler,
CommandHandler,
ContextTypes,
MessageHandler,
filters,
)
from pyra.plugins.base import BasePlugin, ConfigField
from pyra.utils.paths import pyra_home
_log = logging.getLogger("pyra.plugin.telegram_bot")
_HISTORY_DB = pyra_home() / "telegram_history.db"
_MAX_HISTORY = 40 # messages kept per chat
_RATE_LIMIT = 20 # messages per hour per user
_APPROVAL_TIMEOUT = 120 # seconds to wait for inline button press
_EDIT_INTERVAL = 1.5 # minimum seconds between progressive message edits
_MAX_TOOL_ITER = 10
_MAX_MSG_LEN = 4096 # Telegram hard limit
# ── SQLite history ────────────────────────────────────────────────────────────
def _open_db() -> sqlite3.Connection:
conn = sqlite3.connect(_HISTORY_DB)
conn.execute("""
CREATE TABLE IF NOT EXISTS sessions (
chat_id INTEGER PRIMARY KEY,
history TEXT NOT NULL DEFAULT '[]',
updated REAL NOT NULL
)
""")
conn.commit()
try:
import os
os.chmod(_HISTORY_DB, 0o600)
except Exception:
pass
return conn
def _load_history(chat_id: int) -> list[dict]:
conn = _open_db()
row = conn.execute(
"SELECT history FROM sessions WHERE chat_id = ?", (chat_id,)
).fetchone()
conn.close()
return json.loads(row[0]) if row else []
def _save_history(chat_id: int, messages: list[dict]) -> None:
trimmed = messages[-_MAX_HISTORY:]
conn = _open_db()
conn.execute(
"INSERT OR REPLACE INTO sessions (chat_id, history, updated) VALUES (?,?,?)",
(chat_id, json.dumps(trimmed), time.time()),
)
conn.commit()
conn.close()
# ── Rate limiter ──────────────────────────────────────────────────────────────
class _RateLimiter:
def __init__(self, per_hour: int = _RATE_LIMIT) -> None:
self._buckets: dict[int, deque] = {}
self._limit = per_hour
def allow(self, user_id: int) -> bool:
now = time.monotonic()
bucket = self._buckets.setdefault(user_id, deque())
cutoff = now - 3600
while bucket and bucket[0] < cutoff:
bucket.popleft()
if len(bucket) >= self._limit:
return False
bucket.append(now)
return True
# ── Plugin ────────────────────────────────────────────────────────────────────
class TelegramBotPlugin(BasePlugin):
name = "telegram_bot"
description = "Remote Pyra chat over Telegram (daemon task, long-polling)"
version = "1.0.0"
def __init__(self) -> None:
self._vault_reader: Callable[[str], str | None] | None = None
self._rate_limiter = _RateLimiter()
# chat_id -> {authenticated, awaiting_passphrase, attempts}
self._sessions: dict[int, dict] = {}
# short call_id -> asyncio.Future[bool]
self._pending_approvals: dict[str, asyncio.Future] = {}
# ── Plugin lifecycle ──────────────────────────────────────────────────────
def on_load(self, vault_reader: Callable[[str], str | None]) -> None:
self._vault_reader = vault_reader
def setup(self, console: Any, vault_writer: Callable[[str, str], None]) -> None:
import questionary
from rich.panel import Panel
from rich.rule import Rule
console.print()
console.print(Panel(
"[bold]Telegram Bot Setup Wizard[/bold]\n\n"
"This wizard connects Pyra to Telegram so you can chat with your\n"
"assistant from anywhere. You will need Telegram open on your phone\n"
"or desktop to complete the next steps.",
border_style="cyan",
))
# ── Step 1: Create bot ────────────────────────────────────────────────
console.print()
console.print(Rule("[bold cyan]Step 1 / 5[/bold cyan] Create your Telegram bot"))
console.print()
console.print(
" 1. Open Telegram and search for [bold]@BotFather[/bold]\n"
" 2. Send [bold]/newbot[/bold] and follow the prompts\n"
" 3. Choose a display name (e.g. [dim]My Pyra Assistant[/dim])\n"
" 4. Choose a username ending in [bold]bot[/bold] "
"(e.g. [dim]my_pyra_bot[/dim])\n"
" 5. BotFather replies with a token that looks like:\n"
" [dim]123456789:AABBccDDeeFFggHHiiJJkkLL[/dim]"
)
console.print()
questionary.press_any_key_to_continue(
" Press any key when you have your token ready ..."
).ask()
console.print()
token = questionary.password(" Bot token:").ask()
if not token or not token.strip():
console.print("[dim]Setup cancelled.[/dim]")
return
token = token.strip()
# ── Step 2: Find user ID ──────────────────────────────────────────────
console.print()
console.print(Rule("[bold cyan]Step 2 / 5[/bold cyan] Find your Telegram user ID"))
console.print()
console.print(
" Your user ID is a permanent number that identifies your account.\n"
" It never changes, even if you change your username.\n\n"
" 1. Search for [bold]@userinfobot[/bold] in Telegram\n"
" 2. Send any message (e.g. [dim]/start[/dim])\n"
" 3. Copy the [bold]Id:[/bold] number from the reply "
"(e.g. [dim]123456789[/dim])"
)
console.print()
questionary.press_any_key_to_continue(
" Press any key when you have your user ID ready ..."
).ask()
console.print()
allowed = questionary.text(
" Allowed Telegram user IDs (comma-separated, leave blank to allow anyone):"
).ask()
if allowed is None:
console.print("[dim]Setup cancelled.[/dim]")
return
# ── Step 3: Session passphrase ────────────────────────────────────────
console.print()
console.print(Rule("[bold cyan]Step 3 / 5[/bold cyan] Set a session passphrase"))
console.print()
console.print(
" The passphrase is an extra layer of security. Every new chat\n"
" session must pass this challenge before Pyra responds — even\n"
" if someone else gains access to your Telegram account."
)
console.print()
passphrase = questionary.password(" Session passphrase:").ask()
if not passphrase:
console.print("[dim]Setup cancelled.[/dim]")
return
confirm = questionary.password(" Confirm passphrase:").ask()
if passphrase != confirm:
console.print("[red]Passphrases do not match. Run setup again to retry.[/red]")
return
# ── Step 4: Save to vault ─────────────────────────────────────────────
console.print()
console.print(Rule("[bold cyan]Step 4 / 5[/bold cyan] Saving configuration"))
console.print()
pw_hash = bcrypt.hashpw(passphrase.encode(), bcrypt.gensalt()).decode()
vault_writer("plugin:telegram_bot:token", token)
vault_writer("plugin:telegram_bot:allowed_users", (allowed or "").strip())
vault_writer("plugin:telegram_bot:passphrase_hash", pw_hash)
allowed_display = (allowed or "").strip() or "[dim](any user — consider restricting)[/dim]"
console.print(f" [green]✓[/green] Bot token stored in vault")
console.print(f" [green]✓[/green] Allowed users: {allowed_display}")
console.print(f" [green]✓[/green] Passphrase stored as bcrypt hash")
# ── Step 5: Done ──────────────────────────────────────────────────────
console.print()
console.print(Rule("[bold cyan]Step 5 / 5[/bold cyan] Configuration complete"))
console.print()
def config_fields(self) -> list[ConfigField]:
return [
ConfigField(
"rate_limit",
"Rate limit (messages/hour)",
"text",
str(_RATE_LIMIT),
description="Maximum messages per hour per Telegram user",
),
]
def daemon_tasks(self) -> list:
return [self._run_polling()]
# ── Daemon task ───────────────────────────────────────────────────────────
async def _run_polling(self) -> None:
assert self._vault_reader is not None
token = self._vault_reader("plugin:telegram_bot:token")
if not token:
_log.error(
"Telegram bot token not set. Run `pyra plugin setup telegram_bot`."
)
return
passphrase_hash = self._vault_reader("plugin:telegram_bot:passphrase_hash") or ""
allowed_str = self._vault_reader("plugin:telegram_bot:allowed_users") or ""
allowed_users: set[int] = {
int(uid.strip())
for uid in allowed_str.split(",")
if uid.strip().isdigit()
}
app = Application.builder().token(token).build()
plugin = self # closure reference
async def _on_start(update: Update, ctx: ContextTypes.DEFAULT_TYPE) -> None:
if update.effective_user and update.effective_user.id in allowed_users:
await update.message.reply_text(
"Pyra is online. Send any message to authenticate."
)
async def _on_message(update: Update, ctx: ContextTypes.DEFAULT_TYPE) -> None:
await plugin._handle_message(update, ctx, allowed_users, passphrase_hash)
async def _on_approval(update: Update, ctx: ContextTypes.DEFAULT_TYPE) -> None:
await plugin._handle_approval_callback(update)
app.add_handler(CommandHandler("start", _on_start))
app.add_handler(CallbackQueryHandler(_on_approval))
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, _on_message))
_log.info("Telegram bot starting (long-polling).")
await app.initialize()
try:
await app.start()
await app.updater.start_polling(drop_pending_updates=True)
_log.info("Telegram bot is polling for updates.")
await asyncio.Event().wait() # block until CancelledError
except asyncio.CancelledError:
_log.info("Telegram bot shutting down.")
finally:
try:
await app.updater.stop()
await app.stop()
await app.shutdown()
except Exception as exc:
_log.warning("Error during Telegram bot shutdown: %s", exc)
# ── Message handler ───────────────────────────────────────────────────────
async def _handle_message(
self,
update: Update,
ctx: ContextTypes.DEFAULT_TYPE,
allowed_users: set[int],
passphrase_hash: str,
) -> None:
if update.effective_user is None or update.message is None:
return
user_id = update.effective_user.id
chat_id = update.effective_chat.id if update.effective_chat else user_id
# Allowlist — silently ignore unknown senders
if allowed_users and user_id not in allowed_users:
return
text = (update.message.text or "").strip()
if not text:
return
session = self._sessions.setdefault(
chat_id,
{"authenticated": False, "awaiting_passphrase": False, "attempts": 0},
)
# ── Passphrase authentication ─────────────────────────────────────────
if not session["authenticated"]:
if not session["awaiting_passphrase"]:
session["awaiting_passphrase"] = True
session["attempts"] = 0
await update.message.reply_text("Enter your passphrase to continue:")
return
if passphrase_hash and bcrypt.checkpw(text.encode(), passphrase_hash.encode()):
session["authenticated"] = True
session["awaiting_passphrase"] = False
session["attempts"] = 0
await update.message.reply_text(
"Authenticated. How can I help you?\n"
"Send /start at any time to check bot status."
)
else:
session["attempts"] += 1
remaining = 3 - session["attempts"]
if remaining <= 0:
session["awaiting_passphrase"] = False
await update.message.reply_text(
"Too many failed attempts. Send any message to try again."
)
else:
await update.message.reply_text(
f"Wrong passphrase. {remaining} attempt(s) left."
)
return
# ── Rate limit ────────────────────────────────────────────────────────
if not self._rate_limiter.allow(user_id):
await update.message.reply_text(
"Rate limit reached (20 messages/hour). Try again later."
)
return
# ── Injection scan ────────────────────────────────────────────────────
from pyra.security.injection import scan_response
warnings = scan_response(text)
if warnings:
labels = ", ".join(w.pattern_label for w in warnings)
_log.warning("Injection in Telegram message (user %d): %s", user_id, labels)
await update.message.reply_text(
"Your message was blocked: injection pattern detected."
)
return
# ── Load history + system context ─────────────────────────────────────
history = _load_history(chat_id)
if not history:
try:
from pyra.memory.reader import load_context_for_session
ctx_text = load_context_for_session()
if ctx_text:
history = [{"role": "system", "content": ctx_text}]
except Exception:
pass
history.append({"role": "user", "content": text})
try:
await ctx.bot.send_chat_action(chat_id=chat_id, action="typing")
except Exception:
pass
# ── AI response ───────────────────────────────────────────────────────
try:
reply = await self._ai_chat(chat_id, history, ctx.bot)
except Exception as exc:
_log.error("AI error (chat %d): %s", chat_id, exc, exc_info=True)
await ctx.bot.send_message(chat_id=chat_id, text=f"AI error: {exc}")
return
history.append({"role": "assistant", "content": reply})
_save_history(chat_id, history)
# ── AI streaming + tool-use loop ──────────────────────────────────────────
async def _ai_chat(self, chat_id: int, messages: list[dict], bot: Bot) -> str:
from pyra.config.manager import load_config
from pyra.plugins.registry import PluginRegistry
from pyra.setup.providers import get_provider
from pyra.vault.reader import get_key
cfg = load_config()
provider = get_provider(cfg.ai.provider_id)
api_key = get_key(cfg.ai.provider_id) if provider.requires_key else "local"
call_kwargs: dict[str, Any] = {
"model": f"{provider.litellm_prefix}{cfg.ai.model}",
"api_key": api_key,
}
base_url = cfg.ai.base_url or provider.base_url
if base_url:
call_kwargs["api_base"] = base_url
litellm.suppress_debug_info = True
registry = PluginRegistry.instance()
tools_spec = [
{
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.parameters,
},
}
for t in registry.get_all_tools()
]
# Mutable state shared with helpers below
state: dict[str, Any] = {"msg_id": None, "last_edit": 0.0}
placeholder = await bot.send_message(chat_id=chat_id, text="")
state["msg_id"] = placeholder.message_id
async def _update(text: str) -> None:
if not state["msg_id"]:
return
now = time.monotonic()
if now - state["last_edit"] < _EDIT_INTERVAL:
return
try:
await bot.edit_message_text(
chat_id=chat_id,
message_id=state["msg_id"],
text=text[:_MAX_MSG_LEN],
)
state["last_edit"] = now
except Exception:
pass
async def _finalize(text: str) -> None:
if not state["msg_id"]:
return
if text:
try:
await bot.edit_message_text(
chat_id=chat_id,
message_id=state["msg_id"],
text=text[:_MAX_MSG_LEN],
)
except Exception:
pass
else:
try:
await bot.delete_message(chat_id=chat_id, message_id=state["msg_id"])
except Exception:
pass
state["msg_id"] = None
accumulated = ""
try:
for _iter in range(_MAX_TOOL_ITER):
tool_chunks: dict[int, dict] = {}
accumulated = ""
stream = await litellm.acompletion(
**call_kwargs,
messages=messages,
tools=tools_spec if tools_spec else None,
tool_choice="auto" if tools_spec else None,
stream=True,
)
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.content:
accumulated += delta.content
await _update(accumulated)
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tool_chunks:
tool_chunks[idx] = {"id": tc.id or "", "name": "", "args": ""}
if tc.function:
if tc.function.name:
tool_chunks[idx]["name"] += tc.function.name
if tc.function.arguments:
tool_chunks[idx]["args"] += tc.function.arguments
if not tool_chunks:
await _finalize(accumulated)
return accumulated
# Show any intermediate prose before tool calls
if accumulated:
await _finalize(accumulated)
else:
try:
await bot.delete_message(chat_id=chat_id, message_id=state["msg_id"])
except Exception:
pass
state["msg_id"] = None
tool_calls_list = [
{
"id": data["id"],
"type": "function",
"function": {"name": data["name"], "arguments": data["args"]},
}
for _, data in sorted(tool_chunks.items())
]
messages.append({
"role": "assistant",
"content": accumulated or None,
"tool_calls": tool_calls_list,
})
for tc in tool_calls_list:
result = await self._execute_tool_with_approval(tc, chat_id, bot)
messages.append({
"role": "tool",
"tool_call_id": tc["id"],
"content": result,
})
# New placeholder for the next AI response
new_ph = await bot.send_message(chat_id=chat_id, text="")
state["msg_id"] = new_ph.message_id
state["last_edit"] = 0.0
except litellm.BadRequestError:
# Provider doesn't support tool calls — retry without tools
accumulated = ""
state["last_edit"] = 0.0
stream = await litellm.acompletion(
**call_kwargs, messages=messages, stream=True
)
async for chunk in stream:
if chunk.choices[0].delta.content:
accumulated += chunk.choices[0].delta.content
await _update(accumulated)
await _finalize(accumulated)
return accumulated
return accumulated or "Error: tool-use loop exceeded maximum iterations."
# ── Tool approval via inline buttons ──────────────────────────────────────
async def _execute_tool_with_approval(
self, tool_call: dict, chat_id: int, bot: Bot
) -> str:
from pyra.plugins.registry import PluginRegistry
from pyra.security.injection import scan_response
tool_name = tool_call["function"]["name"]
args_raw = tool_call["function"]["arguments"]
try:
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
except json.JSONDecodeError:
return f"Error: invalid tool arguments for {tool_name}"
args_preview = json.dumps(args, indent=2)[:500]
call_id = uuid.uuid4().hex[:8]
keyboard = InlineKeyboardMarkup([[
InlineKeyboardButton("✅ Approve", callback_data=f"approve:{call_id}"),
InlineKeyboardButton("❌ Deny", callback_data=f"deny:{call_id}"),
]])
await bot.send_message(
chat_id=chat_id,
text=f"Tool request: {tool_name}\n\n{args_preview}",
reply_markup=keyboard,
)
loop = asyncio.get_running_loop()
future: asyncio.Future[bool] = loop.create_future()
self._pending_approvals[call_id] = future
try:
approved = await asyncio.wait_for(future, timeout=_APPROVAL_TIMEOUT)
except asyncio.TimeoutError:
self._pending_approvals.pop(call_id, None)
await bot.send_message(
chat_id=chat_id, text=f"Tool {tool_name}: approval timed out — denied."
)
return "Tool execution denied (timeout)."
if not approved:
return "Tool execution denied by user."
registry = PluginRegistry.instance()
tool = registry.find_tool(tool_name)
if tool is None:
return f"Error: tool '{tool_name}' not found in registry."
try:
result = tool.handler(**args)
if not isinstance(result, str):
result = str(result)
except Exception as exc:
return f"Tool error: {exc}"
injection_warnings = scan_response(result)
if injection_warnings:
labels = ", ".join(w.pattern_label for w in injection_warnings)
await bot.send_message(
chat_id=chat_id,
text=f"Warning: tool result contains suspicious content ({labels}).",
)
return result[:4000]
# ── Approval callback ─────────────────────────────────────────────────────
async def _handle_approval_callback(self, update: Update) -> None:
query = update.callback_query
if query is None:
return
await query.answer()
data = query.data or ""
if ":" not in data:
return
action, call_id = data.split(":", 1)
future = self._pending_approvals.pop(call_id, None)
if future and not future.done():
future.set_result(action == "approve")
try:
await query.edit_message_reply_markup(reply_markup=None)
except Exception:
pass
def get_plugin() -> TelegramBotPlugin:
return TelegramBotPlugin()
+11
View File
@@ -25,6 +25,7 @@ from pyra.plugins.executor import ToolExecutor
from pyra.plugins.registry import PluginRegistry from pyra.plugins.registry import PluginRegistry
from pyra.security.injection import scan_response from pyra.security.injection import scan_response
from pyra.setup.providers import get_provider from pyra.setup.providers import get_provider
from pyra.setup.wizard import fetch_loaded_models
from pyra.utils.paths import pyra_home from pyra.utils.paths import pyra_home
_HISTORY_FILE = pyra_home() / ".chat_history" _HISTORY_FILE = pyra_home() / ".chat_history"
@@ -176,6 +177,16 @@ def start_chat() -> None:
"[dim]Type /help for commands, /quit to exit.[/dim]" "[dim]Type /help for commands, /quit to exit.[/dim]"
) )
if provider.group == "Local":
loaded = fetch_loaded_models(provider)
if not loaded:
render_info(f"No model currently loaded in {provider.display_name}.")
elif cfg.ai.model not in loaded:
render_info(
f"Model '{cfg.ai.model}' not loaded in {provider.display_name}. "
f"Loaded: {', '.join(loaded)}"
)
_flags: dict = {"use_tools": True} _flags: dict = {"use_tools": True}
while True: while True:
+158 -13
View File
@@ -165,6 +165,7 @@ def plugin_list() -> None:
@click.argument("name") @click.argument("name")
def plugin_install(name: str) -> None: def plugin_install(name: str) -> None:
"""Install a bundled plugin to ~/.pyra/plugins/.""" """Install a bundled plugin to ~/.pyra/plugins/."""
import questionary
from pyra.plugins.install import get_bundled_plugins_dir, install_bundled_plugin from pyra.plugins.install import get_bundled_plugins_dir, install_bundled_plugin
from pyra.utils.paths import pyra_home from pyra.utils.paths import pyra_home
@@ -173,12 +174,55 @@ def plugin_install(name: str) -> None:
try: try:
install_bundled_plugin(name, bundled_dir, plugins_dir) install_bundled_plugin(name, bundled_dir, plugins_dir)
console.print(f"[green]Installed:[/green] {name}") console.print(f"[green]Installed:[/green] {name}")
console.print(f" Enable: [dim]pyra plugin enable {name}[/dim]")
console.print(f" Confirm: [dim]pyra plugin setup {name}[/dim]")
except FileNotFoundError as exc: except FileNotFoundError as exc:
console.print(f"[red]Error:[/red] {exc}") console.print(f"[red]Error:[/red] {exc}")
return
except Exception as exc: except Exception as exc:
console.print(f"[red]Install failed:[/red] {exc}") console.print(f"[red]Install failed:[/red] {exc}")
return
try:
configure_now = questionary.confirm(f"Configure {name} now?", default=True).ask()
except (KeyboardInterrupt, EOFError):
return
if not configure_now:
console.print(f" Configure later: [dim]pyra plugin setup {name}[/dim]")
console.print(f" Enable: [dim]pyra plugin enable {name}[/dim]")
return
from pyra.plugins.loader import load_plugin_by_name
from pyra.vault.writer import set_key as _set_key
p = load_plugin_by_name(name, plugins_dir)
if p is None:
console.print(f"[red]Could not load {name} for setup.[/red]")
return
try:
p.setup(console, _set_key)
except (KeyboardInterrupt, EOFError):
console.print("[dim]Setup cancelled.[/dim]")
return
except Exception as exc:
console.print(f"[red]Setup error:[/red] {exc}")
return
try:
enable_now = questionary.confirm(f"Enable {name} now?", default=True).ask()
except (KeyboardInterrupt, EOFError):
enable_now = False
if enable_now:
from pyra.config.manager import load_config, save_config
try:
cfg = load_config()
if name not in cfg.plugins.enabled:
cfg.plugins.enabled.append(name)
save_config(cfg)
console.print(f"[green]Enabled:[/green] {name}")
except Exception as exc:
console.print(f"[yellow]Could not enable automatically:[/yellow] {exc}")
console.print(f" Enable manually: [dim]pyra plugin enable {name}[/dim]")
console.print(f"[dim]Run [bold]pyra daemon start[/bold] to bring {name} online.[/dim]")
@plugin.command("enable") @plugin.command("enable")
@@ -266,43 +310,144 @@ def daemon() -> None:
_bootstrap_or_exit() _bootstrap_or_exit()
@daemon.command("run", hidden=True)
def daemon_run() -> None:
"""Run daemon in foreground (used by service manager)."""
from pyra.daemon.core import run_foreground
run_foreground()
@daemon.command("start") @daemon.command("start")
def daemon_start() -> None: def daemon_start() -> None:
"""Start the Pyra daemon in the background.""" """Start the Pyra daemon in the background."""
console.print("[yellow]Daemon (Stage 2.4) is not yet implemented.[/yellow]") from pyra.daemon.core import start_background
try:
start_background()
except FileNotFoundError:
console.print("[red]Error:[/red] Run [dim]pyra setup[/dim] first.")
@daemon.command("stop") @daemon.command("stop")
def daemon_stop() -> None: def daemon_stop() -> None:
"""Stop the running Pyra daemon.""" """Stop the running Pyra daemon."""
console.print("[yellow]Daemon (Stage 2.4) is not yet implemented.[/yellow]") _daemon_ipc("stop", success_msg="Daemon stopped.")
@daemon.command("status") @daemon.command("status")
def daemon_status() -> None: def daemon_status() -> None:
"""Show daemon status.""" """Show daemon status."""
console.print("[yellow]Daemon (Stage 2.4) is not yet implemented.[/yellow]") _daemon_ipc("status")
@daemon.command("restart") @daemon.command("restart")
def daemon_restart() -> None: def daemon_restart() -> None:
"""Restart the Pyra daemon.""" """Restart the Pyra daemon."""
console.print("[yellow]Daemon (Stage 2.4) is not yet implemented.[/yellow]") import time
from pyra.daemon.core import start_background
_daemon_ipc("stop", success_msg=None)
time.sleep(1.5)
try:
start_background()
except FileNotFoundError:
console.print("[red]Error:[/red] Run [dim]pyra setup[/dim] first.")
@daemon.command("install") @daemon.command("install")
def daemon_install() -> None: def daemon_install() -> None:
"""Install Pyra as a system service (launchd/systemd).""" """Install Pyra as a system service (launchd/systemd/schtasks)."""
console.print("[yellow]Daemon service install (Stage 2.4) is not yet implemented.[/yellow]") from pyra.daemon.service import detect_platform, install_service
try:
install_service()
console.print(f"[green]Service installed[/green] ({detect_platform()}).")
except Exception as exc:
console.print(f"[red]Install failed:[/red] {exc}")
@daemon.command("uninstall") @daemon.command("uninstall")
def daemon_uninstall() -> None: def daemon_uninstall() -> None:
"""Remove the Pyra system service.""" """Remove the Pyra system service."""
console.print("[yellow]Daemon service uninstall (Stage 2.4) is not yet implemented.[/yellow]") from pyra.daemon.service import uninstall_service
try:
uninstall_service()
console.print("[green]Service removed.[/green]")
except Exception as exc:
console.print(f"[red]Uninstall failed:[/red] {exc}")
@daemon.command("run", hidden=True) def _daemon_ipc(cmd: str, *, success_msg: str | None = None) -> None:
def daemon_run() -> None: """Send a command to the running daemon via IPC and render the response."""
"""Run daemon in foreground (used by service manager).""" from pyra.config.manager import load_config
console.print("[yellow]Daemon (Stage 2.4) is not yet implemented.[/yellow]") from pyra.daemon.ipc import (
get_socket_path,
is_unix_socket,
get_port_file_path,
send_command,
)
try:
cfg = load_config()
except FileNotFoundError:
console.print("[red]Error:[/red] Run [dim]pyra setup[/dim] first.")
return
if is_unix_socket():
address = get_socket_path(cfg.daemon.socket_path)
else:
port = _read_windows_port()
if port is None:
console.print("[yellow]Daemon is not running.[/yellow]")
return
address = ("127.0.0.1", port)
try:
resp = send_command(address, {"cmd": cmd})
except (ConnectionRefusedError, FileNotFoundError, OSError):
console.print("[yellow]Daemon is not running.[/yellow]")
return
except ConnectionResetError:
console.print("[red]Permission denied:[/red] daemon rejected connection.")
return
except TimeoutError:
console.print("[red]Daemon did not respond in time.[/red]")
return
if not resp.get("ok"):
console.print(f"[red]Error:[/red] {resp.get('data', {}).get('error', 'unknown')}")
return
if cmd == "status":
_render_daemon_status(resp["data"])
elif success_msg:
console.print(f"[green]{success_msg}[/green]")
def _read_windows_port() -> int | None:
from pyra.daemon.ipc import get_port_file_path
try:
return int(get_port_file_path().read_text().strip())
except (FileNotFoundError, ValueError):
return None
def _render_daemon_status(data: dict) -> None:
from rich.table import Table
uptime = data.get("uptime", 0.0)
pid = data.get("pid", "?")
tasks = data.get("tasks", [])
hours, rem = divmod(int(uptime), 3600)
mins, secs = divmod(rem, 60)
uptime_str = f"{hours}h {mins}m {secs}s" if hours else f"{mins}m {secs}s"
console.print(f"[bold green]Daemon running[/bold green] — PID {pid}, uptime {uptime_str}")
if tasks:
table = Table("Task", "Alive", "Restarts", "Last error", show_header=True)
for t in tasks:
alive = "[green]yes[/green]" if t.get("alive") else "[red]no[/red]"
error = t.get("last_error") or ""
table.add_row(t.get("name", "?"), alive, str(t.get("restart_count", 0)), error)
console.print(table)
else:
console.print("[dim]No plugin tasks registered.[/dim]")
+1
View File
@@ -36,6 +36,7 @@ class DaemonConfig(BaseModel):
socket_path: str = "~/.pyra/daemon.sock" socket_path: str = "~/.pyra/daemon.sock"
log_file: str = "~/.pyra/daemon.log" log_file: str = "~/.pyra/daemon.log"
pid_file: str = "~/.pyra/daemon.pid" pid_file: str = "~/.pyra/daemon.pid"
ipc_port: int = 0 # Windows TCP loopback: 0 = OS-assigned, written to ~/.pyra/daemon.port
class PyraConfig(BaseModel): class PyraConfig(BaseModel):
+21
View File
@@ -0,0 +1,21 @@
"""Pyra background daemon package."""
from pyra.daemon.core import PluginSupervisor, run_foreground, start_background
from pyra.daemon.ipc import IpcClient, IpcServer, send_command
from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path
from pyra.daemon.service import detect_platform, install_service, uninstall_service
__all__ = [
"run_foreground",
"start_background",
"PluginSupervisor",
"IpcClient",
"IpcServer",
"send_command",
"PidFile",
"PidFileError",
"resolve_pid_path",
"detect_platform",
"install_service",
"uninstall_service",
]
+313
View File
@@ -0,0 +1,313 @@
"""Pyra daemon core — asyncio event loop, plugin task supervisor, signal handling."""
from __future__ import annotations
import asyncio
import logging
import logging.handlers
import os
import signal
import subprocess
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Coroutine
from pyra.utils.paths import pyra_home, safe_chmod
_log = logging.getLogger("pyra.daemon")
_start_time: float = 0.0
# ── Plugin task supervisor ────────────────────────────────────────────────────
@dataclass
class TaskRecord:
name: str
coro_factory: Callable[[], Coroutine] # type: ignore[type-arg]
task: asyncio.Task | None = field(default=None, repr=False)
restart_count: int = 0
last_error: str | None = None
def is_alive(self) -> bool:
return self.task is not None and not self.task.done()
class PluginSupervisor:
_RESTART_DELAY: float = 5.0
_MAX_RESTARTS: int = 10
def __init__(self) -> None:
self._records: list[TaskRecord] = []
self._shutdown = asyncio.Event()
def add_task(self, name: str, factory: Callable[[], Coroutine]) -> None: # type: ignore[type-arg]
self._records.append(TaskRecord(name=name, coro_factory=factory))
async def start(self) -> None:
for record in self._records:
record.task = asyncio.create_task(
self._supervise(record), name=record.name
)
_log.info("Supervisor started with %d plugin task(s).", len(self._records))
async def run_until_shutdown(self) -> None:
await self._shutdown.wait()
_log.info("Shutdown requested — stopping supervisor.")
def request_shutdown(self) -> None:
self._shutdown.set()
async def stop(self) -> None:
for record in self._records:
if record.task and not record.task.done():
record.task.cancel()
tasks = [r.task for r in self._records if r.task]
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def reload(self) -> None:
"""Cancel all running tasks and restart them with fresh coroutines."""
for record in self._records:
if record.task and not record.task.done():
record.task.cancel()
try:
await record.task
except (asyncio.CancelledError, Exception):
pass
record.restart_count = 0
record.last_error = None
record.task = asyncio.create_task(
self._supervise(record), name=record.name
)
_log.info("Reloaded %d plugin task(s).", len(self._records))
def status(self) -> list[dict]:
return [
{
"name": r.name,
"alive": r.is_alive(),
"restart_count": r.restart_count,
"last_error": r.last_error,
}
for r in self._records
]
async def _supervise(self, record: TaskRecord) -> None:
while not self._shutdown.is_set():
try:
await record.coro_factory()
_log.info("Plugin task %s completed normally.", record.name)
return
except asyncio.CancelledError:
return
except Exception as exc:
record.restart_count += 1
record.last_error = f"{type(exc).__name__}: {exc}"
_log.error(
"Plugin task %s crashed (restart #%d): %s",
record.name, record.restart_count, exc,
exc_info=True,
)
if record.restart_count >= self._MAX_RESTARTS:
_log.critical(
"Plugin task %s exceeded max restarts (%d). Giving up.",
record.name, self._MAX_RESTARTS,
)
return
try:
await asyncio.wait_for(
asyncio.sleep(self._RESTART_DELAY),
timeout=self._RESTART_DELAY + 1,
)
except asyncio.CancelledError:
return
# ── IPC command dispatch ──────────────────────────────────────────────────────
def _make_ipc_handler(supervisor: PluginSupervisor):
async def handler(msg: dict) -> dict:
cmd = msg.get("cmd", "")
match cmd:
case "ping":
return {"ok": True, "data": {"pong": True}}
case "status":
return {
"ok": True,
"data": {
"uptime": time.monotonic() - _start_time,
"pid": os.getpid(),
"tasks": supervisor.status(),
},
}
case "stop":
supervisor.request_shutdown()
return {"ok": True, "data": {}}
case "reload":
await supervisor.reload()
return {"ok": True, "data": {"tasks_reloaded": len(supervisor._records)}}
case _:
return {"ok": False, "data": {"error": f"unknown command: {cmd}"}}
return handler
# ── Main async entrypoint ─────────────────────────────────────────────────────
async def _run_daemon(cfg, supervisor: PluginSupervisor) -> None:
from pyra.daemon.ipc import IpcServer, get_socket_path, is_unix_socket
# Install signal handlers now that the event loop is running.
_install_signal_handlers(supervisor)
if is_unix_socket():
address = get_socket_path(cfg.daemon.socket_path)
else:
address = ("127.0.0.1", cfg.daemon.ipc_port)
server = IpcServer(address, _make_ipc_handler(supervisor))
await supervisor.start()
async with asyncio.TaskGroup() as tg:
tg.create_task(server.start(), name="ipc_server")
tg.create_task(supervisor.run_until_shutdown(), name="shutdown_waiter")
await server.stop()
await supervisor.stop()
# ── Foreground entry point (pyra daemon run) ──────────────────────────────────
def run_foreground() -> None:
"""Run the daemon in the foreground. Called by `pyra daemon run`."""
from pyra.config.manager import load_config
from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path
from pyra.plugins.registry import PluginRegistry
global _start_time
cfg = load_config()
_setup_logging(cfg.daemon.log_file)
pid_path = resolve_pid_path(cfg.daemon.pid_file)
pid_file = PidFile(pid_path)
existing = pid_file.read()
if existing is not None and not pid_file.is_stale():
_log.error("Daemon already running (PID %d). Exiting.", existing)
sys.exit(1)
registry = PluginRegistry()
from pyra.utils.paths import pyra_home as _pyra_home
plugins_dir = _pyra_home() / "plugins"
if plugins_dir.exists():
registry.load_all(plugins_dir, cfg.plugins.enabled)
supervisor = PluginSupervisor()
for name, factory in registry.get_daemon_task_factories():
supervisor.add_task(name, factory)
_start_time = time.monotonic()
try:
with pid_file:
_log.info("Pyra daemon starting (PID %d).", os.getpid())
try:
asyncio.run(_run_daemon(cfg, supervisor))
except KeyboardInterrupt:
pass
_log.info("Pyra daemon stopped.")
except PidFileError as exc:
_log.error("Could not acquire PID file: %s", exc)
sys.exit(1)
# ── Background spawn (pyra daemon start) ─────────────────────────────────────
def start_background() -> None:
"""Spawn `pyra daemon run` as a detached background process."""
from pyra.config.manager import load_config
from pyra.daemon.pid import PidFile, resolve_pid_path
from pyra.daemon.service import find_pyra_executable
cfg = load_config()
pid_path = resolve_pid_path(cfg.daemon.pid_file)
pid_file = PidFile(pid_path)
existing = pid_file.read()
if existing is not None and not pid_file.is_stale():
from pyra.chat.renderer import console
console.print(f"[yellow]Daemon already running (PID {existing}).[/yellow]")
return
exe = find_pyra_executable()
log_path = Path(cfg.daemon.log_file).expanduser()
log_path.parent.mkdir(parents=True, exist_ok=True)
with open(log_path, "a") as log_fh:
if sys.platform == "win32":
DETACHED_PROCESS = 0x00000008
CREATE_NEW_PROCESS_GROUP = 0x00000200
subprocess.Popen(
[exe, "daemon", "run"],
creationflags=DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP,
stdout=log_fh,
stderr=log_fh,
close_fds=True,
)
else:
subprocess.Popen(
[exe, "daemon", "run"],
start_new_session=True,
stdout=log_fh,
stderr=log_fh,
stdin=subprocess.DEVNULL,
close_fds=True,
)
from pyra.chat.renderer import console
# Poll the PID file for up to 3 seconds to confirm startup.
for _ in range(30):
time.sleep(0.1)
pid = pid_file.read()
if pid is not None:
console.print(f"[green]Daemon started (PID {pid}).[/green]")
return
console.print("[yellow]Daemon process spawned but PID file not yet written.[/yellow]")
# ── Signal handling ───────────────────────────────────────────────────────────
def _install_signal_handlers(supervisor: PluginSupervisor) -> None:
if sys.platform == "win32":
signal.signal(signal.SIGTERM, lambda *_: supervisor.request_shutdown())
return
loop = asyncio.get_running_loop()
loop.add_signal_handler(signal.SIGTERM, supervisor.request_shutdown)
loop.add_signal_handler(signal.SIGHUP, supervisor.request_shutdown)
# ── Logging setup ─────────────────────────────────────────────────────────────
def _setup_logging(log_file_str: str) -> None:
log_path = Path(log_file_str).expanduser()
log_path.parent.mkdir(parents=True, exist_ok=True)
handler = logging.handlers.RotatingFileHandler(
log_path, maxBytes=5 * 1024 * 1024, backupCount=3
)
handler.setFormatter(
logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s")
)
root = logging.getLogger("pyra")
root.addHandler(handler)
root.setLevel(logging.INFO)
safe_chmod(log_path, 0o600)
+241
View File
@@ -0,0 +1,241 @@
"""IPC transport for the Pyra daemon.
Linux/macOS: Unix domain socket at ~/.pyra/daemon.sock (chmod 600, UID-checked).
Windows: TCP loopback on an OS-assigned port; actual port written to
~/.pyra/daemon.port so clients can connect without knowing it ahead
of time.
"""
from __future__ import annotations
import asyncio
import json
import os
import struct
import sys
from pathlib import Path
from typing import Any, Awaitable, Callable
# ── Protocol types ────────────────────────────────────────────────────────────
IpcMessage = dict[str, Any] # must have "cmd" key
IpcResponse = dict[str, Any] # must have "ok" and "data" keys
# ── Encode / decode ───────────────────────────────────────────────────────────
def encode_message(msg: IpcMessage) -> bytes:
return (json.dumps(msg) + "\n").encode()
def decode_message(line: bytes) -> IpcMessage:
try:
return json.loads(line.rstrip(b"\n"))
except json.JSONDecodeError as exc:
raise ValueError(f"Invalid IPC message: {exc}") from exc
# ── Address helpers ───────────────────────────────────────────────────────────
def is_unix_socket() -> bool:
return sys.platform != "win32"
def get_socket_path(cfg_socket_path: str) -> Path:
"""Expand ~ and return the Unix socket path."""
return Path(cfg_socket_path).expanduser()
def get_port_file_path() -> Path:
from pyra.utils.paths import pyra_home
return pyra_home() / "daemon.port"
def _read_windows_port() -> int | None:
port_file = get_port_file_path()
try:
return int(port_file.read_text().strip())
except (FileNotFoundError, ValueError):
return None
# ── Server ────────────────────────────────────────────────────────────────────
class IpcServer:
def __init__(
self,
address: Path | tuple[str, int],
handler: Callable[[IpcMessage], Awaitable[IpcResponse]],
) -> None:
self._address = address
self._handler = handler
self._server: asyncio.Server | None = None
async def start(self) -> None:
if is_unix_socket():
assert isinstance(self._address, Path)
sock_path = self._address
if sock_path.exists():
sock_path.unlink()
self._server = await asyncio.start_unix_server(
self._handle_client, path=str(sock_path)
)
os.chmod(sock_path, 0o600)
else:
host, port = self._address if isinstance(self._address, tuple) else ("127.0.0.1", 0)
self._server = await asyncio.start_server(
self._handle_client, host=host, port=port
)
actual_port = self._server.sockets[0].getsockname()[1]
port_file = get_port_file_path()
port_file.write_text(str(actual_port))
await self._server.start_serving()
async def stop(self) -> None:
if self._server is not None:
self._server.close()
try:
await asyncio.wait_for(self._server.wait_closed(), timeout=5.0)
except asyncio.TimeoutError:
pass
if is_unix_socket() and isinstance(self._address, Path):
try:
self._address.unlink()
except FileNotFoundError:
pass
else:
port_file = get_port_file_path()
try:
port_file.unlink()
except FileNotFoundError:
pass
async def _handle_client(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
try:
if is_unix_socket() and not self._check_peer_uid(writer):
writer.close()
return
line = await asyncio.wait_for(reader.readline(), timeout=5.0)
if not line:
return
try:
msg = decode_message(line)
except ValueError:
resp: IpcResponse = {"ok": False, "data": {"error": "invalid JSON"}}
else:
resp = await self._handler(msg)
writer.write(encode_message(resp))
await writer.drain()
except (asyncio.TimeoutError, ConnectionResetError, BrokenPipeError):
pass
finally:
try:
writer.close()
await writer.wait_closed()
except Exception:
pass
def _check_peer_uid(self, writer: asyncio.StreamWriter) -> bool:
"""Return True if the peer's UID matches ours. Falls back to True on error."""
try:
peer_uid = _get_peer_uid(writer)
if peer_uid is None:
return True # can't determine — allow (socket perms are the guard)
return peer_uid == os.getuid()
except Exception:
return True # don't crash the server on unexpected errors
# ── Client ────────────────────────────────────────────────────────────────────
class IpcClient:
def __init__(self, address: Path | tuple[str, int]) -> None:
self._address = address
async def send(self, msg: IpcMessage, timeout: float = 5.0) -> IpcResponse:
if is_unix_socket():
assert isinstance(self._address, Path)
reader, writer = await asyncio.wait_for(
asyncio.open_unix_connection(str(self._address)), timeout=timeout
)
else:
host, port = self._address if isinstance(self._address, tuple) else ("127.0.0.1", 0)
reader, writer = await asyncio.wait_for(
asyncio.open_connection(host, port), timeout=timeout
)
try:
writer.write(encode_message(msg))
await writer.drain()
line = await asyncio.wait_for(reader.readline(), timeout=timeout)
return decode_message(line)
finally:
writer.close()
try:
await writer.wait_closed()
except Exception:
pass
def send_command(
address: Path | tuple[str, int],
msg: IpcMessage,
timeout: float = 5.0,
) -> IpcResponse:
"""Synchronous wrapper around IpcClient.send() for CLI callers."""
return asyncio.run(IpcClient(address).send(msg, timeout=timeout))
# ── Peer UID detection ────────────────────────────────────────────────────────
def _get_peer_uid(writer: asyncio.StreamWriter) -> int | None:
"""Return the connecting peer's UID, or None if unavailable."""
try:
sock = writer.get_extra_info("socket")
if sock is None:
return None
if sys.platform == "linux":
# SO_PEERCRED: struct { pid_t pid; uid_t uid; gid_t gid; }
SO_PEERCRED = 17
cred = sock.getsockopt(
socket_module().SOL_SOCKET, SO_PEERCRED, struct.calcsize("3i")
)
_pid, uid, _gid = struct.unpack("3i", cred)
return uid
if sys.platform == "darwin":
return _macos_peer_uid(sock.fileno())
except Exception:
pass
return None
def socket_module(): # lazy import to avoid top-level import on non-Unix
import socket
return socket
def _macos_peer_uid(fd: int) -> int | None:
"""Use getpeereid(2) via ctypes to retrieve the peer UID on macOS."""
import ctypes
import ctypes.util
libc_name = ctypes.util.find_library("c")
if not libc_name:
return None
libc = ctypes.CDLL(libc_name)
euid = ctypes.c_uint32(0)
egid = ctypes.c_uint32(0)
if libc.getpeereid(fd, ctypes.byref(euid), ctypes.byref(egid)) != 0:
return None
return euid.value
+94
View File
@@ -0,0 +1,94 @@
"""PID file management for the Pyra daemon."""
from __future__ import annotations
import os
import sys
from pathlib import Path
class PidFileError(OSError):
"""Raised when a PID file operation fails due to a live conflicting process."""
class PidFile:
def __init__(self, path: Path) -> None:
self._path = path
def write(self) -> None:
"""Write the current PID atomically.
Raises PidFileError if a non-stale PID file already exists.
"""
existing = self.read()
if existing is not None and not self.is_stale():
raise PidFileError(
f"Daemon already running with PID {existing} "
f"(PID file: {self._path})"
)
tmp = self._path.with_suffix(".pid.tmp")
tmp.write_text(str(os.getpid()))
tmp.replace(self._path)
def read(self) -> int | None:
"""Return the PID from the file, or None if the file is absent or unreadable."""
try:
return int(self._path.read_text().strip())
except (FileNotFoundError, ValueError):
return None
def is_stale(self) -> bool:
"""True when the PID file exists but the process no longer runs."""
pid = self.read()
if pid is None:
return False
return not _process_is_alive(pid)
def remove(self) -> None:
"""Delete the PID file, ignoring FileNotFoundError."""
try:
self._path.unlink()
except FileNotFoundError:
pass
def __enter__(self) -> "PidFile":
self.write()
return self
def __exit__(self, *_: object) -> None:
self.remove()
def resolve_pid_path(cfg_path: str) -> Path:
"""Expand ~ and return an absolute Path."""
return Path(cfg_path).expanduser().resolve()
# ── Platform-specific process liveness check ─────────────────────────────────
def _process_is_alive(pid: int) -> bool:
if sys.platform == "win32":
return _win_process_is_alive(pid)
return _posix_process_is_alive(pid)
def _posix_process_is_alive(pid: int) -> bool:
try:
os.kill(pid, 0)
return True
except ProcessLookupError:
return False
except PermissionError:
# Process exists but is owned by another user — still alive.
return True
def _win_process_is_alive(pid: int) -> bool:
import ctypes
SYNCHRONIZE = 0x00100000
handle = ctypes.windll.kernel32.OpenProcess(SYNCHRONIZE, False, pid) # type: ignore[attr-defined]
if handle == 0:
return False
ctypes.windll.kernel32.CloseHandle(handle) # type: ignore[attr-defined]
return True
+212
View File
@@ -0,0 +1,212 @@
"""OS-specific service file generation and install/uninstall for the Pyra daemon."""
from __future__ import annotations
import platform
import shutil
import subprocess
import sys
from pathlib import Path
from typing import Literal
from pyra.utils.paths import safe_chmod
def detect_platform() -> Literal["macos", "linux", "windows"]:
s = platform.system()
if s == "Darwin":
return "macos"
if s == "Linux":
return "linux"
if s == "Windows":
return "windows"
raise RuntimeError(f"Unsupported platform: {s}")
def find_pyra_executable() -> str:
"""Return the full path to the active pyra executable.
Tries, in order:
1. shutil.which("pyra") — works when pyra is on PATH (activated venv)
2. sys.executable's sibling "pyra" script — covers editable installs
3. Fallback: sys.executable -m pyra
"""
found = shutil.which("pyra")
if found:
return found
sibling = Path(sys.executable).parent / "pyra"
if sibling.exists():
return str(sibling)
return f"{sys.executable} -m pyra"
# ── Template generators ───────────────────────────────────────────────────────
def render_launchd_plist(exe: str, log_file: str, pid_file: str) -> str:
log = str(Path(log_file).expanduser())
return f"""<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN"
"http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>com.pyra.daemon</string>
<key>ProgramArguments</key>
<array>
<string>{exe}</string>
<string>daemon</string>
<string>run</string>
</array>
<key>RunAtLoad</key>
<true/>
<key>KeepAlive</key>
<true/>
<key>StandardOutPath</key>
<string>{log}</string>
<key>StandardErrorPath</key>
<string>{log}</string>
<key>ProcessType</key>
<string>Background</string>
</dict>
</plist>
"""
def render_systemd_unit(exe: str, log_file: str) -> str:
log = str(Path(log_file).expanduser())
return f"""[Unit]
Description=Pyra Personal AI Assistant Daemon
After=default.target
[Service]
Type=simple
ExecStart={exe} daemon run
Restart=on-failure
RestartSec=5s
StandardOutput=append:{log}
StandardError=append:{log}
[Install]
WantedBy=default.target
"""
def render_schtasks_xml(exe: str) -> str:
return f"""<?xml version="1.0" encoding="UTF-16"?>
<Task version="1.2" xmlns="http://schemas.microsoft.com/windows/2004/02/mit/task">
<RegistrationInfo>
<Description>Pyra Personal AI Assistant Daemon</Description>
</RegistrationInfo>
<Triggers>
<LogonTrigger>
<Enabled>true</Enabled>
</LogonTrigger>
</Triggers>
<Settings>
<MultipleInstancesPolicy>IgnoreNew</MultipleInstancesPolicy>
<DisallowStartIfOnBatteries>false</DisallowStartIfOnBatteries>
<StopIfGoingOnBatteries>false</StopIfGoingOnBatteries>
<ExecutionTimeLimit>PT0S</ExecutionTimeLimit>
<RestartOnFailure>
<Interval>PT1M</Interval>
<Count>999</Count>
</RestartOnFailure>
</Settings>
<Actions Context="Author">
<Exec>
<Command>{exe}</Command>
<Arguments>daemon run</Arguments>
</Exec>
</Actions>
</Task>
"""
# ── Install / uninstall ───────────────────────────────────────────────────────
def install_service() -> None:
"""Generate and register the OS service for the current platform."""
from pyra.config.manager import load_config
cfg = load_config()
exe = find_pyra_executable()
plat = detect_platform()
if plat == "macos":
_install_launchd(exe, cfg.daemon.log_file, cfg.daemon.pid_file)
elif plat == "linux":
_install_systemd(exe, cfg.daemon.log_file)
else:
_install_windows(exe)
def uninstall_service() -> None:
"""Deregister the OS service for the current platform."""
plat = detect_platform()
if plat == "macos":
_uninstall_launchd()
elif plat == "linux":
_uninstall_systemd()
else:
_uninstall_windows()
# ── macOS launchd ─────────────────────────────────────────────────────────────
_PLIST_PATH = Path.home() / "Library" / "LaunchAgents" / "com.pyra.daemon.plist"
def _install_launchd(exe: str, log_file: str, pid_file: str) -> None:
_PLIST_PATH.parent.mkdir(parents=True, exist_ok=True)
_PLIST_PATH.write_text(render_launchd_plist(exe, log_file, pid_file))
safe_chmod(_PLIST_PATH, 0o644) # launchd requires 644, not 600
subprocess.run(["launchctl", "load", str(_PLIST_PATH)], check=True)
def _uninstall_launchd() -> None:
if _PLIST_PATH.exists():
subprocess.run(["launchctl", "unload", str(_PLIST_PATH)], check=False)
_PLIST_PATH.unlink()
# ── Linux systemd ─────────────────────────────────────────────────────────────
_SYSTEMD_UNIT = Path.home() / ".config" / "systemd" / "user" / "pyra.service"
def _install_systemd(exe: str, log_file: str) -> None:
_SYSTEMD_UNIT.parent.mkdir(parents=True, exist_ok=True)
_SYSTEMD_UNIT.write_text(render_systemd_unit(exe, log_file))
subprocess.run(["systemctl", "--user", "daemon-reload"], check=True)
subprocess.run(["systemctl", "--user", "enable", "pyra"], check=True)
def _uninstall_systemd() -> None:
subprocess.run(
["systemctl", "--user", "disable", "--now", "pyra"], check=False
)
if _SYSTEMD_UNIT.exists():
_SYSTEMD_UNIT.unlink()
subprocess.run(["systemctl", "--user", "daemon-reload"], check=False)
# ── Windows Task Scheduler ────────────────────────────────────────────────────
def _install_windows(exe: str) -> None:
from pyra.utils.paths import pyra_home
xml_path = pyra_home() / "daemon_task.xml"
# schtasks /Create /XML requires UTF-16 encoding
xml_path.write_text(render_schtasks_xml(exe), encoding="utf-16")
subprocess.run(
["schtasks", "/Create", "/TN", "PyraAssistant", "/XML", str(xml_path), "/F"],
check=True,
)
def _uninstall_windows() -> None:
subprocess.run(
["schtasks", "/Delete", "/TN", "PyraAssistant", "/F"], check=False
)
+26
View File
@@ -75,6 +75,32 @@ class PluginRegistry:
pass pass
return tasks return tasks
def get_daemon_task_factories(
self,
) -> list[tuple[str, Callable[[], Coroutine]]]: # type: ignore[type-arg]
"""Return (name, factory) pairs for all plugin daemon tasks.
Each factory re-calls plugin.daemon_tasks() to produce a fresh coroutine,
enabling the supervisor to restart crashed tasks without changing the plugin
protocol.
"""
factories: list[tuple[str, Callable[[], Coroutine]]] = [] # type: ignore[type-arg]
for plugin in self._plugins.values():
try:
initial = plugin.daemon_tasks()
n_tasks = len(initial)
for c in initial:
c.close() # prevent "coroutine never awaited" RuntimeWarning
except Exception:
continue
for i in range(n_tasks):
name = f"{plugin.name}.task_{i}"
# Capture plugin and index by value so each closure is independent.
def _factory(p=plugin, idx=i) -> Coroutine: # type: ignore[type-arg]
return p.daemon_tasks()[idx]
factories.append((name, _factory))
return factories
def find_tool(self, name: str) -> Tool | None: def find_tool(self, name: str) -> Tool | None:
return self._tools.get(name) return self._tools.get(name)
+426 -54
View File
@@ -1,12 +1,16 @@
import contextlib
import json
import httpx import httpx
import questionary import questionary
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
from rich.text import Text from rich.text import Text
from pyra.config.manager import save_config from pyra.config.manager import load_config, save_config
from pyra.config.schema import GeneralConfig, ProviderConfig, PyraConfig from pyra.config.schema import GeneralConfig, ProviderConfig, PyraConfig
from pyra.setup.providers import PROVIDERS, Provider, get_provider from pyra.setup.providers import PROVIDERS, Provider, get_provider
from pyra.utils.paths import pyra_home, safe_chmod
console = Console() console = Console()
@@ -20,6 +24,72 @@ _USE_CASE_PLUGINS: dict[str, list[str]] = {
} }
_DRAFT_FILE = "setup.draft.json"
def _draft_path():
return pyra_home() / _DRAFT_FILE
def _save_draft(state: dict) -> None:
path = _draft_path()
path.write_text(json.dumps(state, indent=2))
safe_chmod(path, 0o600)
def _load_draft() -> dict | None:
path = _draft_path()
if not path.exists():
return None
try:
return json.loads(path.read_text())
except Exception:
return None
def _delete_draft() -> None:
with contextlib.suppress(FileNotFoundError):
_draft_path().unlink()
def _mark_step_done(state: dict, step: str) -> None:
state.setdefault("completed_steps", [])
if step not in state["completed_steps"]:
state["completed_steps"].append(step)
def _offer_resume(draft: dict) -> bool:
"""Show a summary of the incomplete setup and ask Resume / Start fresh."""
completed = draft.get("completed_steps", [])
step_display = {
"profile": f"Profile: {draft.get('user_name', '?')}",
"provider": f"Provider: {draft.get('provider_id', '?')}",
"model": f"Model: {draft.get('model', '?')}",
"api_key": "API key: stored in vault",
"connection": "Connection: test passed",
}
lines = ["[bold]An incomplete setup was found.[/bold]\n"]
for step, label in step_display.items():
if step in completed:
lines.append(f" [green]✓[/green] {label}")
else:
lines.append(f" [dim]○ {label.split(':')[0].strip()}: pending[/dim]")
console.print(Panel("\n".join(lines), title="Incomplete setup", border_style="yellow"))
console.print()
action = questionary.select(
"What would you like to do?",
choices=[
questionary.Choice("Resume from where you left off", value="resume"),
questionary.Choice("Start fresh", value="fresh"),
],
).ask()
if action is None:
raise SystemExit(0)
return action == "resume"
def run_setup() -> None: def run_setup() -> None:
console.print(Panel( console.print(Panel(
Text("Welcome to Pyra Setup", justify="center", style="bold cyan"), Text("Welcome to Pyra Setup", justify="center", style="bold cyan"),
@@ -28,36 +98,93 @@ def run_setup() -> None:
)) ))
console.print() console.print()
user_name, purpose, use_cases = _collect_user_profile() state: dict = {}
draft = _load_draft()
if draft:
if _offer_resume(draft):
state = draft
else:
_delete_draft()
provider = _choose_provider() try:
model = _choose_model(provider) # ── Step 1: profile ────────────────────────────────────────────────
if "profile" in state.get("completed_steps", []):
user_name = state["user_name"]
purpose = state["purpose"]
use_cases = state["use_cases"]
console.print(f" [dim]✓ Profile: {user_name}[/dim]")
else:
user_name, purpose, use_cases = _collect_user_profile()
state.update(user_name=user_name, purpose=purpose, use_cases=use_cases)
_mark_step_done(state, "profile")
_save_draft(state)
if provider.requires_key: # ── Step 2: provider ───────────────────────────────────────────────
_collect_api_key(provider) if "provider" in state.get("completed_steps", []):
provider = get_provider(state["provider_id"])
console.print(f" [dim]✓ Provider: {provider.display_name}[/dim]")
else:
provider = _choose_provider()
state.update(provider_id=provider.id)
_mark_step_done(state, "provider")
_save_draft(state)
_test_connection(provider, model) # ── Step 3: model ──────────────────────────────────────────────────
if "model" in state.get("completed_steps", []):
model = state["model"]
console.print(f" [dim]✓ Model: {model}[/dim]")
else:
model = _choose_model(provider)
state.update(model=model)
_mark_step_done(state, "model")
_save_draft(state)
cfg = PyraConfig( # ── Step 4: API key ────────────────────────────────────────────────
ai=ProviderConfig( if "api_key" not in state.get("completed_steps", []) and provider.requires_key:
provider_id=provider.id, from pyra.vault.reader import get_key as _get_key
model=model, if not _get_key(provider.id):
base_url=provider.base_url, _collect_api_key(provider)
), _mark_step_done(state, "api_key")
general=GeneralConfig(user_name=user_name, purpose=purpose), _save_draft(state)
)
save_config(cfg)
_suggest_plugins(use_cases) # ── Step 5: connection test ────────────────────────────────────────
if "connection" not in state.get("completed_steps", []):
model = _test_connection(provider, model)
state["model"] = model
_mark_step_done(state, "connection")
_save_draft(state)
console.print() # ── Finalise ───────────────────────────────────────────────────────
console.print(Panel( cfg = PyraConfig(
f"[green]Setup complete![/green]\n\n" ai=ProviderConfig(
f"Provider: [bold]{provider.display_name}[/bold]\n" provider_id=provider.id,
f"Model: [bold]{model}[/bold]\n\n" model=model,
"Run [bold cyan]pyra chat[/bold cyan] to start talking.", base_url=provider.base_url,
border_style="green", ),
)) general=GeneralConfig(user_name=user_name, purpose=purpose),
)
save_config(cfg)
_delete_draft()
_suggest_plugins(use_cases)
_offer_telegram_setup_if_selected(use_cases)
console.print()
console.print(Panel(
f"[green]Setup complete![/green]\n\n"
f"Provider: [bold]{provider.display_name}[/bold]\n"
f"Model: [bold]{model}[/bold]\n\n"
"Run [bold cyan]pyra chat[/bold cyan] to start talking.",
border_style="green",
))
except SystemExit:
if state.get("completed_steps"):
console.print()
console.print(
" [dim]Setup paused — run [bold]pyra setup[/bold] to resume.[/dim]"
)
raise
def _collect_user_profile() -> tuple[str, str, list[str]]: def _collect_user_profile() -> tuple[str, str, list[str]]:
@@ -112,6 +239,68 @@ def _suggest_plugins(use_cases: list[str]) -> None:
)) ))
def _offer_telegram_setup_if_selected(use_cases: list[str]) -> None:
"""If the user chose 'Communication bots', offer to install and configure telegram_bot."""
relevant = any(
"telegram_bot" in _USE_CASE_PLUGINS.get(uc, [])
for uc in use_cases
)
if not relevant:
return
console.print()
try:
answer = questionary.confirm(
"Set up the Telegram bot for remote access to Pyra?", default=True
).ask()
except (KeyboardInterrupt, EOFError):
return
if not answer:
console.print(
" [dim]You can do this later: pyra plugin install telegram_bot[/dim]"
)
return
from pyra.plugins.install import get_bundled_plugins_dir, install_bundled_plugin
from pyra.plugins.loader import load_plugin_by_name
from pyra.utils.paths import pyra_home
from pyra.vault.writer import set_key
bundled_dir = get_bundled_plugins_dir()
plugins_dir = pyra_home() / "plugins"
try:
install_bundled_plugin("telegram_bot", bundled_dir, plugins_dir)
except Exception as exc:
console.print(f"[red]Could not install telegram_bot:[/red] {exc}")
return
p = load_plugin_by_name("telegram_bot", plugins_dir)
if p is None:
console.print("[red]Could not load telegram_bot for setup.[/red]")
return
try:
p.setup(console, set_key)
except (KeyboardInterrupt, EOFError):
console.print("[dim]Telegram setup skipped.[/dim]")
return
except Exception as exc:
console.print(f"[red]Telegram setup error:[/red] {exc}")
return
try:
cfg = load_config()
if "telegram_bot" not in cfg.plugins.enabled:
cfg.plugins.enabled.append("telegram_bot")
save_config(cfg)
console.print("[green]Telegram bot enabled.[/green]")
except Exception:
console.print(
" [dim]Enable manually: pyra plugin enable telegram_bot[/dim]"
)
def _choose_provider() -> Provider: def _choose_provider() -> Provider:
local = [p for p in PROVIDERS if p.group == "Local"] local = [p for p in PROVIDERS if p.group == "Local"]
cloud = [p for p in PROVIDERS if p.group == "Cloud"] cloud = [p for p in PROVIDERS if p.group == "Cloud"]
@@ -135,26 +324,160 @@ def _choose_provider() -> Provider:
if provider.connectivity_check: if provider.connectivity_check:
_check_local_server(provider) _check_local_server(provider)
_show_local_model_status(provider)
return provider return provider
def _classify_error(exc: Exception) -> tuple[str, str]:
"""Return (short_label, resolution_hint) for a provider or network error."""
name = type(exc).__name__
module = type(exc).__module__ or ""
is_llm = "litellm" in module or "openai" in module
if is_llm:
if "AuthenticationError" in name:
return (
"Invalid API key",
"The provider rejected your API key.\n"
"Double-check it on the provider dashboard and re-enter it.",
)
if "NotFoundError" in name:
return (
"Model not found",
"The model name doesn't exist for this provider.\n"
"Check the exact model identifier on the provider's model list.",
)
if "RateLimitError" in name:
return (
"Rate limit reached",
"You've exceeded the provider's request rate.\n"
"Wait a few seconds and retry.",
)
if "ServiceUnavailable" in name:
return (
"Service temporarily unavailable",
"The provider's servers returned a 5xx error.\n"
"This is usually transient — wait a minute and retry.",
)
if "APIConnectionError" in name or "ConnectError" in name:
return (
"Cannot reach provider",
"A network error prevented the connection.\n"
"Check your internet connection and firewall settings.",
)
if "Timeout" in name:
return (
"Request timed out",
"The provider did not respond in time.\n"
"The service may be overloaded — retry in a moment.",
)
if "BadRequestError" in name or "InvalidRequest" in name:
return (
"Bad request",
"The request was rejected — the model name or parameters may be wrong.\n"
"Verify the exact model identifier.",
)
return ("Provider error", str(exc)[:300])
if "httpx" in module:
if "ConnectError" in name or "ConnectTimeout" in name:
return (
"Server not reachable",
"Could not connect to the local server.\n"
"Make sure it is running and listening on the expected address.",
)
if "Timeout" in name:
return (
"Connection timed out",
"The local server did not respond in time.\n"
"It may still be starting up — wait a moment and retry.",
)
if "HTTPStatusError" in name:
code = getattr(getattr(exc, "response", None), "status_code", 0)
if code == 401:
return ("Unauthorized (401)", "The server requires credentials that were not provided.")
if code == 404:
return ("Endpoint not found (404)", "The API endpoint was not found — check the server version.")
return (f"HTTP {code}", f"The server returned an unexpected error (HTTP {code}).")
return ("Unexpected error", str(exc)[:300])
def _check_local_server(provider: Provider) -> None: def _check_local_server(provider: Provider) -> None:
console.print(f" Checking connection to [bold]{provider.display_name}[/bold]...", end=" ") while True:
try:
resp = httpx.get(provider.connectivity_check, timeout=3.0)
resp.raise_for_status()
console.print("[green]✓[/green]")
except Exception:
console.print("[yellow]✗ (server not reachable)[/yellow]")
console.print( console.print(
f" [yellow]Warning:[/yellow] Could not reach {provider.base_url}.\n" f" Checking connection to [bold]{provider.display_name}[/bold]...", end=" "
f" Make sure {provider.display_name} is running before using Pyra."
) )
try:
resp = httpx.get(provider.connectivity_check, timeout=3.0)
resp.raise_for_status()
console.print("[green]✓[/green]")
return
except Exception as exc:
label, hint = _classify_error(exc)
console.print("[yellow]✗[/yellow]")
console.print()
console.print(Panel(
f"[bold yellow]{label}[/bold yellow]\n\n{hint}",
title="Connection problem",
border_style="yellow",
))
action = questionary.select(
"How would you like to proceed?",
choices=[
questionary.Choice("Retry", value="retry"),
questionary.Choice(
"Continue anyway (model list may be unavailable)", value="continue"
),
questionary.Choice("Abort setup", value="abort"),
],
).ask()
if action is None or action == "abort":
raise SystemExit(0)
if action == "continue":
return
# "retry" → loop
def fetch_loaded_models(provider: Provider) -> list[str]:
"""Return models currently loaded in RAM from a local provider's API."""
if not provider.base_url:
return []
try:
if provider.id == "ollama":
resp = httpx.get(f"{provider.base_url}/api/ps", timeout=3.0)
resp.raise_for_status()
return [m["name"] for m in resp.json().get("models", [])]
elif provider.id == "lmstudio":
resp = httpx.get("http://localhost:1234/api/v0/models", timeout=3.0)
resp.raise_for_status()
return [
m["id"] for m in resp.json().get("data", [])
if m.get("state") == "loaded"
]
else: # llamacpp — /models returns only the active loaded model
resp = httpx.get(f"{provider.base_url}/models", timeout=3.0)
resp.raise_for_status()
return [m["id"] for m in resp.json().get("data", [])]
except Exception:
return []
def _show_local_model_status(provider: Provider) -> None:
"""Print a one-line status showing which models are currently loaded."""
models = fetch_loaded_models(provider)
if not models:
console.print(" [yellow]No model currently loaded[/yellow]")
elif len(models) == 1:
console.print(f" [green]Loaded model:[/green] {models[0]}")
else:
names = ", ".join(models)
console.print(f" [green]{len(models)} models loaded:[/green] {names}")
def _fetch_local_models(provider: Provider) -> list[str]: def _fetch_local_models(provider: Provider) -> list[str]:
"""Return currently loaded/available models from a local provider's API.""" """Return currently loaded models from a local provider's API."""
if not provider.base_url: if not provider.base_url:
return [] return []
try: try:
@@ -162,6 +485,13 @@ def _fetch_local_models(provider: Provider) -> list[str]:
resp = httpx.get(f"{provider.base_url}/api/tags", timeout=3.0) resp = httpx.get(f"{provider.base_url}/api/tags", timeout=3.0)
resp.raise_for_status() resp.raise_for_status()
return [m["name"] for m in resp.json().get("models", [])] return [m["name"] for m in resp.json().get("models", [])]
elif provider.id == "lmstudio":
resp = httpx.get("http://localhost:1234/api/v0/models", timeout=3.0)
resp.raise_for_status()
return [
m["id"] for m in resp.json().get("data", [])
if m.get("state") == "loaded"
]
else: else:
resp = httpx.get(f"{provider.base_url}/models", timeout=3.0) resp = httpx.get(f"{provider.base_url}/models", timeout=3.0)
resp.raise_for_status() resp.raise_for_status()
@@ -264,26 +594,68 @@ def _collect_api_key(provider: Provider) -> None:
console.print(" [green]✓ Key stored in vault[/green]") console.print(" [green]✓ Key stored in vault[/green]")
def _test_connection(provider: Provider, model: str) -> None: def _test_connection(provider: Provider, model: str) -> str:
from pyra.vault.reader import get_key from pyra.vault.reader import get_key
console.print("\n Running test call...", end=" ") while True:
try: console.print("\n Running connection test...", end=" ")
import litellm try:
import litellm
# Local providers don't need a real key but litellm still requires the field api_key = get_key(provider.id) if provider.requires_key else "local"
api_key = get_key(provider.id) if provider.requires_key else "local" kwargs: dict = {
kwargs: dict = { "model": f"{provider.litellm_prefix}{model}",
"model": f"{provider.litellm_prefix}{model}", "messages": [{"role": "user", "content": "Reply with exactly: OK"}],
"messages": [{"role": "user", "content": "Reply with exactly: OK"}], "max_tokens": 10,
"max_tokens": 10, "api_key": api_key,
"api_key": api_key, }
} if provider.base_url:
if provider.base_url: kwargs["api_base"] = provider.base_url
kwargs["api_base"] = provider.base_url
litellm.completion(**kwargs) litellm.completion(**kwargs)
console.print("[green]✓ Connection OK[/green]") console.print("[green]✓ Connection OK[/green]")
except Exception as exc: return model
console.print(f"[yellow]✗ Test call failed: {exc}[/yellow]")
console.print(" [dim]You can still proceed — check your config with 'pyra setup' again.[/dim]") except Exception as exc:
label, hint = _classify_error(exc)
console.print("[red]✗[/red]")
console.print()
console.print(Panel(
f"[bold red]{label}[/bold red]\n\n{hint}",
title="Test call failed",
border_style="red",
))
exc_name = type(exc).__name__
is_auth_error = "AuthenticationError" in exc_name
is_model_error = any(
kw in exc_name for kw in ("NotFoundError", "BadRequestError", "InvalidRequest")
)
choices = [questionary.Choice("Retry", value="retry")]
if is_model_error:
choices.append(questionary.Choice("Change model", value="change_model"))
if provider.requires_key and is_auth_error:
choices.append(questionary.Choice("Re-enter API key", value="rekey"))
choices += [
questionary.Choice("Skip test and continue setup", value="skip"),
questionary.Choice("Abort setup", value="abort"),
]
action = questionary.select(
"How would you like to proceed?",
choices=choices,
).ask()
if action is None or action == "abort":
raise SystemExit(0)
if action == "skip":
console.print(
" [dim]Test skipped — run [bold]pyra setup[/bold] again if chat doesn't work.[/dim]"
)
return model
if action == "change_model":
model = _choose_model(provider)
elif action == "rekey":
_collect_api_key(provider)
# loop → retry (with possibly new model or key)
+44
View File
@@ -137,3 +137,47 @@ def test_main_skips_setup_when_config_exists(tmp_pyra_home, monkeypatch):
def test_config_slash_command_registered(): def test_config_slash_command_registered():
from pyra.chat.session import _STATIC_COMMANDS from pyra.chat.session import _STATIC_COMMANDS
assert "/config" in _STATIC_COMMANDS assert "/config" in _STATIC_COMMANDS
# ── plugin install ────────────────────────────────────────────────────────────
def test_plugin_install_decline_setup(tmp_pyra_home, monkeypatch):
"""Declining 'Configure now?' shows manual instructions and exits cleanly."""
from unittest.mock import MagicMock
import questionary
monkeypatch.setattr(
"pyra.plugins.install.install_bundled_plugin", lambda *a, **kw: None
)
monkeypatch.setattr(
questionary, "confirm",
lambda *a, **kw: MagicMock(ask=lambda: False),
)
runner = CliRunner()
result = runner.invoke(main, ["plugin", "install", "telegram_bot"])
assert result.exit_code == 0
assert "Installed" in result.output
assert "Configure later" in result.output
def test_plugin_install_error_does_not_prompt(tmp_pyra_home, monkeypatch):
"""If install fails, the configure prompt is never shown."""
from unittest.mock import MagicMock
import questionary
monkeypatch.setattr(
"pyra.plugins.install.install_bundled_plugin",
lambda *a, **kw: (_ for _ in ()).throw(FileNotFoundError("not found")),
)
confirm_calls = []
monkeypatch.setattr(
questionary, "confirm",
lambda *a, **kw: confirm_calls.append(1) or MagicMock(ask=lambda: False),
)
runner = CliRunner()
result = runner.invoke(main, ["plugin", "install", "telegram_bot"])
assert result.exit_code == 0
assert "Error" in result.output
assert len(confirm_calls) == 0 # prompt never reached
+226
View File
@@ -0,0 +1,226 @@
"""Unit tests for the daemon core — PluginSupervisor and IPC handler dispatch."""
from __future__ import annotations
import asyncio
import pytest
from pyra.daemon.core import PluginSupervisor, _make_ipc_handler
# ── Helpers ───────────────────────────────────────────────────────────────────
async def _drain(n: int = 20) -> None:
"""Yield to the event loop n times to let scheduled tasks run."""
for _ in range(n):
await asyncio.sleep(0)
# ── PluginSupervisor — lifecycle ──────────────────────────────────────────────
async def test_supervisor_empty_starts_and_stops_cleanly() -> None:
sup = PluginSupervisor()
await sup.start()
await sup.stop()
assert sup.status() == []
async def test_supervisor_runs_task_to_completion() -> None:
done = asyncio.Event()
async def task():
done.set()
sup = PluginSupervisor()
sup._RESTART_DELAY = 0.0
sup.add_task("t", task)
await sup.start()
await asyncio.wait_for(done.wait(), timeout=1.0)
await sup.stop()
assert sup._records[0].restart_count == 0
assert sup._records[0].last_error is None
async def test_supervisor_restarts_crashed_task() -> None:
call_count = 0
completed = asyncio.Event()
async def flaky():
nonlocal call_count
call_count += 1
if call_count == 1:
raise RuntimeError("first call fails")
completed.set()
sup = PluginSupervisor()
sup._RESTART_DELAY = 0.0
sup.add_task("flaky", flaky)
await sup.start()
await asyncio.wait_for(completed.wait(), timeout=1.0)
await sup.stop()
assert sup._records[0].restart_count == 1
assert "RuntimeError" in (sup._records[0].last_error or "")
async def test_supervisor_gives_up_after_max_restarts() -> None:
async def always_fails():
raise ValueError("always")
sup = PluginSupervisor()
sup._RESTART_DELAY = 0.0
sup._MAX_RESTARTS = 3
sup.add_task("failing", always_fails)
await sup.start()
# Allow enough iterations for 3 restarts + give-up.
for _ in range(200):
await asyncio.sleep(0)
if sup._records[0].task and sup._records[0].task.done():
break
await sup.stop()
assert sup._records[0].restart_count == 3
assert sup._records[0].last_error is not None
# ── PluginSupervisor — status ─────────────────────────────────────────────────
async def test_supervisor_status_returns_correct_shape() -> None:
sup = PluginSupervisor()
sup._RESTART_DELAY = 0.0
async def noop():
pass
sup.add_task("noop", noop)
await sup.start()
await _drain()
statuses = sup.status()
assert len(statuses) == 1
s = statuses[0]
assert set(s.keys()) == {"name", "alive", "restart_count", "last_error"}
assert s["name"] == "noop"
assert isinstance(s["alive"], bool)
assert isinstance(s["restart_count"], int)
await sup.stop()
async def test_supervisor_status_empty_when_no_tasks() -> None:
sup = PluginSupervisor()
await sup.start()
assert sup.status() == []
await sup.stop()
# ── PluginSupervisor — reload ─────────────────────────────────────────────────
async def test_supervisor_reload_restarts_tasks() -> None:
call_count = 0
async def counting():
nonlocal call_count
call_count += 1
# Hang until cancelled so reload can cancel it.
await asyncio.sleep(10)
sup = PluginSupervisor()
sup._RESTART_DELAY = 0.0
sup.add_task("c", counting)
await sup.start()
await _drain()
assert call_count == 1
await sup.reload()
await _drain()
# After reload, the task should have been restarted (called a second time).
assert call_count == 2
assert sup._records[0].restart_count == 0 # reset by reload
await sup.stop()
async def test_supervisor_reload_resets_restart_count() -> None:
call_count = 0
async def flaky():
nonlocal call_count
call_count += 1
if call_count <= 2:
raise RuntimeError("crash")
await asyncio.sleep(10)
sup = PluginSupervisor()
sup._RESTART_DELAY = 0.0
sup.add_task("f", flaky)
await sup.start()
# Wait for 2 crashes to accumulate.
for _ in range(200):
await asyncio.sleep(0)
if sup._records[0].restart_count >= 2:
break
assert sup._records[0].restart_count == 2
await sup.reload()
# Reload must reset the counter.
assert sup._records[0].restart_count == 0
await sup.stop()
# ── IPC command handler ───────────────────────────────────────────────────────
async def test_ipc_handler_ping() -> None:
sup = PluginSupervisor()
handler = _make_ipc_handler(sup)
resp = await handler({"cmd": "ping"})
assert resp["ok"] is True
assert resp["data"]["pong"] is True
async def test_ipc_handler_status_shape() -> None:
sup = PluginSupervisor()
handler = _make_ipc_handler(sup)
resp = await handler({"cmd": "status"})
assert resp["ok"] is True
assert "uptime" in resp["data"]
assert "pid" in resp["data"]
assert "tasks" in resp["data"]
assert isinstance(resp["data"]["tasks"], list)
async def test_ipc_handler_stop_signals_shutdown() -> None:
sup = PluginSupervisor()
handler = _make_ipc_handler(sup)
assert not sup._shutdown.is_set()
resp = await handler({"cmd": "stop"})
assert resp["ok"] is True
assert sup._shutdown.is_set()
async def test_ipc_handler_reload_returns_task_count() -> None:
sup = PluginSupervisor()
handler = _make_ipc_handler(sup)
resp = await handler({"cmd": "reload"})
assert resp["ok"] is True
assert resp["data"]["tasks_reloaded"] == 0
async def test_ipc_handler_unknown_command() -> None:
sup = PluginSupervisor()
handler = _make_ipc_handler(sup)
resp = await handler({"cmd": "bogus"})
assert resp["ok"] is False
assert "error" in resp["data"]
assert "bogus" in resp["data"]["error"]
+162
View File
@@ -0,0 +1,162 @@
"""Unit tests for the IPC layer."""
from __future__ import annotations
import asyncio
import os
import sys
import tempfile
from pathlib import Path
import pytest
from pyra.daemon.ipc import (
IpcClient,
IpcMessage,
IpcResponse,
IpcServer,
decode_message,
encode_message,
is_unix_socket,
)
@pytest.fixture
def sock_path():
"""Short socket path that fits within macOS's 104-char AF_UNIX limit."""
with tempfile.TemporaryDirectory(dir="/tmp") as d:
yield Path(d) / "t.sock"
# ── Protocol encode / decode ──────────────────────────────────────────────────
def test_encode_appends_newline() -> None:
data = encode_message({"cmd": "ping"})
assert data.endswith(b"\n")
def test_encode_is_valid_json() -> None:
import json
data = encode_message({"cmd": "status", "extra": 42})
assert json.loads(data) == {"cmd": "status", "extra": 42}
def test_decode_roundtrip() -> None:
msg: IpcMessage = {"cmd": "stop"}
assert decode_message(encode_message(msg)) == msg
def test_decode_strips_newline() -> None:
assert decode_message(b'{"cmd": "stop"}\n')["cmd"] == "stop"
def test_decode_raises_on_bad_json() -> None:
with pytest.raises(ValueError, match="Invalid IPC message"):
decode_message(b"not json\n")
def test_decode_raises_on_empty_line() -> None:
with pytest.raises(ValueError):
decode_message(b"\n")
# ── is_unix_socket ────────────────────────────────────────────────────────────
def test_is_unix_socket_matches_platform() -> None:
if sys.platform == "win32":
assert not is_unix_socket()
else:
assert is_unix_socket()
# ── Server + client roundtrip (Unix only) ─────────────────────────────────────
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
async def test_server_client_ping(sock_path: Path) -> None:
async def handler(msg: IpcMessage) -> IpcResponse:
return {"ok": True, "data": {"pong": True}}
server = IpcServer(sock_path, handler)
await server.start()
try:
resp = await IpcClient(sock_path).send({"cmd": "ping"})
assert resp["ok"] is True
assert resp["data"]["pong"] is True
finally:
await server.stop()
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
async def test_server_echoes_error_for_bad_json(sock_path: Path) -> None:
async def handler(msg: IpcMessage) -> IpcResponse:
return {"ok": True, "data": {}}
server = IpcServer(sock_path, handler)
await server.start()
try:
reader, writer = await asyncio.open_unix_connection(str(sock_path))
writer.write(b"not valid json\n")
await writer.drain()
line = await asyncio.wait_for(reader.readline(), timeout=3.0)
resp = decode_message(line)
assert resp["ok"] is False
assert "error" in resp["data"]
finally:
try:
writer.close()
except Exception:
pass
await server.stop()
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
async def test_handler_response_returned_to_client(sock_path: Path) -> None:
async def handler(msg: IpcMessage) -> IpcResponse:
if msg.get("cmd") == "status":
return {"ok": True, "data": {"uptime": 99.0}}
return {"ok": False, "data": {"error": "unknown"}}
server = IpcServer(sock_path, handler)
await server.start()
try:
resp = await IpcClient(sock_path).send({"cmd": "status"})
assert resp["ok"] is True
assert resp["data"]["uptime"] == 99.0
resp2 = await IpcClient(sock_path).send({"cmd": "bogus"})
assert resp2["ok"] is False
finally:
await server.stop()
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
async def test_client_raises_when_no_server(sock_path: Path) -> None:
client = IpcClient(sock_path)
with pytest.raises((ConnectionRefusedError, FileNotFoundError, OSError)):
await client.send({"cmd": "ping"})
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
async def test_socket_file_chmod_600(sock_path: Path) -> None:
async def handler(msg: IpcMessage) -> IpcResponse:
return {"ok": True, "data": {}}
server = IpcServer(sock_path, handler)
await server.start()
try:
mode = oct(sock_path.stat().st_mode & 0o777)
assert mode == oct(0o600), f"Expected 0o600, got {mode}"
finally:
await server.stop()
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
async def test_stop_removes_socket_file(sock_path: Path) -> None:
async def handler(msg: IpcMessage) -> IpcResponse:
return {"ok": True, "data": {}}
server = IpcServer(sock_path, handler)
await server.start()
assert sock_path.exists()
await server.stop()
assert not sock_path.exists()
+103
View File
@@ -0,0 +1,103 @@
"""Unit tests for daemon PID file management."""
from __future__ import annotations
import os
from pathlib import Path
import pytest
from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path
def test_write_creates_file(tmp_path: Path) -> None:
p = PidFile(tmp_path / "daemon.pid")
p.write()
assert (tmp_path / "daemon.pid").exists()
assert int((tmp_path / "daemon.pid").read_text().strip()) == os.getpid()
def test_read_returns_none_when_absent(tmp_path: Path) -> None:
p = PidFile(tmp_path / "daemon.pid")
assert p.read() is None
def test_read_returns_pid_when_present(tmp_path: Path) -> None:
pid_file = tmp_path / "daemon.pid"
pid_file.write_text("12345")
p = PidFile(pid_file)
assert p.read() == 12345
def test_read_returns_none_on_bad_content(tmp_path: Path) -> None:
pid_file = tmp_path / "daemon.pid"
pid_file.write_text("not-a-number")
p = PidFile(pid_file)
assert p.read() is None
def test_is_stale_false_for_self(tmp_path: Path) -> None:
p = PidFile(tmp_path / "daemon.pid")
p.write()
assert not p.is_stale()
def test_is_stale_true_for_dead_pid(tmp_path: Path) -> None:
pid_file = tmp_path / "daemon.pid"
pid_file.write_text("999999999") # unrealistically large PID
p = PidFile(pid_file)
assert p.is_stale()
def test_is_stale_false_when_file_absent(tmp_path: Path) -> None:
p = PidFile(tmp_path / "daemon.pid")
assert not p.is_stale()
def test_remove_deletes_file(tmp_path: Path) -> None:
p = PidFile(tmp_path / "daemon.pid")
p.write()
p.remove()
assert not (tmp_path / "daemon.pid").exists()
def test_remove_is_idempotent(tmp_path: Path) -> None:
p = PidFile(tmp_path / "daemon.pid")
p.remove() # must not raise
def test_context_manager_writes_and_removes(tmp_path: Path) -> None:
pid_file = tmp_path / "daemon.pid"
p = PidFile(pid_file)
with p:
assert pid_file.exists()
assert int(pid_file.read_text().strip()) == os.getpid()
assert not pid_file.exists()
def test_write_raises_when_live_pid_exists(tmp_path: Path) -> None:
p = PidFile(tmp_path / "daemon.pid")
p.write() # writes self PID (which is alive)
p2 = PidFile(tmp_path / "daemon.pid")
with pytest.raises(PidFileError):
p2.write()
def test_write_succeeds_over_stale_pid(tmp_path: Path) -> None:
pid_file = tmp_path / "daemon.pid"
pid_file.write_text("999999999") # stale
p = PidFile(pid_file)
p.write() # should not raise
assert int(pid_file.read_text().strip()) == os.getpid()
def test_resolve_pid_path_expands_tilde() -> None:
result = resolve_pid_path("~/.pyra/daemon.pid")
assert not str(result).startswith("~")
assert result.is_absolute()
def test_resolve_pid_path_absolute_unchanged(tmp_path: Path) -> None:
path = tmp_path / "daemon.pid"
result = resolve_pid_path(str(path))
assert result == path
+189
View File
@@ -0,0 +1,189 @@
"""Unit tests for daemon service file generation and platform detection."""
from __future__ import annotations
import subprocess
import sys
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from pyra.daemon.service import (
detect_platform,
find_pyra_executable,
render_launchd_plist,
render_systemd_unit,
render_schtasks_xml,
)
# ── Template rendering ────────────────────────────────────────────────────────
def test_render_launchd_plist_contains_exe() -> None:
xml = render_launchd_plist("/usr/local/bin/pyra", "~/.pyra/daemon.log", "~/.pyra/daemon.pid")
assert "/usr/local/bin/pyra" in xml
assert "<string>daemon</string>" in xml
assert "<string>run</string>" in xml
assert "com.pyra.daemon" in xml
assert "<true/>" in xml # KeepAlive and RunAtLoad
def test_render_launchd_plist_expands_log_tilde() -> None:
xml = render_launchd_plist("/bin/pyra", "~/.pyra/daemon.log", "~/.pyra/daemon.pid")
assert "~" not in xml
def test_render_systemd_unit_contains_exe() -> None:
unit = render_systemd_unit("/usr/local/bin/pyra", "~/.pyra/daemon.log")
assert "ExecStart=/usr/local/bin/pyra daemon run" in unit
assert "Restart=on-failure" in unit
assert "Type=simple" in unit
assert "WantedBy=default.target" in unit
def test_render_systemd_unit_expands_log_tilde() -> None:
unit = render_systemd_unit("/bin/pyra", "~/.pyra/daemon.log")
assert "~" not in unit
def test_render_schtasks_xml_contains_exe() -> None:
xml = render_schtasks_xml("C:\\Users\\test\\pyra.exe")
assert "C:\\Users\\test\\pyra.exe" in xml
assert "LogonTrigger" in xml
assert "daemon run" in xml
assert "IgnoreNew" in xml
def test_render_schtasks_xml_no_time_limit() -> None:
xml = render_schtasks_xml("pyra.exe")
assert "PT0S" in xml # ExecutionTimeLimit=PT0S means unlimited
# ── Platform detection ────────────────────────────────────────────────────────
def test_detect_platform_returns_known_value() -> None:
result = detect_platform()
assert result in ("macos", "linux", "windows")
@pytest.mark.parametrize("system,expected", [
("Darwin", "macos"),
("Linux", "linux"),
("Windows", "windows"),
])
def test_detect_platform_maps_correctly(system: str, expected: str) -> None:
with patch("platform.system", return_value=system):
assert detect_platform() == expected
def test_detect_platform_raises_on_unknown() -> None:
with patch("platform.system", return_value="FreeBSD"):
with pytest.raises(RuntimeError, match="Unsupported platform"):
detect_platform()
# ── Executable detection ──────────────────────────────────────────────────────
def test_find_pyra_executable_returns_string() -> None:
result = find_pyra_executable()
assert isinstance(result, str)
assert len(result) > 0
def test_find_pyra_executable_uses_which_when_available(tmp_path: Path) -> None:
fake_pyra = tmp_path / "pyra"
fake_pyra.touch()
with patch("shutil.which", return_value=str(fake_pyra)):
assert find_pyra_executable() == str(fake_pyra)
def test_find_pyra_executable_falls_back_to_sibling(tmp_path: Path) -> None:
fake_python = tmp_path / "python3"
fake_pyra = tmp_path / "pyra"
fake_pyra.touch()
with patch("shutil.which", return_value=None):
with patch("sys.executable", str(fake_python)):
assert find_pyra_executable() == str(fake_pyra)
def test_find_pyra_executable_falls_back_to_module(tmp_path: Path) -> None:
fake_python = tmp_path / "python3"
with patch("shutil.which", return_value=None):
with patch("sys.executable", str(fake_python)):
result = find_pyra_executable()
assert result == f"{fake_python} -m pyra"
# ── Install / uninstall (subprocess mocked) ───────────────────────────────────
@pytest.mark.skipif(sys.platform == "win32", reason="launchd install is macOS-only")
def test_install_launchd_writes_plist_and_calls_launchctl(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
import pyra.daemon.service as svc
plist_path = tmp_path / "Library" / "LaunchAgents" / "com.pyra.daemon.plist"
monkeypatch.setattr(svc, "_PLIST_PATH", plist_path)
calls: list[list[str]] = []
monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: calls.append(cmd))
svc._install_launchd("/usr/local/bin/pyra", "~/.pyra/daemon.log", "~/.pyra/daemon.pid")
assert plist_path.exists()
assert "com.pyra.daemon" in plist_path.read_text()
assert any("launchctl" in c[0] for c in calls)
@pytest.mark.skipif(sys.platform == "win32", reason="systemd install is Linux-only")
def test_install_systemd_writes_unit_and_calls_systemctl(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
import pyra.daemon.service as svc
unit_path = tmp_path / ".config" / "systemd" / "user" / "pyra.service"
monkeypatch.setattr(svc, "_SYSTEMD_UNIT", unit_path)
calls: list[list[str]] = []
monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: calls.append(cmd))
svc._install_systemd("/usr/local/bin/pyra", "~/.pyra/daemon.log")
assert unit_path.exists()
assert "ExecStart" in unit_path.read_text()
assert any("systemctl" in c[0] for c in calls)
@pytest.mark.skipif(sys.platform == "win32", reason="launchd uninstall is macOS-only")
def test_uninstall_launchd_removes_plist(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
import pyra.daemon.service as svc
plist_path = tmp_path / "Library" / "LaunchAgents" / "com.pyra.daemon.plist"
plist_path.parent.mkdir(parents=True)
plist_path.write_text("<plist/>")
monkeypatch.setattr(svc, "_PLIST_PATH", plist_path)
monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: None)
svc._uninstall_launchd()
assert not plist_path.exists()
@pytest.mark.skipif(sys.platform == "win32", reason="systemd uninstall is Linux-only")
def test_uninstall_systemd_removes_unit(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
import pyra.daemon.service as svc
unit_path = tmp_path / ".config" / "systemd" / "user" / "pyra.service"
unit_path.parent.mkdir(parents=True)
unit_path.write_text("[Service]")
monkeypatch.setattr(svc, "_SYSTEMD_UNIT", unit_path)
monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: None)
svc._uninstall_systemd()
assert not unit_path.exists()
+349 -3
View File
@@ -96,14 +96,34 @@ def test_suggest_plugins_multiple_categories(monkeypatch):
# ── _fetch_local_models ──────────────────────────────────────────────────────── # ── _fetch_local_models ────────────────────────────────────────────────────────
def test_fetch_local_models_lmstudio_returns_model_ids(monkeypatch): def test_fetch_local_models_lmstudio_returns_loaded_model_ids(monkeypatch):
import pyra.setup.wizard as wiz import pyra.setup.wizard as wiz
mock_resp = MagicMock() mock_resp = MagicMock()
mock_resp.json.return_value = {"data": [{"id": "gemma-4b"}, {"id": "llama3"}]} mock_resp.json.return_value = {
"data": [
{"id": "gemma-4b", "state": "loaded"},
{"id": "llama3", "state": "not_loaded"},
]
}
mock_resp.raise_for_status = lambda: None mock_resp.raise_for_status = lambda: None
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp) monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
from pyra.setup.providers import get_provider from pyra.setup.providers import get_provider
assert wiz._fetch_local_models(get_provider("lmstudio")) == ["gemma-4b", "llama3"] assert wiz._fetch_local_models(get_provider("lmstudio")) == ["gemma-4b"]
def test_fetch_local_models_lmstudio_filters_unloaded(monkeypatch):
import pyra.setup.wizard as wiz
mock_resp = MagicMock()
mock_resp.json.return_value = {
"data": [
{"id": "model-a", "state": "not_loaded"},
{"id": "model-b", "state": "not_loaded"},
]
}
mock_resp.raise_for_status = lambda: None
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
from pyra.setup.providers import get_provider
assert wiz._fetch_local_models(get_provider("lmstudio")) == []
def test_fetch_local_models_ollama_returns_model_names(monkeypatch): def test_fetch_local_models_ollama_returns_model_names(monkeypatch):
@@ -181,3 +201,329 @@ def test_load_lmstudio_model_returns_false_on_exception(monkeypatch):
import pyra.setup.wizard as wiz import pyra.setup.wizard as wiz
monkeypatch.setattr(wiz.httpx, "post", MagicMock(side_effect=Exception("timeout"))) monkeypatch.setattr(wiz.httpx, "post", MagicMock(side_effect=Exception("timeout")))
assert wiz._load_lmstudio_model("gemma-4b") is False assert wiz._load_lmstudio_model("gemma-4b") is False
# ── _classify_error ───────────────────────────────────────────────────────────
def _fake_llm_exc(name: str) -> Exception:
"""Create a fake litellm exception with the given class name."""
cls = type(name, (Exception,), {"__module__": "litellm.exceptions"})
return cls("test error")
def test_classify_auth_error():
from pyra.setup.wizard import _classify_error
label, hint = _classify_error(_fake_llm_exc("AuthenticationError"))
assert "key" in label.lower()
assert "provider" in hint.lower() or "dashboard" in hint.lower()
def test_classify_not_found_error():
from pyra.setup.wizard import _classify_error
label, hint = _classify_error(_fake_llm_exc("NotFoundError"))
assert "model" in label.lower()
assert "model" in hint.lower()
def test_classify_rate_limit_error():
from pyra.setup.wizard import _classify_error
label, _ = _classify_error(_fake_llm_exc("RateLimitError"))
assert "rate" in label.lower()
def test_classify_service_unavailable():
from pyra.setup.wizard import _classify_error
label, _ = _classify_error(_fake_llm_exc("ServiceUnavailableError"))
assert "unavailable" in label.lower()
def test_classify_api_connection_error():
from pyra.setup.wizard import _classify_error
label, hint = _classify_error(_fake_llm_exc("APIConnectionError"))
assert "reach" in label.lower() or "network" in label.lower() or "connect" in label.lower()
assert "internet" in hint.lower() or "network" in hint.lower()
def test_classify_timeout_error():
from pyra.setup.wizard import _classify_error
label, _ = _classify_error(_fake_llm_exc("TimeoutError"))
assert "time" in label.lower()
def test_classify_bad_request_error():
from pyra.setup.wizard import _classify_error
label, hint = _classify_error(_fake_llm_exc("BadRequestError"))
assert "request" in label.lower() or "bad" in label.lower()
assert "model" in hint.lower()
def test_classify_httpx_connect_error():
import httpx
from pyra.setup.wizard import _classify_error
label, hint = _classify_error(httpx.ConnectError("refused"))
assert "reach" in label.lower() or "reachable" in label.lower()
assert "running" in hint.lower() or "listening" in hint.lower()
def test_classify_httpx_timeout():
import httpx
from pyra.setup.wizard import _classify_error
label, _ = _classify_error(httpx.TimeoutException("timeout"))
assert "time" in label.lower()
def test_classify_generic_error():
from pyra.setup.wizard import _classify_error
label, hint = _classify_error(ValueError("something went wrong"))
assert len(label) > 0
assert "something went wrong" in hint
def test_classify_error_always_returns_two_strings():
from pyra.setup.wizard import _classify_error
for exc in (RuntimeError("boom"), OSError("disk"), KeyError("k")):
label, hint = _classify_error(exc)
assert isinstance(label, str) and len(label) > 0
assert isinstance(hint, str) and len(hint) > 0
# ── _check_local_server retry behaviour ──────────────────────────────────────
def _silent_print(*a, **kw):
pass
def test_check_local_server_success(monkeypatch):
import pyra.setup.wizard as wiz
mock_resp = MagicMock()
mock_resp.raise_for_status = lambda: None
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
monkeypatch.setattr(wiz.console, "print", _silent_print)
from pyra.setup.providers import get_provider
wiz._check_local_server(get_provider("lmstudio")) # must not raise
def test_check_local_server_retry_then_success(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
call_count = {"n": 0}
def flaky_get(*a, **kw):
call_count["n"] += 1
if call_count["n"] == 1:
raise ConnectionError("refused")
m = MagicMock()
m.raise_for_status = lambda: None
return m
monkeypatch.setattr(wiz.httpx, "get", flaky_get)
monkeypatch.setattr(wiz.console, "print", _silent_print)
monkeypatch.setattr(wiz.questionary, "select",
lambda *a, **kw: MagicMock(ask=lambda: "retry"))
wiz._check_local_server(get_provider("lmstudio"))
assert call_count["n"] == 2
def test_check_local_server_abort_raises_system_exit(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=ConnectionError("refused")))
monkeypatch.setattr(wiz.console, "print", _silent_print)
monkeypatch.setattr(wiz.questionary, "select",
lambda *a, **kw: MagicMock(ask=lambda: "abort"))
with pytest.raises(SystemExit):
wiz._check_local_server(get_provider("lmstudio"))
def test_check_local_server_continue_returns(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=ConnectionError("refused")))
monkeypatch.setattr(wiz.console, "print", _silent_print)
monkeypatch.setattr(wiz.questionary, "select",
lambda *a, **kw: MagicMock(ask=lambda: "continue"))
wiz._check_local_server(get_provider("lmstudio")) # must return without raising
# ── draft persistence ─────────────────────────────────────────────────────────
def test_save_and_load_draft(tmp_pyra_home):
import pyra.setup.wizard as wiz
state = {"completed_steps": ["profile"], "user_name": "Alice"}
wiz._save_draft(state)
loaded = wiz._load_draft()
assert loaded == state
def test_load_draft_returns_none_when_no_file(tmp_pyra_home):
import pyra.setup.wizard as wiz
assert wiz._load_draft() is None
def test_delete_draft_removes_file(tmp_pyra_home):
import pyra.setup.wizard as wiz
wiz._save_draft({"completed_steps": []})
wiz._delete_draft()
assert wiz._load_draft() is None
def test_delete_draft_is_idempotent(tmp_pyra_home):
import pyra.setup.wizard as wiz
wiz._delete_draft()
wiz._delete_draft()
def test_mark_step_done_appends_once(tmp_pyra_home):
import pyra.setup.wizard as wiz
state = {}
wiz._mark_step_done(state, "profile")
wiz._mark_step_done(state, "profile")
assert state["completed_steps"].count("profile") == 1
def test_draft_file_has_correct_permissions(tmp_pyra_home):
import os
import pyra.setup.wizard as wiz
wiz._save_draft({"completed_steps": []})
path = wiz._draft_path()
if os.name != "nt":
mode = oct(path.stat().st_mode)[-3:]
assert mode == "600"
# ── fetch_loaded_models ───────────────────────────────────────────────────────
def test_fetch_loaded_models_ollama_uses_api_ps(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
mock_resp = MagicMock()
mock_resp.json.return_value = {"models": [{"name": "llama3:latest"}, {"name": "mistral"}]}
mock_resp.raise_for_status = lambda: None
calls = []
monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: (calls.append(url), mock_resp)[1])
result = wiz.fetch_loaded_models(get_provider("ollama"))
assert result == ["llama3:latest", "mistral"]
assert any("/api/ps" in u for u in calls)
def test_fetch_loaded_models_lmstudio_uses_beta_api_and_filters(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
mock_resp = MagicMock()
mock_resp.json.return_value = {
"data": [
{"id": "gemma-4b", "state": "loaded"},
{"id": "llama3", "state": "not_loaded"},
]
}
mock_resp.raise_for_status = lambda: None
calls = []
monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: (calls.append(url), mock_resp)[1])
result = wiz.fetch_loaded_models(get_provider("lmstudio"))
assert result == ["gemma-4b"]
assert any("/api/v0/models" in u for u in calls)
def test_fetch_loaded_models_lmstudio_filters_unloaded(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
mock_resp = MagicMock()
mock_resp.json.return_value = {
"data": [
{"id": "model-a", "state": "not_loaded"},
{"id": "model-b", "state": "not_loaded"},
]
}
mock_resp.raise_for_status = lambda: None
monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: mock_resp)
assert wiz.fetch_loaded_models(get_provider("lmstudio")) == []
def test_fetch_loaded_models_returns_empty_on_error(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=Exception("conn refused")))
assert wiz.fetch_loaded_models(get_provider("ollama")) == []
def test_fetch_loaded_models_returns_empty_when_no_base_url():
import pyra.setup.wizard as wiz
from pyra.setup.providers import Provider
provider = Provider(
id="test", display_name="Test", requires_key=False,
default_model="x", litellm_prefix="openai/", group="Local",
)
assert wiz.fetch_loaded_models(provider) == []
# ── _show_local_model_status ──────────────────────────────────────────────────
def test_show_local_model_status_one_model(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: ["gemma-4b"])
printed = []
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
wiz._show_local_model_status(get_provider("lmstudio"))
assert any("gemma-4b" in s for s in printed)
def test_show_local_model_status_none(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: [])
printed = []
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
wiz._show_local_model_status(get_provider("lmstudio"))
assert any("No model" in s for s in printed)
def test_show_local_model_status_multiple(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: ["a", "b", "c"])
printed = []
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
wiz._show_local_model_status(get_provider("lmstudio"))
combined = " ".join(printed)
assert "3" in combined or ("a" in combined and "b" in combined)
# ── _test_connection model re-entry ───────────────────────────────────────────
def test_test_connection_change_model(tmp_pyra_home, monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
call_count = {"n": 0}
class FakeNotFound(Exception):
pass
FakeNotFound.__module__ = "litellm.exceptions"
FakeNotFound.__name__ = "NotFoundError"
def fake_completion(**kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
raise FakeNotFound("model not found")
import litellm
monkeypatch.setattr(litellm, "completion", fake_completion)
monkeypatch.setattr(wiz.console, "print", _silent_print)
monkeypatch.setattr(wiz.questionary, "select",
lambda *a, **kw: MagicMock(ask=lambda: "change_model"))
monkeypatch.setattr(wiz.questionary, "text",
lambda *a, **kw: MagicMock(ask=lambda: "new-model"))
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: [])
monkeypatch.setattr("pyra.vault.reader.get_key", lambda k: "sk-test")
provider = get_provider("anthropic")
result = wiz._test_connection(provider, "old-model")
assert result == "new-model"
assert call_count["n"] == 2
+281
View File
@@ -0,0 +1,281 @@
"""Unit tests for the telegram_bot bundled plugin.
Tests cover pure-logic helpers: rate limiter, SQLite history, auth flow,
and tool argument parsing. Handler integration (live Telegram) is not tested.
"""
from __future__ import annotations
import importlib.util
import json
import sys
import time
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
# ── Import helpers straight from the bundled source ──────────────────────────
_PLUGIN_PATH = (
Path(__file__).parent.parent.parent
/ "src/pyra/bundled_plugins/telegram_bot/plugin.py"
)
def _load_plugin_module():
"""Load plugin.py without requiring python-telegram-bot to be installed."""
# Provide stub modules so the top-level imports don't fail in unit tests
for mod_name in [
"bcrypt",
"telegram",
"telegram.ext",
"litellm",
]:
if mod_name not in sys.modules:
stub = MagicMock()
sys.modules[mod_name] = stub
# telegram.ext sub-attributes needed at import time
if mod_name == "telegram.ext":
stub.Application = MagicMock()
stub.CallbackQueryHandler = MagicMock()
stub.CommandHandler = MagicMock()
stub.ContextTypes = MagicMock()
stub.MessageHandler = MagicMock()
stub.filters = MagicMock()
spec = importlib.util.spec_from_file_location("pyra_plugin_telegram_bot", _PLUGIN_PATH)
assert spec and spec.loader
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore[union-attr]
return mod
_mod = _load_plugin_module()
_RateLimiter = _mod._RateLimiter
_load_history = _mod._load_history
_save_history = _mod._save_history
TelegramBotPlugin = _mod.TelegramBotPlugin
# ── Rate limiter ──────────────────────────────────────────────────────────────
class TestRateLimiter:
def test_allows_up_to_limit(self):
rl = _RateLimiter(per_hour=5)
for _ in range(5):
assert rl.allow(user_id=1)
def test_blocks_over_limit(self):
rl = _RateLimiter(per_hour=3)
for _ in range(3):
rl.allow(user_id=1)
assert not rl.allow(user_id=1)
def test_independent_per_user(self):
rl = _RateLimiter(per_hour=2)
rl.allow(1)
rl.allow(1)
assert not rl.allow(1)
assert rl.allow(2)
def test_old_entries_expire(self):
rl = _RateLimiter(per_hour=1)
# Manually populate bucket with an old timestamp
from collections import deque
rl._buckets[99] = deque([time.monotonic() - 3601])
assert rl.allow(99) # old entry removed, new one fits
def test_multiple_users_independent(self):
rl = _RateLimiter(per_hour=1)
rl.allow(10)
assert not rl.allow(10)
assert rl.allow(20)
assert rl.allow(30)
# ── History DB ────────────────────────────────────────────────────────────────
class TestHistoryDB:
def test_empty_history(self, tmp_path, monkeypatch):
db_path = tmp_path / "tg.db"
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
assert _load_history(42) == []
def test_save_and_load(self, tmp_path, monkeypatch):
db_path = tmp_path / "tg.db"
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
msgs = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi there"},
]
_save_history(42, msgs)
assert _load_history(42) == msgs
def test_save_trims_to_max(self, tmp_path, monkeypatch):
db_path = tmp_path / "tg.db"
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
monkeypatch.setattr(_mod, "_MAX_HISTORY", 3)
msgs = [{"role": "user", "content": str(i)} for i in range(10)]
_save_history(1, msgs)
loaded = _load_history(1)
assert len(loaded) == 3
# Should keep the most recent messages
assert loaded[-1]["content"] == "9"
def test_overwrite(self, tmp_path, monkeypatch):
db_path = tmp_path / "tg.db"
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
_save_history(1, [{"role": "user", "content": "first"}])
_save_history(1, [{"role": "user", "content": "second"}])
assert _load_history(1)[0]["content"] == "second"
def test_separate_chats(self, tmp_path, monkeypatch):
db_path = tmp_path / "tg.db"
monkeypatch.setattr(_mod, "_HISTORY_DB", db_path)
_save_history(1, [{"role": "user", "content": "chat1"}])
_save_history(2, [{"role": "user", "content": "chat2"}])
assert _load_history(1)[0]["content"] == "chat1"
assert _load_history(2)[0]["content"] == "chat2"
# ── Plugin: on_load and setup ─────────────────────────────────────────────────
class TestPluginLifecycle:
def test_on_load_stores_vault_reader(self):
plugin = TelegramBotPlugin()
reader = MagicMock(return_value=None)
plugin.on_load(reader)
assert plugin._vault_reader is reader
def test_daemon_tasks_returns_coroutine(self):
import inspect
plugin = TelegramBotPlugin()
plugin.on_load(MagicMock(return_value=None))
tasks = plugin.daemon_tasks()
assert len(tasks) == 1
assert inspect.iscoroutine(tasks[0])
tasks[0].close() # prevent "coroutine never awaited" warning
def test_get_plugin_factory(self):
plugin = _mod.get_plugin()
assert isinstance(plugin, TelegramBotPlugin)
assert plugin.name == "telegram_bot"
def test_config_fields(self):
plugin = TelegramBotPlugin()
fields = plugin.config_fields()
assert any(f.key == "rate_limit" for f in fields)
def _patch_setup(self, token, allowed, pass1, pass2):
"""Return a context manager that patches all questionary calls used by setup()."""
pw_answers = iter([token, pass1, pass2])
return (
patch("questionary.password",
side_effect=lambda *a, **kw: MagicMock(ask=lambda: next(pw_answers))),
patch("questionary.text",
return_value=MagicMock(ask=lambda: allowed)),
patch("questionary.press_any_key_to_continue",
return_value=MagicMock(ask=lambda: None)),
)
def test_setup_mismatched_passphrase(self):
"""setup() writes nothing to the vault when passphrases don't match."""
plugin = TelegramBotPlugin()
console = MagicMock()
vault_writer = MagicMock()
pw_patch, text_patch, pakc_patch = self._patch_setup(
"fake-token", "123456789", "pass1", "pass2"
)
with pw_patch, text_patch, pakc_patch:
plugin.setup(console, vault_writer)
vault_writer.assert_not_called()
console.print.assert_called_with(
"[red]Passphrases do not match. Run setup again to retry.[/red]"
)
def test_setup_happy_path(self):
"""setup() writes all three vault keys when credentials are valid."""
plugin = TelegramBotPlugin()
console = MagicMock()
vault_writer = MagicMock()
pw_patch, text_patch, pakc_patch = self._patch_setup(
"real-token", "111222333", "s3cr3t", "s3cr3t"
)
with pw_patch, text_patch, pakc_patch:
plugin.setup(console, vault_writer)
calls = {call[0][0]: call[0][1] for call in vault_writer.call_args_list}
assert calls.get("plugin:telegram_bot:token") == "real-token"
assert calls.get("plugin:telegram_bot:allowed_users") == "111222333"
assert "plugin:telegram_bot:passphrase_hash" in calls
def test_setup_cancelled_on_empty_token(self):
"""setup() exits without writing if the token prompt is cancelled."""
plugin = TelegramBotPlugin()
console = MagicMock()
vault_writer = MagicMock()
with patch("questionary.password", return_value=MagicMock(ask=lambda: None)), \
patch("questionary.press_any_key_to_continue",
return_value=MagicMock(ask=lambda: None)):
plugin.setup(console, vault_writer)
vault_writer.assert_not_called()
# ── Auth session state ────────────────────────────────────────────────────────
class TestAuthState:
def test_session_defaults(self):
plugin = TelegramBotPlugin()
session = plugin._sessions.setdefault(
99, {"authenticated": False, "awaiting_passphrase": False, "attempts": 0}
)
assert not session["authenticated"]
assert not session["awaiting_passphrase"]
assert session["attempts"] == 0
def test_session_per_chat(self):
plugin = TelegramBotPlugin()
plugin._sessions[1] = {"authenticated": True, "awaiting_passphrase": False, "attempts": 0}
plugin._sessions[2] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 1}
assert plugin._sessions[1]["authenticated"]
assert not plugin._sessions[2]["authenticated"]
assert plugin._sessions[2]["attempts"] == 1
def test_session_auth_flag(self):
plugin = TelegramBotPlugin()
plugin._sessions[5] = {"authenticated": False, "awaiting_passphrase": True, "attempts": 0}
# Simulate successful auth
plugin._sessions[5]["authenticated"] = True
plugin._sessions[5]["awaiting_passphrase"] = False
assert plugin._sessions[5]["authenticated"]
# ── Tool argument parsing ─────────────────────────────────────────────────────
class TestToolArgParsing:
def test_valid_json_string(self):
args_raw = '{"query": "hello"}'
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
assert args == {"query": "hello"}
def test_dict_passthrough(self):
args_raw = {"query": "hello"}
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
assert args == {"query": "hello"}
def test_invalid_json_caught(self):
args_raw = "not json"
try:
json.loads(args_raw)
assert False, "Should have raised"
except json.JSONDecodeError:
pass
def test_result_truncated_to_4000(self):
long_result = "x" * 5000
assert len(long_result[:4000]) == 4000