23 Commits

Author SHA1 Message Date
curo1305 fac6e7e77e chore(email): sharpen tool descriptions and rename email_list_inbox
- Rename email_list_inbox → email_list_folder (works on any folder, not just inbox)
- email_list_folder / email_search: distinguish browse-by-recency vs filter-by-query
- email_move / email_delete: clarify single-email scope; point to email_bulk_action for multiple
- email_bulk_action: clarify it handles multiple emails; point to move/delete for single
- email_create_folder: remove redundant "ask user" instruction (requires_approval handles it)
- Update tests to reflect renamed tool

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 23:50:24 +02:00
curo1305 d3fc4e2d42 feat(email): implement full email plugin
Supports Gmail (OAuth + filter API), Microsoft 365 (OAuth + full rule CRUD via
O365), ProtonMail (Bridge, paid only), and generic IMAP/SMTP providers.

16 tools: list_inbox, read, send, reply, forward, move, delete, mark_read,
search, list_folders, create_folder, inbox_summary, list_rules, create_rule,
delete_rule, bulk_action.

Background daemon task monitors inbox via IMAP IDLE and publishes new-email
events to daemon/events.py for messaging bot pickup. Setup wizard warns
explicitly about ProtonMail Bridge + paid plan requirement and recommends
a messaging bot if none is configured.

36 unit tests covering: HTML stripping, header decoding, raw message parsing,
IMAP search builder, Gmail/Outlook rule normalisation, folder-not-found path,
events bus, Bridge connectivity guard, and tool approval flags.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 23:26:39 +02:00
curo1305 8d4917f7ca feat(daemon): add async event bus for inter-plugin notifications
Adds daemon/events.py — a lightweight asyncio.Queue-based publish/subscribe
bus that lets daemon tasks communicate without direct imports between plugins.
Email plugin publishes new_email events; messaging bots consume via
subscribe_forever(). Also adds email optional-dependency group to pyproject.toml
(imap-tools, google-api-python-client, google-auth-oauthlib, O365).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 23:12:17 +02:00
curo1305 bde0856979 feat(daemon): Stage 6 daemon infrastructure
Always-on asyncio daemon with IPC socket, OS service install/uninstall
(launchd/systemd/schtasks), and plugin task supervisor.

- daemon/pid.py: atomic PID file, stale detection (POSIX + Windows)
- daemon/ipc.py: Unix socket (chmod 600, UID-checked) on Linux/macOS;
  TCP loopback + port file on Windows; newline-delimited JSON protocol
- daemon/service.py: launchd plist, systemd user unit, schtasks XML;
  auto-detects platform; finds pyra executable via shutil.which
- daemon/core.py: asyncio event loop, PluginSupervisor (per-task
  restart up to 10x with 5s back-off, reload), IPC command dispatch,
  SIGTERM/SIGHUP signal handling via get_running_loop()
- cli.py: all 7 daemon stubs replaced with real commands
- 376 tests passing (13 new supervisor + IPC handler tests)
2026-05-19 16:14:51 +02:00
curo1305 4744cf819b docs: update CLAUDE.md for Stage 6 daemon infrastructure
- Current Status: add Stage 6 daemon infrastructure in progress
- Architecture table: expand daemon/__init__.py stub to all 5 daemon modules
- Code Inventory: add daemon.core, daemon.pid, daemon.ipc, daemon.service
  sections with function signatures and purposes
- Internal classes: add PluginSupervisor and PidFile; expand DaemonConfig

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 16:10:49 +02:00
curo1305 1d5d0387d9 test(daemon): add supervisor and IPC handler tests
13 async tests covering: supervisor lifecycle (start/stop), task
completion, crash-and-restart, max-restart enforcement, status shape,
reload (task restart + counter reset), and IPC handler dispatch for all
4 commands plus unknown commands.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:57:49 +02:00
curo1305 db6ca6ee57 feat(daemon): implement reload, fix PID race condition
- PluginSupervisor.reload(): cancels all running plugin tasks, resets
  restart counters, and re-creates them with fresh coroutines
- IPC reload command now calls supervisor.reload() instead of being a stub
- run_foreground(): wrap PID file acquisition in try/except PidFileError
  to produce a clean error if two daemon starts race on the PID file

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:54:56 +02:00
curo1305 cc24257ab0 fix(daemon): close discarded coroutines in get_daemon_task_factories
The initial daemon_tasks() call to count tasks created coroutines that
were immediately discarded, triggering RuntimeWarning "coroutine never
awaited". Explicitly close them after counting.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:53:06 +02:00
curo1305 68f9007ef0 fix(daemon): install signal handlers inside running event loop
_install_signal_handlers() was called before asyncio.run(), registering
handlers on a throwaway loop that asyncio.get_event_loop() created — so
SIGTERM would never reach the supervisor. Move the call into _run_daemon()
and switch to asyncio.get_running_loop() so handlers are registered on the
actual running loop.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:52:11 +02:00
curo1305 c41ad0afc6 feat(daemon): wire up all 7 daemon CLI commands
start/run/stop/status/restart/install/uninstall now call the real daemon
modules instead of printing stub messages. Includes a Rich status table
for `pyra daemon status` and friendly error messages when config is missing.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:38:50 +02:00
curo1305 3d3ce694b9 feat(daemon): add core asyncio daemon, supervisor, and registry factories
- core.py: asyncio event loop entry point, PluginSupervisor with per-task
  restart (up to 10 times, 5s back-off), IPC dispatch, signal handling
  (SIGTERM/SIGHUP on POSIX), RotatingFileHandler, start_background() helper
- daemon/__init__.py: export public API
- plugins/registry.py: add get_daemon_task_factories() so supervisor can
  restart crashed tasks by re-calling plugin.daemon_tasks()[i]
- config/schema.py: add DaemonConfig.ipc_port for Windows TCP loopback

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:33:57 +02:00
curo1305 d42b8b4a47 feat(daemon): add IPC transport module
Newline-delimited JSON over Unix socket (macOS/Linux, chmod 600, UID-checked
via SO_PEERCRED/getpeereid) with TCP loopback fallback on Windows. Port written
to ~/.pyra/daemon.port for Windows clients. Sync send_command() wrapper for CLI.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:26:25 +02:00
curo1305 513871ef96 feat(daemon): add OS service install/uninstall module
Generates launchd plist (macOS), systemd user unit (Linux), and Task
Scheduler XML (Windows). Auto-detects platform; finds pyra executable
via shutil.which with venv-sibling fallback.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:23:11 +02:00
curo1305 eaed52006f feat(daemon): add PID file management module
Atomic write-then-rename, stale-PID detection via os.kill on POSIX and
ctypes.OpenProcess on Windows, context manager for cleanup on exit.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:22:04 +02:00
curo1305 0e052c4992 fix(setup): correct LM Studio loaded state value to "loaded" not "loaded_instance"
Querying the live /api/v0/models endpoint shows LM Studio uses state="loaded"
for in-memory models (not "loaded_instance"), so the filter never matched.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 15:00:01 +02:00
curo1305 40aa934431 fix(setup): filter LM Studio models by state == "loaded_instance"
LM Studio's /v1/models returns all downloaded models, not just loaded
ones. Use /api/v0/models with state filtering in both fetch_loaded_models()
and _fetch_local_models() so only RAM-resident models are shown as loaded.
This also restores the _choose_model() fallback that offers downloaded-but-
unloaded models when nothing is active in LM Studio.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 14:54:06 +02:00
curo1305 833d1445f0 feat(setup,chat): detect actually-loaded local model via provider-specific API
For Ollama, /api/tags returns all installed models, not running ones.
Add fetch_loaded_models() using /api/ps for Ollama (and /v1/models for
LM Studio/llama.cpp, which already return only loaded models).

_show_local_model_status() now calls fetch_loaded_models() so the
setup wizard correctly shows only in-memory models for Ollama.

At chat session startup, local providers warn when the configured model
is not currently loaded, or when nothing is loaded at all.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 14:46:54 +02:00
curo1305 cb390ad6af docs: update README and CLAUDE.md to reflect current state
Add daemon subcommands to README command table (Stage 6 stubs), add
Multi-step Planning section, add chat/planner.py to CLAUDE.md
architecture table, add TaskPlanner to internal classes inventory,
and remove stale test count.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 14:35:44 +02:00
curo1305 01655124b5 Update description 2026-05-19 14:28:37 +02:00
curo1305 b3851a2715 test: add tests for draft persistence, model status, and model re-entry
10 new tests covering:
- _save_draft / _load_draft / _delete_draft / _mark_step_done helpers
- draft file permissions (chmod 600)
- _show_local_model_status with zero, one, and multiple loaded models
- _test_connection change_model path (model error → change → retry succeeds)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 13:43:53 +02:00
curo1305 019e8044a9 feat(setup): model re-entry, status indicator, and resumable setup wizard
- _test_connection() now returns the (possibly changed) model name and
  offers a "Change model" option when the error is model-related
- _show_local_model_status() prints which models are currently loaded
  immediately after selecting a local provider
- Draft persistence: each completed wizard step is saved to
  ~/.pyra/setup.draft.json (chmod 600); on the next run a yellow panel
  summarises progress and offers [Resume / Start fresh]; draft is
  deleted on successful completion or Ctrl-C with no completed steps

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 13:43:49 +02:00
curo1305 efc589cc56 test: add tests for _classify_error and _check_local_server retry behaviour
_classify_error: covers all litellm error types (auth, not-found, rate-limit,
service-unavailable, connection, timeout, bad-request), httpx connect and
timeout errors, and generic fallback — using dynamically constructed fake
exception classes to avoid importing litellm in tests.

_check_local_server: covers success, retry-then-success, abort (SystemExit),
and continue-anyway paths via monkeypatched httpx and questionary.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 13:27:04 +02:00
curo1305 9a392410e7 feat(setup): error controller with retry loops in setup wizard
Add _classify_error() that maps litellm/httpx exceptions to human-readable
labels and resolution hints without requiring a top-level litellm import.

_check_local_server() now loops with Retry / Continue anyway / Abort instead
of printing a one-shot warning and silently continuing.

_test_connection() now loops with Retry / Re-enter API key (auth errors only) /
Skip test / Abort instead of printing the raw exception string and falling through.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 13:26:58 +02:00
23 changed files with 4833 additions and 84 deletions
+49 -6
View File
@@ -8,7 +8,8 @@ a plugin/integration system (Stage 2+) and an encrypted vault (Stage 3+).
## Current Status ## Current Status
**Stage 3 — Memory Database: complete** (2026-05-18) **Stage 3 — Memory Database: complete** (2026-05-18)
Next: Stage 4Vault Encryption **Stage 6Daemon infrastructure: in progress** (`feat/daemon` branch)
Next: Stage 4 — Vault Encryption (skipped for now); messaging bots (Stage 6 remainder)
## Project Roadmap ## Project Roadmap
@@ -19,11 +20,11 @@ memory in `~/.pyra/memory/`, and hard security boundaries around the vault.
### Stage 2 — Plugin Framework ✅ COMPLETE ### Stage 2 — Plugin Framework ✅ COMPLETE
- `src/pyra/plugins/` package: `base.py`, `loader.py`, `registry.py`, `executor.py`, `install.py` - `src/pyra/plugins/` package: `base.py`, `loader.py`, `registry.py`, `executor.py`, `install.py`
- `src/pyra/bundled_plugins/` — ships bundled plugin scripts with pyra - `src/pyra/bundled_plugins/` — ships bundled plugin scripts with pyra
- `src/pyra/daemon/` stub (CLI surface only) - `src/pyra/daemon/` stub (CLI surface only; daemon itself is Stage 6)
- Config: `PluginConfig` + `DaemonConfig` added to `PyraConfig` - Config: `PluginConfig` + `DaemonConfig` added to `PyraConfig`
- Bootstrap: `~/.pyra/plugins/` and `~/.pyra/logs/` created on startup - Bootstrap: `~/.pyra/plugins/` and `~/.pyra/logs/` created on startup
- Chat session: AI tool-use loop (up to 10 iterations), approval gate, plugin slash commands - Chat session: AI tool-use loop (up to 10 iterations), approval gate, plugin slash commands
- CLI: `pyra plugin list/install/enable/disable/setup`, `pyra daemon *` stubs - CLI: `pyra plugin list/install/enable/disable/setup`, `pyra daemon *` (stubs at Stage 2; implemented in Stage 6)
### Stage 3 — Memory Database ✅ COMPLETE ### Stage 3 — Memory Database ✅ COMPLETE
- `src/pyra/memory/database.py`: SQLite + FTS5 via `memory_meta` + `memory_fts` tables - `src/pyra/memory/database.py`: SQLite + FTS5 via `memory_meta` + `memory_fts` tables
@@ -99,6 +100,7 @@ the vault under namespaced keys (`plugin:{name}:{key}`).
| `config/manager.py` | ruamel.yaml round-trip config read/write, chmod 600 enforced | | `config/manager.py` | ruamel.yaml round-trip config read/write, chmod 600 enforced |
| `config/dirs.py` | `bootstrap()` — creates `~/.pyra/` tree, checks vault sentinel every startup | | `config/dirs.py` | `bootstrap()` — creates `~/.pyra/` tree, checks vault sentinel every startup |
| `chat/session.py` | prompt_toolkit REPL loop, AI tool-use loop, plugin slash commands | | `chat/session.py` | prompt_toolkit REPL loop, AI tool-use loop, plugin slash commands |
| `chat/planner.py` | `TaskPlanner` — multi-step plan approval loop, per-step AI execution and verification |
| `chat/renderer.py` | Streaming + non-streaming markdown via rich, injection warning panel | | `chat/renderer.py` | Streaming + non-streaming markdown via rich, injection warning panel |
| `chat/history.py` | Conversation list, token budget trimming, tool message support | | `chat/history.py` | Conversation list, token budget trimming, tool message support |
| `memory/database.py` | SQLite+FTS5 — `init_db()`, `upsert()`, `remove()`, `search()`, `list_all()`, `migrate_from_files()` | | `memory/database.py` | SQLite+FTS5 — `init_db()`, `upsert()`, `remove()`, `search()`, `list_all()`, `migrate_from_files()` |
@@ -116,7 +118,11 @@ the vault under namespaced keys (`plugin:{name}:{key}`).
| `plugins/executor.py` | Approval gate: scan args → prompt → execute → scan result → log | | `plugins/executor.py` | Approval gate: scan args → prompt → execute → scan result → log |
| `plugins/install.py` | Copies bundled plugins to `~/.pyra/plugins/` | | `plugins/install.py` | Copies bundled plugins to `~/.pyra/plugins/` |
| `bundled_plugins/` | Standalone plugin scripts shipped with pyra (installed on demand) | | `bundled_plugins/` | Standalone plugin scripts shipped with pyra (installed on demand) |
| `daemon/__init__.py` | Daemon package stub (implementation in Stage 2.4) | | `daemon/pid.py` | Atomic PID file — write, read, stale detection (POSIX + Windows), context manager |
| `daemon/ipc.py` | IPC transport — Unix socket chmod 600 + UID-check (Linux/macOS) or TCP loopback + port file (Windows); newline-delimited JSON protocol |
| `daemon/service.py` | OS service file generation + install/uninstall — launchd plist (macOS), systemd user unit (Linux), schtasks XML (Windows) |
| `daemon/core.py` | asyncio event loop entry point, `PluginSupervisor` (per-task restart, max 10×, 5s back-off, reload), IPC command dispatch, signal handling |
| `daemon/__init__.py` | Public daemon API exports |
### Runtime: `~/.pyra/` ### Runtime: `~/.pyra/`
@@ -243,7 +249,7 @@ uv pip install -e ".[all-plugins]" # Everything
## Running Tests ## Running Tests
```bash ```bash
pytest tests/ -v # all unit + security tests (161 tests) pytest tests/ -v # all unit + security tests
pytest tests/integration/test_lmstudio.py # requires LM Studio at localhost:1234 pytest tests/integration/test_lmstudio.py # requires LM Studio at localhost:1234
``` ```
@@ -492,6 +498,40 @@ Dataclass: `MemoryFile(name, path, category, size_bytes, modified)`
| `list_bundled_plugins` | `(bundled_dir: Path) -> list[str]` | Names of all bundled plugins that have a `manifest.json` | | `list_bundled_plugins` | `(bundled_dir: Path) -> list[str]` | Names of all bundled plugins that have a `manifest.json` |
| `read_manifest` | `(plugin_dir: Path) -> dict` | Reads `manifest.json`; returns `{}` if missing | | `read_manifest` | `(plugin_dir: Path) -> dict` | Reads `manifest.json`; returns `{}` if missing |
#### `daemon.core`
| Function | Signature | Purpose |
|----------|-----------|---------|
| `run_foreground` | `() -> None` | Entry point for `pyra daemon run` — loads config + plugins, writes PID file, runs asyncio loop |
| `start_background` | `() -> None` | Spawns `pyra daemon run` as a detached subprocess (`start_new_session` on POSIX, `DETACHED_PROCESS` on Windows) |
#### `daemon.pid`
| Function | Signature | Purpose |
|----------|-----------|---------|
| `resolve_pid_path` | `(cfg_path: str) -> Path` | Expand `~` and resolve to absolute Path |
#### `daemon.ipc`
| Function | Signature | Purpose |
|----------|-----------|---------|
| `send_command` | `(address, msg, timeout=5.0) -> IpcResponse` | Synchronous CLI helper — `asyncio.run(IpcClient.send(...))` |
| `get_socket_path` | `(cfg: str) -> Path` | Expand `~` and return Unix socket path |
| `is_unix_socket` | `() -> bool` | True on Linux/macOS (`sys.platform != 'nt'`) |
| `get_port_file_path` | `() -> Path` | Path to `~/.pyra/daemon.port` (Windows TCP port file) |
#### `daemon.service`
| Function | Signature | Purpose |
|----------|-----------|---------|
| `detect_platform` | `() -> Literal["macos","linux","windows"]` | Detect current OS |
| `find_pyra_executable` | `() -> str` | `shutil.which("pyra")` → sibling fallback → `sys.executable -m pyra` |
| `install_service` | `() -> None` | Generate + register OS service (reads config for log/pid paths) |
| `uninstall_service` | `() -> None` | Deregister OS service |
| `render_launchd_plist` | `(exe, log_file, pid_file) -> str` | macOS plist template |
| `render_systemd_unit` | `(exe, log_file) -> str` | Linux systemd unit template |
| `render_schtasks_xml` | `(exe) -> str` | Windows Task Scheduler XML template (write as UTF-16) |
#### `chat.renderer` — rendering functions and shared `console` #### `chat.renderer` — rendering functions and shared `console`
Import `console` from here; do not create a second `rich.Console()` in new code. Import `console` from here; do not create a second `rich.Console()` in new code.
@@ -514,7 +554,7 @@ Import `console` from here; do not create a second `rich.Console()` in new code.
| `GeneralConfig` | `config.schema` | `general:` block — `user_name`, `assistant_name` | | `GeneralConfig` | `config.schema` | `general:` block — `user_name`, `assistant_name` |
| `ProviderConfig` | `config.schema` | `ai:` block — `provider_id`, `model`, `base_url` | | `ProviderConfig` | `config.schema` | `ai:` block — `provider_id`, `model`, `base_url` |
| `PluginConfig` | `config.schema` | `plugins:` block — `enabled`, `require_approval`, `log_executions` | | `PluginConfig` | `config.schema` | `plugins:` block — `enabled`, `require_approval`, `log_executions` |
| `DaemonConfig` | `config.schema` | `daemon:` block | | `DaemonConfig` | `config.schema` | `daemon:` block — `enabled`, `socket_path`, `log_file`, `pid_file`, `ipc_port` |
| `MemoryConfig` | `config.schema` | `memory:` block — `max_tokens_in_context`, `auto_load` | | `MemoryConfig` | `config.schema` | `memory:` block — `max_tokens_in_context`, `auto_load` |
| `SecurityConfig` | `config.schema` | `security:` block — `injection_detection`, `log_injections` | | `SecurityConfig` | `config.schema` | `security:` block — `injection_detection`, `log_injections` |
| `ConversationHistory` | `chat.history` | Holds message list; builds API payload via `build_for_api()`; trims to token budget | | `ConversationHistory` | `chat.history` | Holds message list; builds API payload via `build_for_api()`; trims to token budget |
@@ -524,3 +564,6 @@ Import `console` from here; do not create a second `rich.Console()` in new code.
| `Tool` | `plugins.base` | Dataclass — `name`, `description`, `parameters` (JSON Schema), `handler`, `requires_approval` | | `Tool` | `plugins.base` | Dataclass — `name`, `description`, `parameters` (JSON Schema), `handler`, `requires_approval` |
| `PyraPlugin` | `plugins.base` | `@runtime_checkable` Protocol — the plugin interface | | `PyraPlugin` | `plugins.base` | `@runtime_checkable` Protocol — the plugin interface |
| `BasePlugin` | `plugins.base` | Concrete base with no-op defaults; plugins should inherit this | | `BasePlugin` | `plugins.base` | Concrete base with no-op defaults; plugins should inherit this |
| `TaskPlanner` | `chat.planner` | Multi-step plan runner; `make_tool_handler()` returns the callable wired into the chat session; presents plan for user approval, executes each step via litellm with up to 5 tool-use iterations, verifies output before proceeding |
| `PluginSupervisor` | `daemon.core` | asyncio supervisor — `add_task(name, factory)`, `start()`, `stop()`, `reload()`, `status()`; restarts crashed tasks up to 10× with 5s back-off |
| `PidFile` | `daemon.pid` | `write()` (atomic), `read()`, `is_stale()`, `remove()`, context manager; `PidFileError(OSError)` raised when live PID already exists |
+57 -10
View File
@@ -1,7 +1,7 @@
# Pyra # Pyra
A personal AI assistant CLI with vault-first security. Combines multi-provider AI chat with A personal AI assistant CLI with vault-first security. Combines multi-provider AI chat,
long-term memory and (coming) automation skills. long-term memory, and an extensible plugin system.
## Quick Start ## Quick Start
@@ -31,6 +31,17 @@ pyra chat # start talking
| `pyra memory read <name>` | Read a memory file | | `pyra memory read <name>` | Read a memory file |
| `pyra memory write <name> <content>` | Write a memory file | | `pyra memory write <name> <content>` | Write a memory file |
| `pyra memory append <name> <content>` | Append to a memory file | | `pyra memory append <name> <content>` | Append to a memory file |
| `pyra plugin list` | List installed and available plugins |
| `pyra plugin install <name>` | Install a bundled plugin |
| `pyra plugin enable <name>` | Enable an installed plugin |
| `pyra plugin disable <name>` | Disable a plugin (keeps it installed) |
| `pyra plugin setup <name>` | Run a plugin's credential setup wizard |
| `pyra daemon start` | Start the background daemon *(Stage 6, not yet implemented)* |
| `pyra daemon stop` | Stop the running daemon *(Stage 6, not yet implemented)* |
| `pyra daemon status` | Show daemon status *(Stage 6, not yet implemented)* |
| `pyra daemon restart` | Restart the daemon *(Stage 6, not yet implemented)* |
| `pyra daemon install` | Register Pyra as a system service *(Stage 6, not yet implemented)* |
| `pyra daemon uninstall` | Remove the system service *(Stage 6, not yet implemented)* |
### In-chat slash commands ### In-chat slash commands
@@ -38,6 +49,7 @@ pyra chat # start talking
|---------|-------------| |---------|-------------|
| `/help` | Show available commands | | `/help` | Show available commands |
| `/memory list` | List memory files | | `/memory list` | List memory files |
| `/config` | Open the configuration TUI |
| `/clear` | Clear conversation history | | `/clear` | Clear conversation history |
| `/quit` or `/exit` | Exit Pyra | | `/quit` or `/exit` | Exit Pyra |
@@ -48,16 +60,41 @@ pyra chat # start talking
- **Prompt injection scanner** — warns on suspicious AI output, logs to `~/.pyra/security.log` - **Prompt injection scanner** — warns on suspicious AI output, logs to `~/.pyra/security.log`
- **Path sandboxing** — the AI can only reference memory files by name; traversal is blocked - **Path sandboxing** — the AI can only reference memory files by name; traversal is blocked
## Plugins
Pyra has an extensible plugin system. Bundled plugins are shipped with Pyra and installed on
demand; third-party plugins can be dropped into `~/.pyra/plugins/` directly.
Each plugin is a directory containing a `manifest.json` and a `plugin.py`. Plugin credentials
are stored in the vault under namespaced keys (`plugin:<name>:<key>`) — never in `config.yaml`.
```bash
pyra plugin list # see what's available
pyra plugin install <name> # copy a bundled plugin to ~/.pyra/plugins/
pyra plugin setup <name> # enter credentials (stored in vault)
pyra plugin enable <name> # activate for the next chat session
```
## Multi-step Planning
When given a complex task the AI can propose a **multi-step plan** using the built-in
`plan_and_execute` tool. Pyra prints the plan and asks for approval before executing
anything. Each step runs as a separate AI call with access to enabled plugin tools; each
result is verified before moving on to the next step. You can decline the plan or
interrupt at any point.
## Memory ## Memory
Pyra reads your memory files at the start of each session and injects them as context. Pyra reads your memory files at the start of each session and injects them as context.
Files are plain Markdown stored in `~/.pyra/memory/`: Files are plain Markdown stored in `~/.pyra/memory/`, indexed by a SQLite full-text search
database (`memory.db`) for fast in-chat lookup.
``` ```
~/.pyra/memory/ ~/.pyra/memory/
├── user/profile.md ← who you are ├── user/profile.md ← who you are
├── context/ ← ongoing projects ├── context/ ← ongoing projects
── knowledge/ ← general notes ── knowledge/ ← general notes
└── memory.db ← FTS5 search index (auto-managed)
``` ```
## `~/.pyra/` Directory ## `~/.pyra/` Directory
@@ -67,15 +104,25 @@ Files are plain Markdown stored in `~/.pyra/memory/`:
├── config.yaml ← provider + model (no secrets) ├── config.yaml ← provider + model (no secrets)
├── security.log ← injection event log ├── security.log ← injection event log
├── memory/ ← AI-readable long-term memory ├── memory/ ← AI-readable long-term memory
├── skills/ ← automation scripts (Stage 2) │ └── memory.db ← SQLite FTS5 search index
├── plugins/ ← installed plugins
│ └── <name>/
│ ├── manifest.json
│ └── plugin.py
├── logs/ ← execution logs
│ ├── tool_executions.log
│ └── plugin_errors.log
└── vault/ ← secure, AI-inaccessible storage └── vault/ ← secure, AI-inaccessible storage
└── secrets/api_keys.json └── secrets/api_keys.json
``` ```
## Roadmap ## Roadmap
- **Stage 1** (now): Core CLI, multi-provider chat, memory, vault security - **Stage 1** Core CLI multi-provider chat, memory, vault security
- **Stage 2**: Skills — shell/PowerShell/Python automations with user approval gates - **Stage 2** ✅ Plugin Framework — extensible tools, slash commands, approval gates
- **Stage 3**: Vault encryption with `age` - **Stage 3** ✅ Memory Database — SQLite + FTS5 full-text search index
- **Stage 4**: Security audit sub-agent - **Stage 4** Vault Encryption — `age`-based encryption of `~/.pyra/vault/secrets/`
- **Stage 5**: Web UI, embedding-based memory search - **Stage 5** Skills System — YAML-defined multi-plugin workflows with event triggers
- **Stage 6** Daemon + Messaging Bots — always-on asyncio daemon, Matrix/Telegram/Signal bots
- **Stage 7** Security Audit Sub-agent — automated scanning for injection, CVEs, permission drift
- **Stage 8** Web UI — optional local interface, embedding-based memory search
+8
View File
@@ -35,6 +35,12 @@ gdrive = ["google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0"]
onedrive = ["msal>=1.28.0"] onedrive = ["msal>=1.28.0"]
dropbox = ["dropbox>=12.0.0"] dropbox = ["dropbox>=12.0.0"]
daemon = ["aiofiles>=23.0.0"] daemon = ["aiofiles>=23.0.0"]
email = [
"imap-tools>=1.7.0",
"google-api-python-client>=2.120.0",
"google-auth-oauthlib>=1.2.0",
"O365>=2.0.36",
]
all-plugins = [ all-plugins = [
"caldav>=1.3.0", "webdav4>=0.9.0", "vobject>=0.9.6", "caldav>=1.3.0", "webdav4>=0.9.0", "vobject>=0.9.6",
"matrix-nio>=0.24.0", "aiofiles>=23.0.0", "matrix-nio>=0.24.0", "aiofiles>=23.0.0",
@@ -44,6 +50,8 @@ all-plugins = [
"google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0", "google-api-python-client>=2.120.0", "google-auth-oauthlib>=1.2.0",
"msal>=1.28.0", "msal>=1.28.0",
"dropbox>=12.0.0", "dropbox>=12.0.0",
"imap-tools>=1.7.0",
"O365>=2.0.36",
] ]
[project.scripts] [project.scripts]
@@ -0,0 +1,12 @@
{
"name": "email",
"version": "1.0.0",
"description": "Full email management — read, send, search, sort, and create filter rules. Supports Gmail, Microsoft 365, ProtonMail (Bridge), and any IMAP provider. Background monitoring pushes new-email summaries to your configured messaging bot.",
"author": "pyra",
"requires": [
"imap-tools>=1.7.0",
"google-api-python-client>=2.120.0",
"google-auth-oauthlib>=1.2.0",
"O365>=2.0.36"
]
}
File diff suppressed because it is too large Load Diff
+11
View File
@@ -25,6 +25,7 @@ from pyra.plugins.executor import ToolExecutor
from pyra.plugins.registry import PluginRegistry from pyra.plugins.registry import PluginRegistry
from pyra.security.injection import scan_response from pyra.security.injection import scan_response
from pyra.setup.providers import get_provider from pyra.setup.providers import get_provider
from pyra.setup.wizard import fetch_loaded_models
from pyra.utils.paths import pyra_home from pyra.utils.paths import pyra_home
_HISTORY_FILE = pyra_home() / ".chat_history" _HISTORY_FILE = pyra_home() / ".chat_history"
@@ -176,6 +177,16 @@ def start_chat() -> None:
"[dim]Type /help for commands, /quit to exit.[/dim]" "[dim]Type /help for commands, /quit to exit.[/dim]"
) )
if provider.group == "Local":
loaded = fetch_loaded_models(provider)
if not loaded:
render_info(f"No model currently loaded in {provider.display_name}.")
elif cfg.ai.model not in loaded:
render_info(
f"Model '{cfg.ai.model}' not loaded in {provider.display_name}. "
f"Loaded: {', '.join(loaded)}"
)
_flags: dict = {"use_tools": True} _flags: dict = {"use_tools": True}
while True: while True:
+113 -12
View File
@@ -174,7 +174,7 @@ def plugin_install(name: str) -> None:
install_bundled_plugin(name, bundled_dir, plugins_dir) install_bundled_plugin(name, bundled_dir, plugins_dir)
console.print(f"[green]Installed:[/green] {name}") console.print(f"[green]Installed:[/green] {name}")
console.print(f" Enable: [dim]pyra plugin enable {name}[/dim]") console.print(f" Enable: [dim]pyra plugin enable {name}[/dim]")
console.print(f" Confirm: [dim]pyra plugin setup {name}[/dim]") console.print(f" Configure: [dim]pyra plugin setup {name}[/dim]")
except FileNotFoundError as exc: except FileNotFoundError as exc:
console.print(f"[red]Error:[/red] {exc}") console.print(f"[red]Error:[/red] {exc}")
except Exception as exc: except Exception as exc:
@@ -266,43 +266,144 @@ def daemon() -> None:
_bootstrap_or_exit() _bootstrap_or_exit()
@daemon.command("run", hidden=True)
def daemon_run() -> None:
"""Run daemon in foreground (used by service manager)."""
from pyra.daemon.core import run_foreground
run_foreground()
@daemon.command("start") @daemon.command("start")
def daemon_start() -> None: def daemon_start() -> None:
"""Start the Pyra daemon in the background.""" """Start the Pyra daemon in the background."""
console.print("[yellow]Daemon (Stage 2.4) is not yet implemented.[/yellow]") from pyra.daemon.core import start_background
try:
start_background()
except FileNotFoundError:
console.print("[red]Error:[/red] Run [dim]pyra setup[/dim] first.")
@daemon.command("stop") @daemon.command("stop")
def daemon_stop() -> None: def daemon_stop() -> None:
"""Stop the running Pyra daemon.""" """Stop the running Pyra daemon."""
console.print("[yellow]Daemon (Stage 2.4) is not yet implemented.[/yellow]") _daemon_ipc("stop", success_msg="Daemon stopped.")
@daemon.command("status") @daemon.command("status")
def daemon_status() -> None: def daemon_status() -> None:
"""Show daemon status.""" """Show daemon status."""
console.print("[yellow]Daemon (Stage 2.4) is not yet implemented.[/yellow]") _daemon_ipc("status")
@daemon.command("restart") @daemon.command("restart")
def daemon_restart() -> None: def daemon_restart() -> None:
"""Restart the Pyra daemon.""" """Restart the Pyra daemon."""
console.print("[yellow]Daemon (Stage 2.4) is not yet implemented.[/yellow]") import time
from pyra.daemon.core import start_background
_daemon_ipc("stop", success_msg=None)
time.sleep(1.5)
try:
start_background()
except FileNotFoundError:
console.print("[red]Error:[/red] Run [dim]pyra setup[/dim] first.")
@daemon.command("install") @daemon.command("install")
def daemon_install() -> None: def daemon_install() -> None:
"""Install Pyra as a system service (launchd/systemd).""" """Install Pyra as a system service (launchd/systemd/schtasks)."""
console.print("[yellow]Daemon service install (Stage 2.4) is not yet implemented.[/yellow]") from pyra.daemon.service import detect_platform, install_service
try:
install_service()
console.print(f"[green]Service installed[/green] ({detect_platform()}).")
except Exception as exc:
console.print(f"[red]Install failed:[/red] {exc}")
@daemon.command("uninstall") @daemon.command("uninstall")
def daemon_uninstall() -> None: def daemon_uninstall() -> None:
"""Remove the Pyra system service.""" """Remove the Pyra system service."""
console.print("[yellow]Daemon service uninstall (Stage 2.4) is not yet implemented.[/yellow]") from pyra.daemon.service import uninstall_service
try:
uninstall_service()
console.print("[green]Service removed.[/green]")
except Exception as exc:
console.print(f"[red]Uninstall failed:[/red] {exc}")
@daemon.command("run", hidden=True) def _daemon_ipc(cmd: str, *, success_msg: str | None = None) -> None:
def daemon_run() -> None: """Send a command to the running daemon via IPC and render the response."""
"""Run daemon in foreground (used by service manager).""" from pyra.config.manager import load_config
console.print("[yellow]Daemon (Stage 2.4) is not yet implemented.[/yellow]") from pyra.daemon.ipc import (
get_socket_path,
is_unix_socket,
get_port_file_path,
send_command,
)
try:
cfg = load_config()
except FileNotFoundError:
console.print("[red]Error:[/red] Run [dim]pyra setup[/dim] first.")
return
if is_unix_socket():
address = get_socket_path(cfg.daemon.socket_path)
else:
port = _read_windows_port()
if port is None:
console.print("[yellow]Daemon is not running.[/yellow]")
return
address = ("127.0.0.1", port)
try:
resp = send_command(address, {"cmd": cmd})
except (ConnectionRefusedError, FileNotFoundError, OSError):
console.print("[yellow]Daemon is not running.[/yellow]")
return
except ConnectionResetError:
console.print("[red]Permission denied:[/red] daemon rejected connection.")
return
except TimeoutError:
console.print("[red]Daemon did not respond in time.[/red]")
return
if not resp.get("ok"):
console.print(f"[red]Error:[/red] {resp.get('data', {}).get('error', 'unknown')}")
return
if cmd == "status":
_render_daemon_status(resp["data"])
elif success_msg:
console.print(f"[green]{success_msg}[/green]")
def _read_windows_port() -> int | None:
from pyra.daemon.ipc import get_port_file_path
try:
return int(get_port_file_path().read_text().strip())
except (FileNotFoundError, ValueError):
return None
def _render_daemon_status(data: dict) -> None:
from rich.table import Table
uptime = data.get("uptime", 0.0)
pid = data.get("pid", "?")
tasks = data.get("tasks", [])
hours, rem = divmod(int(uptime), 3600)
mins, secs = divmod(rem, 60)
uptime_str = f"{hours}h {mins}m {secs}s" if hours else f"{mins}m {secs}s"
console.print(f"[bold green]Daemon running[/bold green] — PID {pid}, uptime {uptime_str}")
if tasks:
table = Table("Task", "Alive", "Restarts", "Last error", show_header=True)
for t in tasks:
alive = "[green]yes[/green]" if t.get("alive") else "[red]no[/red]"
error = t.get("last_error") or ""
table.add_row(t.get("name", "?"), alive, str(t.get("restart_count", 0)), error)
console.print(table)
else:
console.print("[dim]No plugin tasks registered.[/dim]")
+1
View File
@@ -36,6 +36,7 @@ class DaemonConfig(BaseModel):
socket_path: str = "~/.pyra/daemon.sock" socket_path: str = "~/.pyra/daemon.sock"
log_file: str = "~/.pyra/daemon.log" log_file: str = "~/.pyra/daemon.log"
pid_file: str = "~/.pyra/daemon.pid" pid_file: str = "~/.pyra/daemon.pid"
ipc_port: int = 0 # Windows TCP loopback: 0 = OS-assigned, written to ~/.pyra/daemon.port
class PyraConfig(BaseModel): class PyraConfig(BaseModel):
+24
View File
@@ -0,0 +1,24 @@
"""Pyra background daemon package."""
from pyra.daemon.core import PluginSupervisor, run_foreground, start_background
from pyra.daemon.events import publish, subscribe_forever
from pyra.daemon.ipc import IpcClient, IpcServer, send_command
from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path
from pyra.daemon.service import detect_platform, install_service, uninstall_service
__all__ = [
"run_foreground",
"start_background",
"PluginSupervisor",
"publish",
"subscribe_forever",
"IpcClient",
"IpcServer",
"send_command",
"PidFile",
"PidFileError",
"resolve_pid_path",
"detect_platform",
"install_service",
"uninstall_service",
]
+313
View File
@@ -0,0 +1,313 @@
"""Pyra daemon core — asyncio event loop, plugin task supervisor, signal handling."""
from __future__ import annotations
import asyncio
import logging
import logging.handlers
import os
import signal
import subprocess
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Coroutine
from pyra.utils.paths import pyra_home, safe_chmod
_log = logging.getLogger("pyra.daemon")
_start_time: float = 0.0
# ── Plugin task supervisor ────────────────────────────────────────────────────
@dataclass
class TaskRecord:
name: str
coro_factory: Callable[[], Coroutine] # type: ignore[type-arg]
task: asyncio.Task | None = field(default=None, repr=False)
restart_count: int = 0
last_error: str | None = None
def is_alive(self) -> bool:
return self.task is not None and not self.task.done()
class PluginSupervisor:
_RESTART_DELAY: float = 5.0
_MAX_RESTARTS: int = 10
def __init__(self) -> None:
self._records: list[TaskRecord] = []
self._shutdown = asyncio.Event()
def add_task(self, name: str, factory: Callable[[], Coroutine]) -> None: # type: ignore[type-arg]
self._records.append(TaskRecord(name=name, coro_factory=factory))
async def start(self) -> None:
for record in self._records:
record.task = asyncio.create_task(
self._supervise(record), name=record.name
)
_log.info("Supervisor started with %d plugin task(s).", len(self._records))
async def run_until_shutdown(self) -> None:
await self._shutdown.wait()
_log.info("Shutdown requested — stopping supervisor.")
def request_shutdown(self) -> None:
self._shutdown.set()
async def stop(self) -> None:
for record in self._records:
if record.task and not record.task.done():
record.task.cancel()
tasks = [r.task for r in self._records if r.task]
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def reload(self) -> None:
"""Cancel all running tasks and restart them with fresh coroutines."""
for record in self._records:
if record.task and not record.task.done():
record.task.cancel()
try:
await record.task
except (asyncio.CancelledError, Exception):
pass
record.restart_count = 0
record.last_error = None
record.task = asyncio.create_task(
self._supervise(record), name=record.name
)
_log.info("Reloaded %d plugin task(s).", len(self._records))
def status(self) -> list[dict]:
return [
{
"name": r.name,
"alive": r.is_alive(),
"restart_count": r.restart_count,
"last_error": r.last_error,
}
for r in self._records
]
async def _supervise(self, record: TaskRecord) -> None:
while not self._shutdown.is_set():
try:
await record.coro_factory()
_log.info("Plugin task %s completed normally.", record.name)
return
except asyncio.CancelledError:
return
except Exception as exc:
record.restart_count += 1
record.last_error = f"{type(exc).__name__}: {exc}"
_log.error(
"Plugin task %s crashed (restart #%d): %s",
record.name, record.restart_count, exc,
exc_info=True,
)
if record.restart_count >= self._MAX_RESTARTS:
_log.critical(
"Plugin task %s exceeded max restarts (%d). Giving up.",
record.name, self._MAX_RESTARTS,
)
return
try:
await asyncio.wait_for(
asyncio.sleep(self._RESTART_DELAY),
timeout=self._RESTART_DELAY + 1,
)
except asyncio.CancelledError:
return
# ── IPC command dispatch ──────────────────────────────────────────────────────
def _make_ipc_handler(supervisor: PluginSupervisor):
async def handler(msg: dict) -> dict:
cmd = msg.get("cmd", "")
match cmd:
case "ping":
return {"ok": True, "data": {"pong": True}}
case "status":
return {
"ok": True,
"data": {
"uptime": time.monotonic() - _start_time,
"pid": os.getpid(),
"tasks": supervisor.status(),
},
}
case "stop":
supervisor.request_shutdown()
return {"ok": True, "data": {}}
case "reload":
await supervisor.reload()
return {"ok": True, "data": {"tasks_reloaded": len(supervisor._records)}}
case _:
return {"ok": False, "data": {"error": f"unknown command: {cmd}"}}
return handler
# ── Main async entrypoint ─────────────────────────────────────────────────────
async def _run_daemon(cfg, supervisor: PluginSupervisor) -> None:
from pyra.daemon.ipc import IpcServer, get_socket_path, is_unix_socket
# Install signal handlers now that the event loop is running.
_install_signal_handlers(supervisor)
if is_unix_socket():
address = get_socket_path(cfg.daemon.socket_path)
else:
address = ("127.0.0.1", cfg.daemon.ipc_port)
server = IpcServer(address, _make_ipc_handler(supervisor))
await supervisor.start()
async with asyncio.TaskGroup() as tg:
tg.create_task(server.start(), name="ipc_server")
tg.create_task(supervisor.run_until_shutdown(), name="shutdown_waiter")
await server.stop()
await supervisor.stop()
# ── Foreground entry point (pyra daemon run) ──────────────────────────────────
def run_foreground() -> None:
"""Run the daemon in the foreground. Called by `pyra daemon run`."""
from pyra.config.manager import load_config
from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path
from pyra.plugins.registry import PluginRegistry
global _start_time
cfg = load_config()
_setup_logging(cfg.daemon.log_file)
pid_path = resolve_pid_path(cfg.daemon.pid_file)
pid_file = PidFile(pid_path)
existing = pid_file.read()
if existing is not None and not pid_file.is_stale():
_log.error("Daemon already running (PID %d). Exiting.", existing)
sys.exit(1)
registry = PluginRegistry()
from pyra.utils.paths import pyra_home as _pyra_home
plugins_dir = _pyra_home() / "plugins"
if plugins_dir.exists():
registry.load_all(plugins_dir, cfg.plugins.enabled)
supervisor = PluginSupervisor()
for name, factory in registry.get_daemon_task_factories():
supervisor.add_task(name, factory)
_start_time = time.monotonic()
try:
with pid_file:
_log.info("Pyra daemon starting (PID %d).", os.getpid())
try:
asyncio.run(_run_daemon(cfg, supervisor))
except KeyboardInterrupt:
pass
_log.info("Pyra daemon stopped.")
except PidFileError as exc:
_log.error("Could not acquire PID file: %s", exc)
sys.exit(1)
# ── Background spawn (pyra daemon start) ─────────────────────────────────────
def start_background() -> None:
"""Spawn `pyra daemon run` as a detached background process."""
from pyra.config.manager import load_config
from pyra.daemon.pid import PidFile, resolve_pid_path
from pyra.daemon.service import find_pyra_executable
cfg = load_config()
pid_path = resolve_pid_path(cfg.daemon.pid_file)
pid_file = PidFile(pid_path)
existing = pid_file.read()
if existing is not None and not pid_file.is_stale():
from pyra.chat.renderer import console
console.print(f"[yellow]Daemon already running (PID {existing}).[/yellow]")
return
exe = find_pyra_executable()
log_path = Path(cfg.daemon.log_file).expanduser()
log_path.parent.mkdir(parents=True, exist_ok=True)
with open(log_path, "a") as log_fh:
if sys.platform == "win32":
DETACHED_PROCESS = 0x00000008
CREATE_NEW_PROCESS_GROUP = 0x00000200
subprocess.Popen(
[exe, "daemon", "run"],
creationflags=DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP,
stdout=log_fh,
stderr=log_fh,
close_fds=True,
)
else:
subprocess.Popen(
[exe, "daemon", "run"],
start_new_session=True,
stdout=log_fh,
stderr=log_fh,
stdin=subprocess.DEVNULL,
close_fds=True,
)
from pyra.chat.renderer import console
# Poll the PID file for up to 3 seconds to confirm startup.
for _ in range(30):
time.sleep(0.1)
pid = pid_file.read()
if pid is not None:
console.print(f"[green]Daemon started (PID {pid}).[/green]")
return
console.print("[yellow]Daemon process spawned but PID file not yet written.[/yellow]")
# ── Signal handling ───────────────────────────────────────────────────────────
def _install_signal_handlers(supervisor: PluginSupervisor) -> None:
if sys.platform == "win32":
signal.signal(signal.SIGTERM, lambda *_: supervisor.request_shutdown())
return
loop = asyncio.get_running_loop()
loop.add_signal_handler(signal.SIGTERM, supervisor.request_shutdown)
loop.add_signal_handler(signal.SIGHUP, supervisor.request_shutdown)
# ── Logging setup ─────────────────────────────────────────────────────────────
def _setup_logging(log_file_str: str) -> None:
log_path = Path(log_file_str).expanduser()
log_path.parent.mkdir(parents=True, exist_ok=True)
handler = logging.handlers.RotatingFileHandler(
log_path, maxBytes=5 * 1024 * 1024, backupCount=3
)
handler.setFormatter(
logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s")
)
root = logging.getLogger("pyra")
root.addHandler(handler)
root.setLevel(logging.INFO)
safe_chmod(log_path, 0o600)
+46
View File
@@ -0,0 +1,46 @@
"""Async notification bus for inter-plugin communication in the daemon.
Plugins publish events to a shared asyncio.Queue; other plugins (e.g. messaging
bots) consume them via subscribe_forever(). No direct plugin-to-plugin imports
are needed — both sides just use this module.
Event shape (by convention):
{"type": "new_email", "priority": int, "from": str, "subject": str,
"summary": str, "uid": str, "folder": str}
{"type": "new_message", "bot": str, "user_id": str, "text": str}
"""
from __future__ import annotations
import asyncio
from typing import Any, AsyncGenerator
_queue: asyncio.Queue[dict[str, Any]] | None = None
def get_queue() -> asyncio.Queue[dict[str, Any]]:
global _queue
if _queue is None:
_queue = asyncio.Queue(maxsize=200)
return _queue
async def publish(event: dict[str, Any]) -> None:
"""Emit an event. Drops silently if the queue is full (daemon is overloaded)."""
q = get_queue()
try:
q.put_nowait(event)
except asyncio.QueueFull:
pass
async def subscribe_forever() -> AsyncGenerator[dict[str, Any], None]:
"""Async generator — yields events as they arrive. Intended for daemon tasks."""
q = get_queue()
while True:
yield await q.get()
def reset() -> None:
"""Discard the current queue and create a fresh one. FOR TESTS ONLY."""
global _queue
_queue = None
+241
View File
@@ -0,0 +1,241 @@
"""IPC transport for the Pyra daemon.
Linux/macOS: Unix domain socket at ~/.pyra/daemon.sock (chmod 600, UID-checked).
Windows: TCP loopback on an OS-assigned port; actual port written to
~/.pyra/daemon.port so clients can connect without knowing it ahead
of time.
"""
from __future__ import annotations
import asyncio
import json
import os
import struct
import sys
from pathlib import Path
from typing import Any, Awaitable, Callable
# ── Protocol types ────────────────────────────────────────────────────────────
IpcMessage = dict[str, Any] # must have "cmd" key
IpcResponse = dict[str, Any] # must have "ok" and "data" keys
# ── Encode / decode ───────────────────────────────────────────────────────────
def encode_message(msg: IpcMessage) -> bytes:
return (json.dumps(msg) + "\n").encode()
def decode_message(line: bytes) -> IpcMessage:
try:
return json.loads(line.rstrip(b"\n"))
except json.JSONDecodeError as exc:
raise ValueError(f"Invalid IPC message: {exc}") from exc
# ── Address helpers ───────────────────────────────────────────────────────────
def is_unix_socket() -> bool:
return sys.platform != "win32"
def get_socket_path(cfg_socket_path: str) -> Path:
"""Expand ~ and return the Unix socket path."""
return Path(cfg_socket_path).expanduser()
def get_port_file_path() -> Path:
from pyra.utils.paths import pyra_home
return pyra_home() / "daemon.port"
def _read_windows_port() -> int | None:
port_file = get_port_file_path()
try:
return int(port_file.read_text().strip())
except (FileNotFoundError, ValueError):
return None
# ── Server ────────────────────────────────────────────────────────────────────
class IpcServer:
def __init__(
self,
address: Path | tuple[str, int],
handler: Callable[[IpcMessage], Awaitable[IpcResponse]],
) -> None:
self._address = address
self._handler = handler
self._server: asyncio.Server | None = None
async def start(self) -> None:
if is_unix_socket():
assert isinstance(self._address, Path)
sock_path = self._address
if sock_path.exists():
sock_path.unlink()
self._server = await asyncio.start_unix_server(
self._handle_client, path=str(sock_path)
)
os.chmod(sock_path, 0o600)
else:
host, port = self._address if isinstance(self._address, tuple) else ("127.0.0.1", 0)
self._server = await asyncio.start_server(
self._handle_client, host=host, port=port
)
actual_port = self._server.sockets[0].getsockname()[1]
port_file = get_port_file_path()
port_file.write_text(str(actual_port))
await self._server.start_serving()
async def stop(self) -> None:
if self._server is not None:
self._server.close()
try:
await asyncio.wait_for(self._server.wait_closed(), timeout=5.0)
except asyncio.TimeoutError:
pass
if is_unix_socket() and isinstance(self._address, Path):
try:
self._address.unlink()
except FileNotFoundError:
pass
else:
port_file = get_port_file_path()
try:
port_file.unlink()
except FileNotFoundError:
pass
async def _handle_client(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
try:
if is_unix_socket() and not self._check_peer_uid(writer):
writer.close()
return
line = await asyncio.wait_for(reader.readline(), timeout=5.0)
if not line:
return
try:
msg = decode_message(line)
except ValueError:
resp: IpcResponse = {"ok": False, "data": {"error": "invalid JSON"}}
else:
resp = await self._handler(msg)
writer.write(encode_message(resp))
await writer.drain()
except (asyncio.TimeoutError, ConnectionResetError, BrokenPipeError):
pass
finally:
try:
writer.close()
await writer.wait_closed()
except Exception:
pass
def _check_peer_uid(self, writer: asyncio.StreamWriter) -> bool:
"""Return True if the peer's UID matches ours. Falls back to True on error."""
try:
peer_uid = _get_peer_uid(writer)
if peer_uid is None:
return True # can't determine — allow (socket perms are the guard)
return peer_uid == os.getuid()
except Exception:
return True # don't crash the server on unexpected errors
# ── Client ────────────────────────────────────────────────────────────────────
class IpcClient:
def __init__(self, address: Path | tuple[str, int]) -> None:
self._address = address
async def send(self, msg: IpcMessage, timeout: float = 5.0) -> IpcResponse:
if is_unix_socket():
assert isinstance(self._address, Path)
reader, writer = await asyncio.wait_for(
asyncio.open_unix_connection(str(self._address)), timeout=timeout
)
else:
host, port = self._address if isinstance(self._address, tuple) else ("127.0.0.1", 0)
reader, writer = await asyncio.wait_for(
asyncio.open_connection(host, port), timeout=timeout
)
try:
writer.write(encode_message(msg))
await writer.drain()
line = await asyncio.wait_for(reader.readline(), timeout=timeout)
return decode_message(line)
finally:
writer.close()
try:
await writer.wait_closed()
except Exception:
pass
def send_command(
address: Path | tuple[str, int],
msg: IpcMessage,
timeout: float = 5.0,
) -> IpcResponse:
"""Synchronous wrapper around IpcClient.send() for CLI callers."""
return asyncio.run(IpcClient(address).send(msg, timeout=timeout))
# ── Peer UID detection ────────────────────────────────────────────────────────
def _get_peer_uid(writer: asyncio.StreamWriter) -> int | None:
"""Return the connecting peer's UID, or None if unavailable."""
try:
sock = writer.get_extra_info("socket")
if sock is None:
return None
if sys.platform == "linux":
# SO_PEERCRED: struct { pid_t pid; uid_t uid; gid_t gid; }
SO_PEERCRED = 17
cred = sock.getsockopt(
socket_module().SOL_SOCKET, SO_PEERCRED, struct.calcsize("3i")
)
_pid, uid, _gid = struct.unpack("3i", cred)
return uid
if sys.platform == "darwin":
return _macos_peer_uid(sock.fileno())
except Exception:
pass
return None
def socket_module(): # lazy import to avoid top-level import on non-Unix
import socket
return socket
def _macos_peer_uid(fd: int) -> int | None:
"""Use getpeereid(2) via ctypes to retrieve the peer UID on macOS."""
import ctypes
import ctypes.util
libc_name = ctypes.util.find_library("c")
if not libc_name:
return None
libc = ctypes.CDLL(libc_name)
euid = ctypes.c_uint32(0)
egid = ctypes.c_uint32(0)
if libc.getpeereid(fd, ctypes.byref(euid), ctypes.byref(egid)) != 0:
return None
return euid.value
+94
View File
@@ -0,0 +1,94 @@
"""PID file management for the Pyra daemon."""
from __future__ import annotations
import os
import sys
from pathlib import Path
class PidFileError(OSError):
"""Raised when a PID file operation fails due to a live conflicting process."""
class PidFile:
def __init__(self, path: Path) -> None:
self._path = path
def write(self) -> None:
"""Write the current PID atomically.
Raises PidFileError if a non-stale PID file already exists.
"""
existing = self.read()
if existing is not None and not self.is_stale():
raise PidFileError(
f"Daemon already running with PID {existing} "
f"(PID file: {self._path})"
)
tmp = self._path.with_suffix(".pid.tmp")
tmp.write_text(str(os.getpid()))
tmp.replace(self._path)
def read(self) -> int | None:
"""Return the PID from the file, or None if the file is absent or unreadable."""
try:
return int(self._path.read_text().strip())
except (FileNotFoundError, ValueError):
return None
def is_stale(self) -> bool:
"""True when the PID file exists but the process no longer runs."""
pid = self.read()
if pid is None:
return False
return not _process_is_alive(pid)
def remove(self) -> None:
"""Delete the PID file, ignoring FileNotFoundError."""
try:
self._path.unlink()
except FileNotFoundError:
pass
def __enter__(self) -> "PidFile":
self.write()
return self
def __exit__(self, *_: object) -> None:
self.remove()
def resolve_pid_path(cfg_path: str) -> Path:
"""Expand ~ and return an absolute Path."""
return Path(cfg_path).expanduser().resolve()
# ── Platform-specific process liveness check ─────────────────────────────────
def _process_is_alive(pid: int) -> bool:
if sys.platform == "win32":
return _win_process_is_alive(pid)
return _posix_process_is_alive(pid)
def _posix_process_is_alive(pid: int) -> bool:
try:
os.kill(pid, 0)
return True
except ProcessLookupError:
return False
except PermissionError:
# Process exists but is owned by another user — still alive.
return True
def _win_process_is_alive(pid: int) -> bool:
import ctypes
SYNCHRONIZE = 0x00100000
handle = ctypes.windll.kernel32.OpenProcess(SYNCHRONIZE, False, pid) # type: ignore[attr-defined]
if handle == 0:
return False
ctypes.windll.kernel32.CloseHandle(handle) # type: ignore[attr-defined]
return True
+212
View File
@@ -0,0 +1,212 @@
"""OS-specific service file generation and install/uninstall for the Pyra daemon."""
from __future__ import annotations
import platform
import shutil
import subprocess
import sys
from pathlib import Path
from typing import Literal
from pyra.utils.paths import safe_chmod
def detect_platform() -> Literal["macos", "linux", "windows"]:
s = platform.system()
if s == "Darwin":
return "macos"
if s == "Linux":
return "linux"
if s == "Windows":
return "windows"
raise RuntimeError(f"Unsupported platform: {s}")
def find_pyra_executable() -> str:
"""Return the full path to the active pyra executable.
Tries, in order:
1. shutil.which("pyra") — works when pyra is on PATH (activated venv)
2. sys.executable's sibling "pyra" script — covers editable installs
3. Fallback: sys.executable -m pyra
"""
found = shutil.which("pyra")
if found:
return found
sibling = Path(sys.executable).parent / "pyra"
if sibling.exists():
return str(sibling)
return f"{sys.executable} -m pyra"
# ── Template generators ───────────────────────────────────────────────────────
def render_launchd_plist(exe: str, log_file: str, pid_file: str) -> str:
log = str(Path(log_file).expanduser())
return f"""<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN"
"http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>com.pyra.daemon</string>
<key>ProgramArguments</key>
<array>
<string>{exe}</string>
<string>daemon</string>
<string>run</string>
</array>
<key>RunAtLoad</key>
<true/>
<key>KeepAlive</key>
<true/>
<key>StandardOutPath</key>
<string>{log}</string>
<key>StandardErrorPath</key>
<string>{log}</string>
<key>ProcessType</key>
<string>Background</string>
</dict>
</plist>
"""
def render_systemd_unit(exe: str, log_file: str) -> str:
log = str(Path(log_file).expanduser())
return f"""[Unit]
Description=Pyra Personal AI Assistant Daemon
After=default.target
[Service]
Type=simple
ExecStart={exe} daemon run
Restart=on-failure
RestartSec=5s
StandardOutput=append:{log}
StandardError=append:{log}
[Install]
WantedBy=default.target
"""
def render_schtasks_xml(exe: str) -> str:
return f"""<?xml version="1.0" encoding="UTF-16"?>
<Task version="1.2" xmlns="http://schemas.microsoft.com/windows/2004/02/mit/task">
<RegistrationInfo>
<Description>Pyra Personal AI Assistant Daemon</Description>
</RegistrationInfo>
<Triggers>
<LogonTrigger>
<Enabled>true</Enabled>
</LogonTrigger>
</Triggers>
<Settings>
<MultipleInstancesPolicy>IgnoreNew</MultipleInstancesPolicy>
<DisallowStartIfOnBatteries>false</DisallowStartIfOnBatteries>
<StopIfGoingOnBatteries>false</StopIfGoingOnBatteries>
<ExecutionTimeLimit>PT0S</ExecutionTimeLimit>
<RestartOnFailure>
<Interval>PT1M</Interval>
<Count>999</Count>
</RestartOnFailure>
</Settings>
<Actions Context="Author">
<Exec>
<Command>{exe}</Command>
<Arguments>daemon run</Arguments>
</Exec>
</Actions>
</Task>
"""
# ── Install / uninstall ───────────────────────────────────────────────────────
def install_service() -> None:
"""Generate and register the OS service for the current platform."""
from pyra.config.manager import load_config
cfg = load_config()
exe = find_pyra_executable()
plat = detect_platform()
if plat == "macos":
_install_launchd(exe, cfg.daemon.log_file, cfg.daemon.pid_file)
elif plat == "linux":
_install_systemd(exe, cfg.daemon.log_file)
else:
_install_windows(exe)
def uninstall_service() -> None:
"""Deregister the OS service for the current platform."""
plat = detect_platform()
if plat == "macos":
_uninstall_launchd()
elif plat == "linux":
_uninstall_systemd()
else:
_uninstall_windows()
# ── macOS launchd ─────────────────────────────────────────────────────────────
_PLIST_PATH = Path.home() / "Library" / "LaunchAgents" / "com.pyra.daemon.plist"
def _install_launchd(exe: str, log_file: str, pid_file: str) -> None:
_PLIST_PATH.parent.mkdir(parents=True, exist_ok=True)
_PLIST_PATH.write_text(render_launchd_plist(exe, log_file, pid_file))
safe_chmod(_PLIST_PATH, 0o644) # launchd requires 644, not 600
subprocess.run(["launchctl", "load", str(_PLIST_PATH)], check=True)
def _uninstall_launchd() -> None:
if _PLIST_PATH.exists():
subprocess.run(["launchctl", "unload", str(_PLIST_PATH)], check=False)
_PLIST_PATH.unlink()
# ── Linux systemd ─────────────────────────────────────────────────────────────
_SYSTEMD_UNIT = Path.home() / ".config" / "systemd" / "user" / "pyra.service"
def _install_systemd(exe: str, log_file: str) -> None:
_SYSTEMD_UNIT.parent.mkdir(parents=True, exist_ok=True)
_SYSTEMD_UNIT.write_text(render_systemd_unit(exe, log_file))
subprocess.run(["systemctl", "--user", "daemon-reload"], check=True)
subprocess.run(["systemctl", "--user", "enable", "pyra"], check=True)
def _uninstall_systemd() -> None:
subprocess.run(
["systemctl", "--user", "disable", "--now", "pyra"], check=False
)
if _SYSTEMD_UNIT.exists():
_SYSTEMD_UNIT.unlink()
subprocess.run(["systemctl", "--user", "daemon-reload"], check=False)
# ── Windows Task Scheduler ────────────────────────────────────────────────────
def _install_windows(exe: str) -> None:
from pyra.utils.paths import pyra_home
xml_path = pyra_home() / "daemon_task.xml"
# schtasks /Create /XML requires UTF-16 encoding
xml_path.write_text(render_schtasks_xml(exe), encoding="utf-16")
subprocess.run(
["schtasks", "/Create", "/TN", "PyraAssistant", "/XML", str(xml_path), "/F"],
check=True,
)
def _uninstall_windows() -> None:
subprocess.run(
["schtasks", "/Delete", "/TN", "PyraAssistant", "/F"], check=False
)
+26
View File
@@ -75,6 +75,32 @@ class PluginRegistry:
pass pass
return tasks return tasks
def get_daemon_task_factories(
self,
) -> list[tuple[str, Callable[[], Coroutine]]]: # type: ignore[type-arg]
"""Return (name, factory) pairs for all plugin daemon tasks.
Each factory re-calls plugin.daemon_tasks() to produce a fresh coroutine,
enabling the supervisor to restart crashed tasks without changing the plugin
protocol.
"""
factories: list[tuple[str, Callable[[], Coroutine]]] = [] # type: ignore[type-arg]
for plugin in self._plugins.values():
try:
initial = plugin.daemon_tasks()
n_tasks = len(initial)
for c in initial:
c.close() # prevent "coroutine never awaited" RuntimeWarning
except Exception:
continue
for i in range(n_tasks):
name = f"{plugin.name}.task_{i}"
# Capture plugin and index by value so each closure is independent.
def _factory(p=plugin, idx=i) -> Coroutine: # type: ignore[type-arg]
return p.daemon_tasks()[idx]
factories.append((name, _factory))
return factories
def find_tool(self, name: str) -> Tool | None: def find_tool(self, name: str) -> Tool | None:
return self._tools.get(name) return self._tools.get(name)
+362 -53
View File
@@ -1,3 +1,6 @@
import contextlib
import json
import httpx import httpx
import questionary import questionary
from rich.console import Console from rich.console import Console
@@ -7,6 +10,7 @@ from rich.text import Text
from pyra.config.manager import save_config from pyra.config.manager import save_config
from pyra.config.schema import GeneralConfig, ProviderConfig, PyraConfig from pyra.config.schema import GeneralConfig, ProviderConfig, PyraConfig
from pyra.setup.providers import PROVIDERS, Provider, get_provider from pyra.setup.providers import PROVIDERS, Provider, get_provider
from pyra.utils.paths import pyra_home, safe_chmod
console = Console() console = Console()
@@ -20,6 +24,72 @@ _USE_CASE_PLUGINS: dict[str, list[str]] = {
} }
_DRAFT_FILE = "setup.draft.json"
def _draft_path():
return pyra_home() / _DRAFT_FILE
def _save_draft(state: dict) -> None:
path = _draft_path()
path.write_text(json.dumps(state, indent=2))
safe_chmod(path, 0o600)
def _load_draft() -> dict | None:
path = _draft_path()
if not path.exists():
return None
try:
return json.loads(path.read_text())
except Exception:
return None
def _delete_draft() -> None:
with contextlib.suppress(FileNotFoundError):
_draft_path().unlink()
def _mark_step_done(state: dict, step: str) -> None:
state.setdefault("completed_steps", [])
if step not in state["completed_steps"]:
state["completed_steps"].append(step)
def _offer_resume(draft: dict) -> bool:
"""Show a summary of the incomplete setup and ask Resume / Start fresh."""
completed = draft.get("completed_steps", [])
step_display = {
"profile": f"Profile: {draft.get('user_name', '?')}",
"provider": f"Provider: {draft.get('provider_id', '?')}",
"model": f"Model: {draft.get('model', '?')}",
"api_key": "API key: stored in vault",
"connection": "Connection: test passed",
}
lines = ["[bold]An incomplete setup was found.[/bold]\n"]
for step, label in step_display.items():
if step in completed:
lines.append(f" [green]✓[/green] {label}")
else:
lines.append(f" [dim]○ {label.split(':')[0].strip()}: pending[/dim]")
console.print(Panel("\n".join(lines), title="Incomplete setup", border_style="yellow"))
console.print()
action = questionary.select(
"What would you like to do?",
choices=[
questionary.Choice("Resume from where you left off", value="resume"),
questionary.Choice("Start fresh", value="fresh"),
],
).ask()
if action is None:
raise SystemExit(0)
return action == "resume"
def run_setup() -> None: def run_setup() -> None:
console.print(Panel( console.print(Panel(
Text("Welcome to Pyra Setup", justify="center", style="bold cyan"), Text("Welcome to Pyra Setup", justify="center", style="bold cyan"),
@@ -28,36 +98,92 @@ def run_setup() -> None:
)) ))
console.print() console.print()
user_name, purpose, use_cases = _collect_user_profile() state: dict = {}
draft = _load_draft()
if draft:
if _offer_resume(draft):
state = draft
else:
_delete_draft()
provider = _choose_provider() try:
model = _choose_model(provider) # ── Step 1: profile ────────────────────────────────────────────────
if "profile" in state.get("completed_steps", []):
user_name = state["user_name"]
purpose = state["purpose"]
use_cases = state["use_cases"]
console.print(f" [dim]✓ Profile: {user_name}[/dim]")
else:
user_name, purpose, use_cases = _collect_user_profile()
state.update(user_name=user_name, purpose=purpose, use_cases=use_cases)
_mark_step_done(state, "profile")
_save_draft(state)
if provider.requires_key: # ── Step 2: provider ───────────────────────────────────────────────
_collect_api_key(provider) if "provider" in state.get("completed_steps", []):
provider = get_provider(state["provider_id"])
console.print(f" [dim]✓ Provider: {provider.display_name}[/dim]")
else:
provider = _choose_provider()
state.update(provider_id=provider.id)
_mark_step_done(state, "provider")
_save_draft(state)
_test_connection(provider, model) # ── Step 3: model ──────────────────────────────────────────────────
if "model" in state.get("completed_steps", []):
model = state["model"]
console.print(f" [dim]✓ Model: {model}[/dim]")
else:
model = _choose_model(provider)
state.update(model=model)
_mark_step_done(state, "model")
_save_draft(state)
cfg = PyraConfig( # ── Step 4: API key ────────────────────────────────────────────────
ai=ProviderConfig( if "api_key" not in state.get("completed_steps", []) and provider.requires_key:
provider_id=provider.id, from pyra.vault.reader import get_key as _get_key
model=model, if not _get_key(provider.id):
base_url=provider.base_url, _collect_api_key(provider)
), _mark_step_done(state, "api_key")
general=GeneralConfig(user_name=user_name, purpose=purpose), _save_draft(state)
)
save_config(cfg)
_suggest_plugins(use_cases) # ── Step 5: connection test ────────────────────────────────────────
if "connection" not in state.get("completed_steps", []):
model = _test_connection(provider, model)
state["model"] = model
_mark_step_done(state, "connection")
_save_draft(state)
console.print() # ── Finalise ───────────────────────────────────────────────────────
console.print(Panel( cfg = PyraConfig(
f"[green]Setup complete![/green]\n\n" ai=ProviderConfig(
f"Provider: [bold]{provider.display_name}[/bold]\n" provider_id=provider.id,
f"Model: [bold]{model}[/bold]\n\n" model=model,
"Run [bold cyan]pyra chat[/bold cyan] to start talking.", base_url=provider.base_url,
border_style="green", ),
)) general=GeneralConfig(user_name=user_name, purpose=purpose),
)
save_config(cfg)
_delete_draft()
_suggest_plugins(use_cases)
console.print()
console.print(Panel(
f"[green]Setup complete![/green]\n\n"
f"Provider: [bold]{provider.display_name}[/bold]\n"
f"Model: [bold]{model}[/bold]\n\n"
"Run [bold cyan]pyra chat[/bold cyan] to start talking.",
border_style="green",
))
except SystemExit:
if state.get("completed_steps"):
console.print()
console.print(
" [dim]Setup paused — run [bold]pyra setup[/bold] to resume.[/dim]"
)
raise
def _collect_user_profile() -> tuple[str, str, list[str]]: def _collect_user_profile() -> tuple[str, str, list[str]]:
@@ -135,26 +261,160 @@ def _choose_provider() -> Provider:
if provider.connectivity_check: if provider.connectivity_check:
_check_local_server(provider) _check_local_server(provider)
_show_local_model_status(provider)
return provider return provider
def _classify_error(exc: Exception) -> tuple[str, str]:
"""Return (short_label, resolution_hint) for a provider or network error."""
name = type(exc).__name__
module = type(exc).__module__ or ""
is_llm = "litellm" in module or "openai" in module
if is_llm:
if "AuthenticationError" in name:
return (
"Invalid API key",
"The provider rejected your API key.\n"
"Double-check it on the provider dashboard and re-enter it.",
)
if "NotFoundError" in name:
return (
"Model not found",
"The model name doesn't exist for this provider.\n"
"Check the exact model identifier on the provider's model list.",
)
if "RateLimitError" in name:
return (
"Rate limit reached",
"You've exceeded the provider's request rate.\n"
"Wait a few seconds and retry.",
)
if "ServiceUnavailable" in name:
return (
"Service temporarily unavailable",
"The provider's servers returned a 5xx error.\n"
"This is usually transient — wait a minute and retry.",
)
if "APIConnectionError" in name or "ConnectError" in name:
return (
"Cannot reach provider",
"A network error prevented the connection.\n"
"Check your internet connection and firewall settings.",
)
if "Timeout" in name:
return (
"Request timed out",
"The provider did not respond in time.\n"
"The service may be overloaded — retry in a moment.",
)
if "BadRequestError" in name or "InvalidRequest" in name:
return (
"Bad request",
"The request was rejected — the model name or parameters may be wrong.\n"
"Verify the exact model identifier.",
)
return ("Provider error", str(exc)[:300])
if "httpx" in module:
if "ConnectError" in name or "ConnectTimeout" in name:
return (
"Server not reachable",
"Could not connect to the local server.\n"
"Make sure it is running and listening on the expected address.",
)
if "Timeout" in name:
return (
"Connection timed out",
"The local server did not respond in time.\n"
"It may still be starting up — wait a moment and retry.",
)
if "HTTPStatusError" in name:
code = getattr(getattr(exc, "response", None), "status_code", 0)
if code == 401:
return ("Unauthorized (401)", "The server requires credentials that were not provided.")
if code == 404:
return ("Endpoint not found (404)", "The API endpoint was not found — check the server version.")
return (f"HTTP {code}", f"The server returned an unexpected error (HTTP {code}).")
return ("Unexpected error", str(exc)[:300])
def _check_local_server(provider: Provider) -> None: def _check_local_server(provider: Provider) -> None:
console.print(f" Checking connection to [bold]{provider.display_name}[/bold]...", end=" ") while True:
try:
resp = httpx.get(provider.connectivity_check, timeout=3.0)
resp.raise_for_status()
console.print("[green]✓[/green]")
except Exception:
console.print("[yellow]✗ (server not reachable)[/yellow]")
console.print( console.print(
f" [yellow]Warning:[/yellow] Could not reach {provider.base_url}.\n" f" Checking connection to [bold]{provider.display_name}[/bold]...", end=" "
f" Make sure {provider.display_name} is running before using Pyra."
) )
try:
resp = httpx.get(provider.connectivity_check, timeout=3.0)
resp.raise_for_status()
console.print("[green]✓[/green]")
return
except Exception as exc:
label, hint = _classify_error(exc)
console.print("[yellow]✗[/yellow]")
console.print()
console.print(Panel(
f"[bold yellow]{label}[/bold yellow]\n\n{hint}",
title="Connection problem",
border_style="yellow",
))
action = questionary.select(
"How would you like to proceed?",
choices=[
questionary.Choice("Retry", value="retry"),
questionary.Choice(
"Continue anyway (model list may be unavailable)", value="continue"
),
questionary.Choice("Abort setup", value="abort"),
],
).ask()
if action is None or action == "abort":
raise SystemExit(0)
if action == "continue":
return
# "retry" → loop
def fetch_loaded_models(provider: Provider) -> list[str]:
"""Return models currently loaded in RAM from a local provider's API."""
if not provider.base_url:
return []
try:
if provider.id == "ollama":
resp = httpx.get(f"{provider.base_url}/api/ps", timeout=3.0)
resp.raise_for_status()
return [m["name"] for m in resp.json().get("models", [])]
elif provider.id == "lmstudio":
resp = httpx.get("http://localhost:1234/api/v0/models", timeout=3.0)
resp.raise_for_status()
return [
m["id"] for m in resp.json().get("data", [])
if m.get("state") == "loaded"
]
else: # llamacpp — /models returns only the active loaded model
resp = httpx.get(f"{provider.base_url}/models", timeout=3.0)
resp.raise_for_status()
return [m["id"] for m in resp.json().get("data", [])]
except Exception:
return []
def _show_local_model_status(provider: Provider) -> None:
"""Print a one-line status showing which models are currently loaded."""
models = fetch_loaded_models(provider)
if not models:
console.print(" [yellow]No model currently loaded[/yellow]")
elif len(models) == 1:
console.print(f" [green]Loaded model:[/green] {models[0]}")
else:
names = ", ".join(models)
console.print(f" [green]{len(models)} models loaded:[/green] {names}")
def _fetch_local_models(provider: Provider) -> list[str]: def _fetch_local_models(provider: Provider) -> list[str]:
"""Return currently loaded/available models from a local provider's API.""" """Return currently loaded models from a local provider's API."""
if not provider.base_url: if not provider.base_url:
return [] return []
try: try:
@@ -162,6 +422,13 @@ def _fetch_local_models(provider: Provider) -> list[str]:
resp = httpx.get(f"{provider.base_url}/api/tags", timeout=3.0) resp = httpx.get(f"{provider.base_url}/api/tags", timeout=3.0)
resp.raise_for_status() resp.raise_for_status()
return [m["name"] for m in resp.json().get("models", [])] return [m["name"] for m in resp.json().get("models", [])]
elif provider.id == "lmstudio":
resp = httpx.get("http://localhost:1234/api/v0/models", timeout=3.0)
resp.raise_for_status()
return [
m["id"] for m in resp.json().get("data", [])
if m.get("state") == "loaded"
]
else: else:
resp = httpx.get(f"{provider.base_url}/models", timeout=3.0) resp = httpx.get(f"{provider.base_url}/models", timeout=3.0)
resp.raise_for_status() resp.raise_for_status()
@@ -264,26 +531,68 @@ def _collect_api_key(provider: Provider) -> None:
console.print(" [green]✓ Key stored in vault[/green]") console.print(" [green]✓ Key stored in vault[/green]")
def _test_connection(provider: Provider, model: str) -> None: def _test_connection(provider: Provider, model: str) -> str:
from pyra.vault.reader import get_key from pyra.vault.reader import get_key
console.print("\n Running test call...", end=" ") while True:
try: console.print("\n Running connection test...", end=" ")
import litellm try:
import litellm
# Local providers don't need a real key but litellm still requires the field api_key = get_key(provider.id) if provider.requires_key else "local"
api_key = get_key(provider.id) if provider.requires_key else "local" kwargs: dict = {
kwargs: dict = { "model": f"{provider.litellm_prefix}{model}",
"model": f"{provider.litellm_prefix}{model}", "messages": [{"role": "user", "content": "Reply with exactly: OK"}],
"messages": [{"role": "user", "content": "Reply with exactly: OK"}], "max_tokens": 10,
"max_tokens": 10, "api_key": api_key,
"api_key": api_key, }
} if provider.base_url:
if provider.base_url: kwargs["api_base"] = provider.base_url
kwargs["api_base"] = provider.base_url
litellm.completion(**kwargs) litellm.completion(**kwargs)
console.print("[green]✓ Connection OK[/green]") console.print("[green]✓ Connection OK[/green]")
except Exception as exc: return model
console.print(f"[yellow]✗ Test call failed: {exc}[/yellow]")
console.print(" [dim]You can still proceed — check your config with 'pyra setup' again.[/dim]") except Exception as exc:
label, hint = _classify_error(exc)
console.print("[red]✗[/red]")
console.print()
console.print(Panel(
f"[bold red]{label}[/bold red]\n\n{hint}",
title="Test call failed",
border_style="red",
))
exc_name = type(exc).__name__
is_auth_error = "AuthenticationError" in exc_name
is_model_error = any(
kw in exc_name for kw in ("NotFoundError", "BadRequestError", "InvalidRequest")
)
choices = [questionary.Choice("Retry", value="retry")]
if is_model_error:
choices.append(questionary.Choice("Change model", value="change_model"))
if provider.requires_key and is_auth_error:
choices.append(questionary.Choice("Re-enter API key", value="rekey"))
choices += [
questionary.Choice("Skip test and continue setup", value="skip"),
questionary.Choice("Abort setup", value="abort"),
]
action = questionary.select(
"How would you like to proceed?",
choices=choices,
).ask()
if action is None or action == "abort":
raise SystemExit(0)
if action == "skip":
console.print(
" [dim]Test skipped — run [bold]pyra setup[/bold] again if chat doesn't work.[/dim]"
)
return model
if action == "change_model":
model = _choose_model(provider)
elif action == "rekey":
_collect_api_key(provider)
# loop → retry (with possibly new model or key)
+226
View File
@@ -0,0 +1,226 @@
"""Unit tests for the daemon core — PluginSupervisor and IPC handler dispatch."""
from __future__ import annotations
import asyncio
import pytest
from pyra.daemon.core import PluginSupervisor, _make_ipc_handler
# ── Helpers ───────────────────────────────────────────────────────────────────
async def _drain(n: int = 20) -> None:
"""Yield to the event loop n times to let scheduled tasks run."""
for _ in range(n):
await asyncio.sleep(0)
# ── PluginSupervisor — lifecycle ──────────────────────────────────────────────
async def test_supervisor_empty_starts_and_stops_cleanly() -> None:
sup = PluginSupervisor()
await sup.start()
await sup.stop()
assert sup.status() == []
async def test_supervisor_runs_task_to_completion() -> None:
done = asyncio.Event()
async def task():
done.set()
sup = PluginSupervisor()
sup._RESTART_DELAY = 0.0
sup.add_task("t", task)
await sup.start()
await asyncio.wait_for(done.wait(), timeout=1.0)
await sup.stop()
assert sup._records[0].restart_count == 0
assert sup._records[0].last_error is None
async def test_supervisor_restarts_crashed_task() -> None:
call_count = 0
completed = asyncio.Event()
async def flaky():
nonlocal call_count
call_count += 1
if call_count == 1:
raise RuntimeError("first call fails")
completed.set()
sup = PluginSupervisor()
sup._RESTART_DELAY = 0.0
sup.add_task("flaky", flaky)
await sup.start()
await asyncio.wait_for(completed.wait(), timeout=1.0)
await sup.stop()
assert sup._records[0].restart_count == 1
assert "RuntimeError" in (sup._records[0].last_error or "")
async def test_supervisor_gives_up_after_max_restarts() -> None:
async def always_fails():
raise ValueError("always")
sup = PluginSupervisor()
sup._RESTART_DELAY = 0.0
sup._MAX_RESTARTS = 3
sup.add_task("failing", always_fails)
await sup.start()
# Allow enough iterations for 3 restarts + give-up.
for _ in range(200):
await asyncio.sleep(0)
if sup._records[0].task and sup._records[0].task.done():
break
await sup.stop()
assert sup._records[0].restart_count == 3
assert sup._records[0].last_error is not None
# ── PluginSupervisor — status ─────────────────────────────────────────────────
async def test_supervisor_status_returns_correct_shape() -> None:
sup = PluginSupervisor()
sup._RESTART_DELAY = 0.0
async def noop():
pass
sup.add_task("noop", noop)
await sup.start()
await _drain()
statuses = sup.status()
assert len(statuses) == 1
s = statuses[0]
assert set(s.keys()) == {"name", "alive", "restart_count", "last_error"}
assert s["name"] == "noop"
assert isinstance(s["alive"], bool)
assert isinstance(s["restart_count"], int)
await sup.stop()
async def test_supervisor_status_empty_when_no_tasks() -> None:
sup = PluginSupervisor()
await sup.start()
assert sup.status() == []
await sup.stop()
# ── PluginSupervisor — reload ─────────────────────────────────────────────────
async def test_supervisor_reload_restarts_tasks() -> None:
call_count = 0
async def counting():
nonlocal call_count
call_count += 1
# Hang until cancelled so reload can cancel it.
await asyncio.sleep(10)
sup = PluginSupervisor()
sup._RESTART_DELAY = 0.0
sup.add_task("c", counting)
await sup.start()
await _drain()
assert call_count == 1
await sup.reload()
await _drain()
# After reload, the task should have been restarted (called a second time).
assert call_count == 2
assert sup._records[0].restart_count == 0 # reset by reload
await sup.stop()
async def test_supervisor_reload_resets_restart_count() -> None:
call_count = 0
async def flaky():
nonlocal call_count
call_count += 1
if call_count <= 2:
raise RuntimeError("crash")
await asyncio.sleep(10)
sup = PluginSupervisor()
sup._RESTART_DELAY = 0.0
sup.add_task("f", flaky)
await sup.start()
# Wait for 2 crashes to accumulate.
for _ in range(200):
await asyncio.sleep(0)
if sup._records[0].restart_count >= 2:
break
assert sup._records[0].restart_count == 2
await sup.reload()
# Reload must reset the counter.
assert sup._records[0].restart_count == 0
await sup.stop()
# ── IPC command handler ───────────────────────────────────────────────────────
async def test_ipc_handler_ping() -> None:
sup = PluginSupervisor()
handler = _make_ipc_handler(sup)
resp = await handler({"cmd": "ping"})
assert resp["ok"] is True
assert resp["data"]["pong"] is True
async def test_ipc_handler_status_shape() -> None:
sup = PluginSupervisor()
handler = _make_ipc_handler(sup)
resp = await handler({"cmd": "status"})
assert resp["ok"] is True
assert "uptime" in resp["data"]
assert "pid" in resp["data"]
assert "tasks" in resp["data"]
assert isinstance(resp["data"]["tasks"], list)
async def test_ipc_handler_stop_signals_shutdown() -> None:
sup = PluginSupervisor()
handler = _make_ipc_handler(sup)
assert not sup._shutdown.is_set()
resp = await handler({"cmd": "stop"})
assert resp["ok"] is True
assert sup._shutdown.is_set()
async def test_ipc_handler_reload_returns_task_count() -> None:
sup = PluginSupervisor()
handler = _make_ipc_handler(sup)
resp = await handler({"cmd": "reload"})
assert resp["ok"] is True
assert resp["data"]["tasks_reloaded"] == 0
async def test_ipc_handler_unknown_command() -> None:
sup = PluginSupervisor()
handler = _make_ipc_handler(sup)
resp = await handler({"cmd": "bogus"})
assert resp["ok"] is False
assert "error" in resp["data"]
assert "bogus" in resp["data"]["error"]
+162
View File
@@ -0,0 +1,162 @@
"""Unit tests for the IPC layer."""
from __future__ import annotations
import asyncio
import os
import sys
import tempfile
from pathlib import Path
import pytest
from pyra.daemon.ipc import (
IpcClient,
IpcMessage,
IpcResponse,
IpcServer,
decode_message,
encode_message,
is_unix_socket,
)
@pytest.fixture
def sock_path():
"""Short socket path that fits within macOS's 104-char AF_UNIX limit."""
with tempfile.TemporaryDirectory(dir="/tmp") as d:
yield Path(d) / "t.sock"
# ── Protocol encode / decode ──────────────────────────────────────────────────
def test_encode_appends_newline() -> None:
data = encode_message({"cmd": "ping"})
assert data.endswith(b"\n")
def test_encode_is_valid_json() -> None:
import json
data = encode_message({"cmd": "status", "extra": 42})
assert json.loads(data) == {"cmd": "status", "extra": 42}
def test_decode_roundtrip() -> None:
msg: IpcMessage = {"cmd": "stop"}
assert decode_message(encode_message(msg)) == msg
def test_decode_strips_newline() -> None:
assert decode_message(b'{"cmd": "stop"}\n')["cmd"] == "stop"
def test_decode_raises_on_bad_json() -> None:
with pytest.raises(ValueError, match="Invalid IPC message"):
decode_message(b"not json\n")
def test_decode_raises_on_empty_line() -> None:
with pytest.raises(ValueError):
decode_message(b"\n")
# ── is_unix_socket ────────────────────────────────────────────────────────────
def test_is_unix_socket_matches_platform() -> None:
if sys.platform == "win32":
assert not is_unix_socket()
else:
assert is_unix_socket()
# ── Server + client roundtrip (Unix only) ─────────────────────────────────────
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
async def test_server_client_ping(sock_path: Path) -> None:
async def handler(msg: IpcMessage) -> IpcResponse:
return {"ok": True, "data": {"pong": True}}
server = IpcServer(sock_path, handler)
await server.start()
try:
resp = await IpcClient(sock_path).send({"cmd": "ping"})
assert resp["ok"] is True
assert resp["data"]["pong"] is True
finally:
await server.stop()
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
async def test_server_echoes_error_for_bad_json(sock_path: Path) -> None:
async def handler(msg: IpcMessage) -> IpcResponse:
return {"ok": True, "data": {}}
server = IpcServer(sock_path, handler)
await server.start()
try:
reader, writer = await asyncio.open_unix_connection(str(sock_path))
writer.write(b"not valid json\n")
await writer.drain()
line = await asyncio.wait_for(reader.readline(), timeout=3.0)
resp = decode_message(line)
assert resp["ok"] is False
assert "error" in resp["data"]
finally:
try:
writer.close()
except Exception:
pass
await server.stop()
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
async def test_handler_response_returned_to_client(sock_path: Path) -> None:
async def handler(msg: IpcMessage) -> IpcResponse:
if msg.get("cmd") == "status":
return {"ok": True, "data": {"uptime": 99.0}}
return {"ok": False, "data": {"error": "unknown"}}
server = IpcServer(sock_path, handler)
await server.start()
try:
resp = await IpcClient(sock_path).send({"cmd": "status"})
assert resp["ok"] is True
assert resp["data"]["uptime"] == 99.0
resp2 = await IpcClient(sock_path).send({"cmd": "bogus"})
assert resp2["ok"] is False
finally:
await server.stop()
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
async def test_client_raises_when_no_server(sock_path: Path) -> None:
client = IpcClient(sock_path)
with pytest.raises((ConnectionRefusedError, FileNotFoundError, OSError)):
await client.send({"cmd": "ping"})
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
async def test_socket_file_chmod_600(sock_path: Path) -> None:
async def handler(msg: IpcMessage) -> IpcResponse:
return {"ok": True, "data": {}}
server = IpcServer(sock_path, handler)
await server.start()
try:
mode = oct(sock_path.stat().st_mode & 0o777)
assert mode == oct(0o600), f"Expected 0o600, got {mode}"
finally:
await server.stop()
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
async def test_stop_removes_socket_file(sock_path: Path) -> None:
async def handler(msg: IpcMessage) -> IpcResponse:
return {"ok": True, "data": {}}
server = IpcServer(sock_path, handler)
await server.start()
assert sock_path.exists()
await server.stop()
assert not sock_path.exists()
+103
View File
@@ -0,0 +1,103 @@
"""Unit tests for daemon PID file management."""
from __future__ import annotations
import os
from pathlib import Path
import pytest
from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path
def test_write_creates_file(tmp_path: Path) -> None:
p = PidFile(tmp_path / "daemon.pid")
p.write()
assert (tmp_path / "daemon.pid").exists()
assert int((tmp_path / "daemon.pid").read_text().strip()) == os.getpid()
def test_read_returns_none_when_absent(tmp_path: Path) -> None:
p = PidFile(tmp_path / "daemon.pid")
assert p.read() is None
def test_read_returns_pid_when_present(tmp_path: Path) -> None:
pid_file = tmp_path / "daemon.pid"
pid_file.write_text("12345")
p = PidFile(pid_file)
assert p.read() == 12345
def test_read_returns_none_on_bad_content(tmp_path: Path) -> None:
pid_file = tmp_path / "daemon.pid"
pid_file.write_text("not-a-number")
p = PidFile(pid_file)
assert p.read() is None
def test_is_stale_false_for_self(tmp_path: Path) -> None:
p = PidFile(tmp_path / "daemon.pid")
p.write()
assert not p.is_stale()
def test_is_stale_true_for_dead_pid(tmp_path: Path) -> None:
pid_file = tmp_path / "daemon.pid"
pid_file.write_text("999999999") # unrealistically large PID
p = PidFile(pid_file)
assert p.is_stale()
def test_is_stale_false_when_file_absent(tmp_path: Path) -> None:
p = PidFile(tmp_path / "daemon.pid")
assert not p.is_stale()
def test_remove_deletes_file(tmp_path: Path) -> None:
p = PidFile(tmp_path / "daemon.pid")
p.write()
p.remove()
assert not (tmp_path / "daemon.pid").exists()
def test_remove_is_idempotent(tmp_path: Path) -> None:
p = PidFile(tmp_path / "daemon.pid")
p.remove() # must not raise
def test_context_manager_writes_and_removes(tmp_path: Path) -> None:
pid_file = tmp_path / "daemon.pid"
p = PidFile(pid_file)
with p:
assert pid_file.exists()
assert int(pid_file.read_text().strip()) == os.getpid()
assert not pid_file.exists()
def test_write_raises_when_live_pid_exists(tmp_path: Path) -> None:
p = PidFile(tmp_path / "daemon.pid")
p.write() # writes self PID (which is alive)
p2 = PidFile(tmp_path / "daemon.pid")
with pytest.raises(PidFileError):
p2.write()
def test_write_succeeds_over_stale_pid(tmp_path: Path) -> None:
pid_file = tmp_path / "daemon.pid"
pid_file.write_text("999999999") # stale
p = PidFile(pid_file)
p.write() # should not raise
assert int(pid_file.read_text().strip()) == os.getpid()
def test_resolve_pid_path_expands_tilde() -> None:
result = resolve_pid_path("~/.pyra/daemon.pid")
assert not str(result).startswith("~")
assert result.is_absolute()
def test_resolve_pid_path_absolute_unchanged(tmp_path: Path) -> None:
path = tmp_path / "daemon.pid"
result = resolve_pid_path(str(path))
assert result == path
+189
View File
@@ -0,0 +1,189 @@
"""Unit tests for daemon service file generation and platform detection."""
from __future__ import annotations
import subprocess
import sys
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from pyra.daemon.service import (
detect_platform,
find_pyra_executable,
render_launchd_plist,
render_systemd_unit,
render_schtasks_xml,
)
# ── Template rendering ────────────────────────────────────────────────────────
def test_render_launchd_plist_contains_exe() -> None:
xml = render_launchd_plist("/usr/local/bin/pyra", "~/.pyra/daemon.log", "~/.pyra/daemon.pid")
assert "/usr/local/bin/pyra" in xml
assert "<string>daemon</string>" in xml
assert "<string>run</string>" in xml
assert "com.pyra.daemon" in xml
assert "<true/>" in xml # KeepAlive and RunAtLoad
def test_render_launchd_plist_expands_log_tilde() -> None:
xml = render_launchd_plist("/bin/pyra", "~/.pyra/daemon.log", "~/.pyra/daemon.pid")
assert "~" not in xml
def test_render_systemd_unit_contains_exe() -> None:
unit = render_systemd_unit("/usr/local/bin/pyra", "~/.pyra/daemon.log")
assert "ExecStart=/usr/local/bin/pyra daemon run" in unit
assert "Restart=on-failure" in unit
assert "Type=simple" in unit
assert "WantedBy=default.target" in unit
def test_render_systemd_unit_expands_log_tilde() -> None:
unit = render_systemd_unit("/bin/pyra", "~/.pyra/daemon.log")
assert "~" not in unit
def test_render_schtasks_xml_contains_exe() -> None:
xml = render_schtasks_xml("C:\\Users\\test\\pyra.exe")
assert "C:\\Users\\test\\pyra.exe" in xml
assert "LogonTrigger" in xml
assert "daemon run" in xml
assert "IgnoreNew" in xml
def test_render_schtasks_xml_no_time_limit() -> None:
xml = render_schtasks_xml("pyra.exe")
assert "PT0S" in xml # ExecutionTimeLimit=PT0S means unlimited
# ── Platform detection ────────────────────────────────────────────────────────
def test_detect_platform_returns_known_value() -> None:
result = detect_platform()
assert result in ("macos", "linux", "windows")
@pytest.mark.parametrize("system,expected", [
("Darwin", "macos"),
("Linux", "linux"),
("Windows", "windows"),
])
def test_detect_platform_maps_correctly(system: str, expected: str) -> None:
with patch("platform.system", return_value=system):
assert detect_platform() == expected
def test_detect_platform_raises_on_unknown() -> None:
with patch("platform.system", return_value="FreeBSD"):
with pytest.raises(RuntimeError, match="Unsupported platform"):
detect_platform()
# ── Executable detection ──────────────────────────────────────────────────────
def test_find_pyra_executable_returns_string() -> None:
result = find_pyra_executable()
assert isinstance(result, str)
assert len(result) > 0
def test_find_pyra_executable_uses_which_when_available(tmp_path: Path) -> None:
fake_pyra = tmp_path / "pyra"
fake_pyra.touch()
with patch("shutil.which", return_value=str(fake_pyra)):
assert find_pyra_executable() == str(fake_pyra)
def test_find_pyra_executable_falls_back_to_sibling(tmp_path: Path) -> None:
fake_python = tmp_path / "python3"
fake_pyra = tmp_path / "pyra"
fake_pyra.touch()
with patch("shutil.which", return_value=None):
with patch("sys.executable", str(fake_python)):
assert find_pyra_executable() == str(fake_pyra)
def test_find_pyra_executable_falls_back_to_module(tmp_path: Path) -> None:
fake_python = tmp_path / "python3"
with patch("shutil.which", return_value=None):
with patch("sys.executable", str(fake_python)):
result = find_pyra_executable()
assert result == f"{fake_python} -m pyra"
# ── Install / uninstall (subprocess mocked) ───────────────────────────────────
@pytest.mark.skipif(sys.platform == "win32", reason="launchd install is macOS-only")
def test_install_launchd_writes_plist_and_calls_launchctl(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
import pyra.daemon.service as svc
plist_path = tmp_path / "Library" / "LaunchAgents" / "com.pyra.daemon.plist"
monkeypatch.setattr(svc, "_PLIST_PATH", plist_path)
calls: list[list[str]] = []
monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: calls.append(cmd))
svc._install_launchd("/usr/local/bin/pyra", "~/.pyra/daemon.log", "~/.pyra/daemon.pid")
assert plist_path.exists()
assert "com.pyra.daemon" in plist_path.read_text()
assert any("launchctl" in c[0] for c in calls)
@pytest.mark.skipif(sys.platform == "win32", reason="systemd install is Linux-only")
def test_install_systemd_writes_unit_and_calls_systemctl(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
import pyra.daemon.service as svc
unit_path = tmp_path / ".config" / "systemd" / "user" / "pyra.service"
monkeypatch.setattr(svc, "_SYSTEMD_UNIT", unit_path)
calls: list[list[str]] = []
monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: calls.append(cmd))
svc._install_systemd("/usr/local/bin/pyra", "~/.pyra/daemon.log")
assert unit_path.exists()
assert "ExecStart" in unit_path.read_text()
assert any("systemctl" in c[0] for c in calls)
@pytest.mark.skipif(sys.platform == "win32", reason="launchd uninstall is macOS-only")
def test_uninstall_launchd_removes_plist(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
import pyra.daemon.service as svc
plist_path = tmp_path / "Library" / "LaunchAgents" / "com.pyra.daemon.plist"
plist_path.parent.mkdir(parents=True)
plist_path.write_text("<plist/>")
monkeypatch.setattr(svc, "_PLIST_PATH", plist_path)
monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: None)
svc._uninstall_launchd()
assert not plist_path.exists()
@pytest.mark.skipif(sys.platform == "win32", reason="systemd uninstall is Linux-only")
def test_uninstall_systemd_removes_unit(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
import pyra.daemon.service as svc
unit_path = tmp_path / ".config" / "systemd" / "user" / "pyra.service"
unit_path.parent.mkdir(parents=True)
unit_path.write_text("[Service]")
monkeypatch.setattr(svc, "_SYSTEMD_UNIT", unit_path)
monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: None)
svc._uninstall_systemd()
assert not unit_path.exists()
+384
View File
@@ -0,0 +1,384 @@
"""Unit tests for the email plugin — pure-logic helpers, no network calls."""
from __future__ import annotations
import json
from unittest.mock import MagicMock, patch
import pytest
# Import helpers directly — they depend only on stdlib
from pyra.bundled_plugins.email.plugin import (
EmailMessage,
FilterRule,
_build_imap_search,
_decode_header,
_gmail_action_summary,
_gmail_criteria_summary,
_normalize_to_gmail,
_normalize_to_outlook,
_outlook_actions_summary,
_parse_raw_message,
_strip_html,
)
# ── _strip_html ────────────────────────────────────────────────────────────────
def test_strip_html_removes_tags():
result = _strip_html("<p>Hello <b>world</b></p>")
assert "<" not in result
assert "Hello" in result
assert "world" in result
def test_strip_html_decodes_entities():
result = _strip_html("&lt;script&gt; &amp; &quot;test&quot;")
assert "<script>" in result
assert "&" in result
def test_strip_html_removes_style_and_script():
html = "<style>body{color:red}</style><script>alert(1)</script><p>Keep this</p>"
result = _strip_html(html)
assert "color" not in result
assert "alert" not in result
assert "Keep this" in result
def test_strip_html_plain_text_unchanged():
result = _strip_html("Hello, world!")
assert result == "Hello, world!"
# ── _decode_header ─────────────────────────────────────────────────────────────
def test_decode_header_plain():
assert _decode_header("Hello") == "Hello"
def test_decode_header_encoded():
# RFC 2047 base64-encoded UTF-8
encoded = "=?utf-8?b?SGVsbG8gV29ybGQ=?="
assert _decode_header(encoded) == "Hello World"
def test_decode_header_empty():
assert _decode_header("") == ""
# ── _parse_raw_message ─────────────────────────────────────────────────────────
def _make_raw_email(
from_addr: str = "sender@example.com",
to_addr: str = "recipient@example.com",
subject: str = "Test Subject",
body: str = "Hello from test.",
message_id: str = "<test123@example.com>",
) -> bytes:
return (
f"From: {from_addr}\r\n"
f"To: {to_addr}\r\n"
f"Subject: {subject}\r\n"
f"Date: Mon, 01 Jan 2024 12:00:00 +0000\r\n"
f"Message-ID: {message_id}\r\n"
f"MIME-Version: 1.0\r\n"
f"Content-Type: text/plain; charset=utf-8\r\n"
f"\r\n"
f"{body}\r\n"
).encode()
def test_parse_raw_message_basic_fields():
raw = _make_raw_email()
msg = _parse_raw_message(raw, uid="42", folder="INBOX", is_read=False)
assert msg.uid == "42"
assert msg.folder == "INBOX"
assert msg.from_addr == "sender@example.com"
assert "recipient@example.com" in msg.to_addrs
assert msg.subject == "Test Subject"
assert msg.body_text == "Hello from test."
assert msg.is_read is False
assert msg.has_attachments is False
assert msg.attachments == []
assert msg.message_id == "<test123@example.com>"
def test_parse_raw_message_snippet_truncated():
long_body = "A" * 500
raw = _make_raw_email(body=long_body)
msg = _parse_raw_message(raw, uid="1", folder="INBOX", is_read=True)
assert len(msg.snippet) <= 200
def test_parse_raw_message_body_truncated_at_8000():
huge_body = "x" * 10000
raw = _make_raw_email(body=huge_body)
msg = _parse_raw_message(raw, uid="1", folder="INBOX", is_read=False)
assert len(msg.body_text) <= 8030 # 8000 + "[...truncated]"
assert "truncated" in msg.body_text
def test_parse_raw_message_html_stripped():
raw = _make_raw_email(body="<html><body><p>Plain text content</p></body></html>")
# Create HTML part manually
html_raw = (
"From: a@b.com\r\nTo: c@d.com\r\nSubject: Test\r\n"
"MIME-Version: 1.0\r\nContent-Type: text/html; charset=utf-8\r\n\r\n"
"<html><body><p>Plain text content</p></body></html>\r\n"
).encode()
msg = _parse_raw_message(html_raw, uid="1", folder="INBOX", is_read=False)
assert "<" not in msg.body_text
assert "Plain text content" in msg.body_text
# ── _build_imap_search ─────────────────────────────────────────────────────────
def test_build_imap_search_unread():
from imap_tools import AND
criteria = _build_imap_search("unread invoices")
# Should produce an AND with seen=False
assert criteria is not None
def test_build_imap_search_from():
criteria = _build_imap_search("from:boss@company.com")
assert criteria is not None
def test_build_imap_search_subject():
criteria = _build_imap_search("subject: meeting notes")
assert criteria is not None
def test_build_imap_search_fallback():
criteria = _build_imap_search("random search terms")
assert criteria is not None
# ── Gmail rule normalisation ───────────────────────────────────────────────────
def test_normalize_to_gmail_from_condition():
criteria, action = _normalize_to_gmail({"from": "boss@company.com"}, {"mark_read": True})
assert criteria.get("from") == "boss@company.com"
assert "UNREAD" in action.get("removeLabelIds", [])
def test_normalize_to_gmail_move_to():
criteria, action = _normalize_to_gmail({"subject": "invoice"}, {"move_to": "Bills"})
assert criteria.get("subject") == "invoice"
assert "Bills" in action.get("addLabelIds", [])
assert "INBOX" in action.get("removeLabelIds", [])
def test_normalize_to_gmail_mark_important():
_, action = _normalize_to_gmail({}, {"mark_important": True})
assert "IMPORTANT" in action.get("addLabelIds", [])
def test_normalize_to_gmail_forward():
_, action = _normalize_to_gmail({}, {"forward_to": "archive@example.com"})
assert action.get("forward") == "archive@example.com"
def test_gmail_criteria_summary_empty():
assert _gmail_criteria_summary({}) == "(any)"
def test_gmail_criteria_summary_from():
assert "from=boss" in _gmail_criteria_summary({"from": "boss@company.com"})
def test_gmail_action_summary_empty():
assert _gmail_action_summary({}) == "(no action)"
# ── Outlook rule normalisation ─────────────────────────────────────────────────
def test_normalize_to_outlook_from():
body = _normalize_to_outlook({"from": "a@b.com"}, {"move_to": "Work"})
from_addrs = body["conditions"].get("fromAddresses", [])
assert any("a@b.com" in str(a) for a in from_addrs)
assert body["actions"].get("moveToFolder") == "Work"
def test_normalize_to_outlook_subject_contains():
body = _normalize_to_outlook({"subject": "invoice"}, {"mark_read": True})
assert "invoice" in body["conditions"].get("subjectContains", [])
assert body["actions"].get("markAsRead") is True
def test_normalize_to_outlook_mark_important():
body = _normalize_to_outlook({}, {"mark_important": True})
assert body["actions"].get("markImportance") == "high"
def test_normalize_to_outlook_delete():
body = _normalize_to_outlook({}, {"delete": True})
assert body["actions"].get("delete") is True
# ── email_move folder-not-found path ──────────────────────────────────────────
def test_email_move_returns_error_when_folder_missing(tmp_pyra_home):
from pyra.bundled_plugins.email.plugin import EmailPlugin
plugin = EmailPlugin()
# Inject a mock provider with known folders
mock_provider = MagicMock()
mock_provider.list_folders.return_value = ["INBOX", "Sent", "Trash"]
plugin._provider_instance = mock_provider
result = plugin._tool_move("uid123", "NonExistent", "INBOX")
assert "does not exist" in result.lower()
assert "email_create_folder" in result
mock_provider.move_message.assert_not_called()
def test_email_move_succeeds_when_folder_exists(tmp_pyra_home):
from pyra.bundled_plugins.email.plugin import EmailPlugin
plugin = EmailPlugin()
mock_provider = MagicMock()
mock_provider.list_folders.return_value = ["INBOX", "Work", "Newsletters"]
plugin._provider_instance = mock_provider
result = plugin._tool_move("uid456", "Work", "INBOX")
assert "moved" in result.lower()
mock_provider.move_message.assert_called_once_with("uid456", "INBOX", "Work")
# ── email_list_rules not-supported path ───────────────────────────────────────
def test_email_list_rules_not_supported(tmp_pyra_home):
from pyra.bundled_plugins.email.plugin import EmailPlugin
plugin = EmailPlugin()
mock_provider = MagicMock()
mock_provider.list_rules.side_effect = NotImplementedError
plugin._provider_instance = mock_provider
result = plugin._tool_list_rules()
assert "not supported" in result.lower()
# ── daemon/events integration ─────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_events_publish_and_subscribe():
from pyra.daemon import events
events.reset()
await events.publish({"type": "new_email", "subject": "Test"})
received = []
async for event in events.subscribe_forever():
received.append(event)
break # only need one
assert received[0]["type"] == "new_email"
assert received[0]["subject"] == "Test"
events.reset()
@pytest.mark.asyncio
async def test_events_queue_full_drops_silently():
from pyra.daemon import events
events.reset()
# Fill the queue
for i in range(200):
await events.publish({"n": i})
# This should not raise even though queue is full
await events.publish({"n": 999})
events.reset()
# ── ProtonMail Bridge connectivity check (mocked) ─────────────────────────────
def test_protonmail_setup_aborts_when_bridge_unreachable(tmp_pyra_home):
"""_setup_protonmail should abort gracefully when Bridge is not running."""
import socket
from unittest.mock import patch, MagicMock
from pyra.bundled_plugins.email.plugin import EmailPlugin
plugin = EmailPlugin()
console = MagicMock()
vault_writer = MagicMock()
with patch("socket.create_connection", side_effect=ConnectionRefusedError):
plugin._setup_protonmail(console, vault_writer, "user@proton.me")
# Should not store any vault key if Bridge is unreachable
vault_writer.assert_not_called()
# ── messaging bot recommendation ──────────────────────────────────────────────
def test_check_messaging_bot_warns_when_no_bot(tmp_pyra_home):
from pyra.bundled_plugins.email.plugin import EmailPlugin
from unittest.mock import MagicMock, patch
from pyra.config.schema import PyraConfig, ProviderConfig, PluginConfig
plugin = EmailPlugin()
console = MagicMock()
cfg = PyraConfig(ai=ProviderConfig(provider_id="lmstudio", model="test"))
cfg.plugins = PluginConfig(enabled=[]) # no bots
with patch("pyra.bundled_plugins.email.plugin.EmailPlugin._load_settings", return_value={}), \
patch("pyra.config.manager.load_config", return_value=cfg):
plugin._check_messaging_bot(console)
# Should have printed something (Panel) recommending a bot
console.print.assert_called()
# ── Tool list completeness ─────────────────────────────────────────────────────
def test_plugin_exposes_16_tools():
from pyra.bundled_plugins.email.plugin import EmailPlugin
plugin = EmailPlugin()
# on_load with no-op vault reader
plugin.on_load(lambda _: None)
tools = plugin.tools()
tool_names = [t.name for t in tools]
assert len(tools) == 16
expected = {
"email_list_folder", "email_read", "email_send", "email_reply",
"email_forward", "email_move", "email_delete", "email_mark_read",
"email_search", "email_list_folders", "email_create_folder",
"email_inbox_summary", "email_list_rules", "email_create_rule",
"email_delete_rule", "email_bulk_action",
}
assert set(tool_names) == expected
def test_write_tools_require_approval():
from pyra.bundled_plugins.email.plugin import EmailPlugin
plugin = EmailPlugin()
plugin.on_load(lambda _: None)
tools = {t.name: t for t in plugin.tools()}
for name in ["email_send", "email_reply", "email_forward", "email_move",
"email_delete", "email_create_folder", "email_create_rule",
"email_delete_rule", "email_bulk_action"]:
assert tools[name].requires_approval, f"{name} should require approval"
def test_read_tools_no_approval():
from pyra.bundled_plugins.email.plugin import EmailPlugin
plugin = EmailPlugin()
plugin.on_load(lambda _: None)
tools = {t.name: t for t in plugin.tools()}
for name in ["email_list_folder", "email_read", "email_mark_read",
"email_search", "email_list_folders", "email_inbox_summary",
"email_list_rules"]:
assert not tools[name].requires_approval, f"{name} should NOT require approval"
+349 -3
View File
@@ -96,14 +96,34 @@ def test_suggest_plugins_multiple_categories(monkeypatch):
# ── _fetch_local_models ──────────────────────────────────────────────────────── # ── _fetch_local_models ────────────────────────────────────────────────────────
def test_fetch_local_models_lmstudio_returns_model_ids(monkeypatch): def test_fetch_local_models_lmstudio_returns_loaded_model_ids(monkeypatch):
import pyra.setup.wizard as wiz import pyra.setup.wizard as wiz
mock_resp = MagicMock() mock_resp = MagicMock()
mock_resp.json.return_value = {"data": [{"id": "gemma-4b"}, {"id": "llama3"}]} mock_resp.json.return_value = {
"data": [
{"id": "gemma-4b", "state": "loaded"},
{"id": "llama3", "state": "not_loaded"},
]
}
mock_resp.raise_for_status = lambda: None mock_resp.raise_for_status = lambda: None
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp) monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
from pyra.setup.providers import get_provider from pyra.setup.providers import get_provider
assert wiz._fetch_local_models(get_provider("lmstudio")) == ["gemma-4b", "llama3"] assert wiz._fetch_local_models(get_provider("lmstudio")) == ["gemma-4b"]
def test_fetch_local_models_lmstudio_filters_unloaded(monkeypatch):
import pyra.setup.wizard as wiz
mock_resp = MagicMock()
mock_resp.json.return_value = {
"data": [
{"id": "model-a", "state": "not_loaded"},
{"id": "model-b", "state": "not_loaded"},
]
}
mock_resp.raise_for_status = lambda: None
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
from pyra.setup.providers import get_provider
assert wiz._fetch_local_models(get_provider("lmstudio")) == []
def test_fetch_local_models_ollama_returns_model_names(monkeypatch): def test_fetch_local_models_ollama_returns_model_names(monkeypatch):
@@ -181,3 +201,329 @@ def test_load_lmstudio_model_returns_false_on_exception(monkeypatch):
import pyra.setup.wizard as wiz import pyra.setup.wizard as wiz
monkeypatch.setattr(wiz.httpx, "post", MagicMock(side_effect=Exception("timeout"))) monkeypatch.setattr(wiz.httpx, "post", MagicMock(side_effect=Exception("timeout")))
assert wiz._load_lmstudio_model("gemma-4b") is False assert wiz._load_lmstudio_model("gemma-4b") is False
# ── _classify_error ───────────────────────────────────────────────────────────
def _fake_llm_exc(name: str) -> Exception:
"""Create a fake litellm exception with the given class name."""
cls = type(name, (Exception,), {"__module__": "litellm.exceptions"})
return cls("test error")
def test_classify_auth_error():
from pyra.setup.wizard import _classify_error
label, hint = _classify_error(_fake_llm_exc("AuthenticationError"))
assert "key" in label.lower()
assert "provider" in hint.lower() or "dashboard" in hint.lower()
def test_classify_not_found_error():
from pyra.setup.wizard import _classify_error
label, hint = _classify_error(_fake_llm_exc("NotFoundError"))
assert "model" in label.lower()
assert "model" in hint.lower()
def test_classify_rate_limit_error():
from pyra.setup.wizard import _classify_error
label, _ = _classify_error(_fake_llm_exc("RateLimitError"))
assert "rate" in label.lower()
def test_classify_service_unavailable():
from pyra.setup.wizard import _classify_error
label, _ = _classify_error(_fake_llm_exc("ServiceUnavailableError"))
assert "unavailable" in label.lower()
def test_classify_api_connection_error():
from pyra.setup.wizard import _classify_error
label, hint = _classify_error(_fake_llm_exc("APIConnectionError"))
assert "reach" in label.lower() or "network" in label.lower() or "connect" in label.lower()
assert "internet" in hint.lower() or "network" in hint.lower()
def test_classify_timeout_error():
from pyra.setup.wizard import _classify_error
label, _ = _classify_error(_fake_llm_exc("TimeoutError"))
assert "time" in label.lower()
def test_classify_bad_request_error():
from pyra.setup.wizard import _classify_error
label, hint = _classify_error(_fake_llm_exc("BadRequestError"))
assert "request" in label.lower() or "bad" in label.lower()
assert "model" in hint.lower()
def test_classify_httpx_connect_error():
import httpx
from pyra.setup.wizard import _classify_error
label, hint = _classify_error(httpx.ConnectError("refused"))
assert "reach" in label.lower() or "reachable" in label.lower()
assert "running" in hint.lower() or "listening" in hint.lower()
def test_classify_httpx_timeout():
import httpx
from pyra.setup.wizard import _classify_error
label, _ = _classify_error(httpx.TimeoutException("timeout"))
assert "time" in label.lower()
def test_classify_generic_error():
from pyra.setup.wizard import _classify_error
label, hint = _classify_error(ValueError("something went wrong"))
assert len(label) > 0
assert "something went wrong" in hint
def test_classify_error_always_returns_two_strings():
from pyra.setup.wizard import _classify_error
for exc in (RuntimeError("boom"), OSError("disk"), KeyError("k")):
label, hint = _classify_error(exc)
assert isinstance(label, str) and len(label) > 0
assert isinstance(hint, str) and len(hint) > 0
# ── _check_local_server retry behaviour ──────────────────────────────────────
def _silent_print(*a, **kw):
pass
def test_check_local_server_success(monkeypatch):
import pyra.setup.wizard as wiz
mock_resp = MagicMock()
mock_resp.raise_for_status = lambda: None
monkeypatch.setattr(wiz.httpx, "get", lambda *a, **kw: mock_resp)
monkeypatch.setattr(wiz.console, "print", _silent_print)
from pyra.setup.providers import get_provider
wiz._check_local_server(get_provider("lmstudio")) # must not raise
def test_check_local_server_retry_then_success(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
call_count = {"n": 0}
def flaky_get(*a, **kw):
call_count["n"] += 1
if call_count["n"] == 1:
raise ConnectionError("refused")
m = MagicMock()
m.raise_for_status = lambda: None
return m
monkeypatch.setattr(wiz.httpx, "get", flaky_get)
monkeypatch.setattr(wiz.console, "print", _silent_print)
monkeypatch.setattr(wiz.questionary, "select",
lambda *a, **kw: MagicMock(ask=lambda: "retry"))
wiz._check_local_server(get_provider("lmstudio"))
assert call_count["n"] == 2
def test_check_local_server_abort_raises_system_exit(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=ConnectionError("refused")))
monkeypatch.setattr(wiz.console, "print", _silent_print)
monkeypatch.setattr(wiz.questionary, "select",
lambda *a, **kw: MagicMock(ask=lambda: "abort"))
with pytest.raises(SystemExit):
wiz._check_local_server(get_provider("lmstudio"))
def test_check_local_server_continue_returns(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=ConnectionError("refused")))
monkeypatch.setattr(wiz.console, "print", _silent_print)
monkeypatch.setattr(wiz.questionary, "select",
lambda *a, **kw: MagicMock(ask=lambda: "continue"))
wiz._check_local_server(get_provider("lmstudio")) # must return without raising
# ── draft persistence ─────────────────────────────────────────────────────────
def test_save_and_load_draft(tmp_pyra_home):
import pyra.setup.wizard as wiz
state = {"completed_steps": ["profile"], "user_name": "Alice"}
wiz._save_draft(state)
loaded = wiz._load_draft()
assert loaded == state
def test_load_draft_returns_none_when_no_file(tmp_pyra_home):
import pyra.setup.wizard as wiz
assert wiz._load_draft() is None
def test_delete_draft_removes_file(tmp_pyra_home):
import pyra.setup.wizard as wiz
wiz._save_draft({"completed_steps": []})
wiz._delete_draft()
assert wiz._load_draft() is None
def test_delete_draft_is_idempotent(tmp_pyra_home):
import pyra.setup.wizard as wiz
wiz._delete_draft()
wiz._delete_draft()
def test_mark_step_done_appends_once(tmp_pyra_home):
import pyra.setup.wizard as wiz
state = {}
wiz._mark_step_done(state, "profile")
wiz._mark_step_done(state, "profile")
assert state["completed_steps"].count("profile") == 1
def test_draft_file_has_correct_permissions(tmp_pyra_home):
import os
import pyra.setup.wizard as wiz
wiz._save_draft({"completed_steps": []})
path = wiz._draft_path()
if os.name != "nt":
mode = oct(path.stat().st_mode)[-3:]
assert mode == "600"
# ── fetch_loaded_models ───────────────────────────────────────────────────────
def test_fetch_loaded_models_ollama_uses_api_ps(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
mock_resp = MagicMock()
mock_resp.json.return_value = {"models": [{"name": "llama3:latest"}, {"name": "mistral"}]}
mock_resp.raise_for_status = lambda: None
calls = []
monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: (calls.append(url), mock_resp)[1])
result = wiz.fetch_loaded_models(get_provider("ollama"))
assert result == ["llama3:latest", "mistral"]
assert any("/api/ps" in u for u in calls)
def test_fetch_loaded_models_lmstudio_uses_beta_api_and_filters(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
mock_resp = MagicMock()
mock_resp.json.return_value = {
"data": [
{"id": "gemma-4b", "state": "loaded"},
{"id": "llama3", "state": "not_loaded"},
]
}
mock_resp.raise_for_status = lambda: None
calls = []
monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: (calls.append(url), mock_resp)[1])
result = wiz.fetch_loaded_models(get_provider("lmstudio"))
assert result == ["gemma-4b"]
assert any("/api/v0/models" in u for u in calls)
def test_fetch_loaded_models_lmstudio_filters_unloaded(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
mock_resp = MagicMock()
mock_resp.json.return_value = {
"data": [
{"id": "model-a", "state": "not_loaded"},
{"id": "model-b", "state": "not_loaded"},
]
}
mock_resp.raise_for_status = lambda: None
monkeypatch.setattr(wiz.httpx, "get", lambda url, **kw: mock_resp)
assert wiz.fetch_loaded_models(get_provider("lmstudio")) == []
def test_fetch_loaded_models_returns_empty_on_error(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz.httpx, "get", MagicMock(side_effect=Exception("conn refused")))
assert wiz.fetch_loaded_models(get_provider("ollama")) == []
def test_fetch_loaded_models_returns_empty_when_no_base_url():
import pyra.setup.wizard as wiz
from pyra.setup.providers import Provider
provider = Provider(
id="test", display_name="Test", requires_key=False,
default_model="x", litellm_prefix="openai/", group="Local",
)
assert wiz.fetch_loaded_models(provider) == []
# ── _show_local_model_status ──────────────────────────────────────────────────
def test_show_local_model_status_one_model(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: ["gemma-4b"])
printed = []
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
wiz._show_local_model_status(get_provider("lmstudio"))
assert any("gemma-4b" in s for s in printed)
def test_show_local_model_status_none(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: [])
printed = []
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
wiz._show_local_model_status(get_provider("lmstudio"))
assert any("No model" in s for s in printed)
def test_show_local_model_status_multiple(monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
monkeypatch.setattr(wiz, "fetch_loaded_models", lambda p: ["a", "b", "c"])
printed = []
monkeypatch.setattr(wiz.console, "print", lambda *a, **kw: printed.append(str(a)))
wiz._show_local_model_status(get_provider("lmstudio"))
combined = " ".join(printed)
assert "3" in combined or ("a" in combined and "b" in combined)
# ── _test_connection model re-entry ───────────────────────────────────────────
def test_test_connection_change_model(tmp_pyra_home, monkeypatch):
import pyra.setup.wizard as wiz
from pyra.setup.providers import get_provider
call_count = {"n": 0}
class FakeNotFound(Exception):
pass
FakeNotFound.__module__ = "litellm.exceptions"
FakeNotFound.__name__ = "NotFoundError"
def fake_completion(**kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
raise FakeNotFound("model not found")
import litellm
monkeypatch.setattr(litellm, "completion", fake_completion)
monkeypatch.setattr(wiz.console, "print", _silent_print)
monkeypatch.setattr(wiz.questionary, "select",
lambda *a, **kw: MagicMock(ask=lambda: "change_model"))
monkeypatch.setattr(wiz.questionary, "text",
lambda *a, **kw: MagicMock(ask=lambda: "new-model"))
monkeypatch.setattr(wiz, "_fetch_local_models", lambda p: [])
monkeypatch.setattr("pyra.vault.reader.get_key", lambda k: "sk-test")
provider = get_provider("anthropic")
result = wiz._test_connection(provider, "old-model")
assert result == "new-model"
assert call_count["n"] == 2