diff --git a/CLAUDE.md b/CLAUDE.md index 85cd00b..2f42633 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -8,7 +8,8 @@ a plugin/integration system (Stage 2+) and an encrypted vault (Stage 3+). ## Current Status **Stage 3 — Memory Database: complete** (2026-05-18) -Next: Stage 4 — Vault Encryption +**Stage 6 — Daemon infrastructure: in progress** (`feat/daemon` branch) +Next: Stage 4 — Vault Encryption (skipped for now); messaging bots (Stage 6 remainder) ## Project Roadmap @@ -19,11 +20,11 @@ memory in `~/.pyra/memory/`, and hard security boundaries around the vault. ### Stage 2 — Plugin Framework ✅ COMPLETE - `src/pyra/plugins/` package: `base.py`, `loader.py`, `registry.py`, `executor.py`, `install.py` - `src/pyra/bundled_plugins/` — ships bundled plugin scripts with pyra -- `src/pyra/daemon/` stub (CLI surface only) +- `src/pyra/daemon/` stub (CLI surface only; daemon itself is Stage 6) - Config: `PluginConfig` + `DaemonConfig` added to `PyraConfig` - Bootstrap: `~/.pyra/plugins/` and `~/.pyra/logs/` created on startup - Chat session: AI tool-use loop (up to 10 iterations), approval gate, plugin slash commands -- CLI: `pyra plugin list/install/enable/disable/setup`, `pyra daemon *` stubs +- CLI: `pyra plugin list/install/enable/disable/setup`, `pyra daemon *` (stubs at Stage 2; implemented in Stage 6) ### Stage 3 — Memory Database ✅ COMPLETE - `src/pyra/memory/database.py`: SQLite + FTS5 via `memory_meta` + `memory_fts` tables @@ -117,7 +118,11 @@ the vault under namespaced keys (`plugin:{name}:{key}`). | `plugins/executor.py` | Approval gate: scan args → prompt → execute → scan result → log | | `plugins/install.py` | Copies bundled plugins to `~/.pyra/plugins/` | | `bundled_plugins/` | Standalone plugin scripts shipped with pyra (installed on demand) | -| `daemon/__init__.py` | Daemon package stub (implementation in Stage 2.4) | +| `daemon/pid.py` | Atomic PID file — write, read, stale detection (POSIX + Windows), context manager | +| `daemon/ipc.py` | IPC transport — Unix socket chmod 600 + UID-check (Linux/macOS) or TCP loopback + port file (Windows); newline-delimited JSON protocol | +| `daemon/service.py` | OS service file generation + install/uninstall — launchd plist (macOS), systemd user unit (Linux), schtasks XML (Windows) | +| `daemon/core.py` | asyncio event loop entry point, `PluginSupervisor` (per-task restart, max 10×, 5s back-off, reload), IPC command dispatch, signal handling | +| `daemon/__init__.py` | Public daemon API exports | ### Runtime: `~/.pyra/` @@ -493,6 +498,40 @@ Dataclass: `MemoryFile(name, path, category, size_bytes, modified)` | `list_bundled_plugins` | `(bundled_dir: Path) -> list[str]` | Names of all bundled plugins that have a `manifest.json` | | `read_manifest` | `(plugin_dir: Path) -> dict` | Reads `manifest.json`; returns `{}` if missing | +#### `daemon.core` + +| Function | Signature | Purpose | +|----------|-----------|---------| +| `run_foreground` | `() -> None` | Entry point for `pyra daemon run` — loads config + plugins, writes PID file, runs asyncio loop | +| `start_background` | `() -> None` | Spawns `pyra daemon run` as a detached subprocess (`start_new_session` on POSIX, `DETACHED_PROCESS` on Windows) | + +#### `daemon.pid` + +| Function | Signature | Purpose | +|----------|-----------|---------| +| `resolve_pid_path` | `(cfg_path: str) -> Path` | Expand `~` and resolve to absolute Path | + +#### `daemon.ipc` + +| Function | Signature | Purpose | +|----------|-----------|---------| +| `send_command` | `(address, msg, timeout=5.0) -> IpcResponse` | Synchronous CLI helper — `asyncio.run(IpcClient.send(...))` | +| `get_socket_path` | `(cfg: str) -> Path` | Expand `~` and return Unix socket path | +| `is_unix_socket` | `() -> bool` | True on Linux/macOS (`sys.platform != 'nt'`) | +| `get_port_file_path` | `() -> Path` | Path to `~/.pyra/daemon.port` (Windows TCP port file) | + +#### `daemon.service` + +| Function | Signature | Purpose | +|----------|-----------|---------| +| `detect_platform` | `() -> Literal["macos","linux","windows"]` | Detect current OS | +| `find_pyra_executable` | `() -> str` | `shutil.which("pyra")` → sibling fallback → `sys.executable -m pyra` | +| `install_service` | `() -> None` | Generate + register OS service (reads config for log/pid paths) | +| `uninstall_service` | `() -> None` | Deregister OS service | +| `render_launchd_plist` | `(exe, log_file, pid_file) -> str` | macOS plist template | +| `render_systemd_unit` | `(exe, log_file) -> str` | Linux systemd unit template | +| `render_schtasks_xml` | `(exe) -> str` | Windows Task Scheduler XML template (write as UTF-16) | + #### `chat.renderer` — rendering functions and shared `console` Import `console` from here; do not create a second `rich.Console()` in new code. @@ -515,7 +554,7 @@ Import `console` from here; do not create a second `rich.Console()` in new code. | `GeneralConfig` | `config.schema` | `general:` block — `user_name`, `assistant_name` | | `ProviderConfig` | `config.schema` | `ai:` block — `provider_id`, `model`, `base_url` | | `PluginConfig` | `config.schema` | `plugins:` block — `enabled`, `require_approval`, `log_executions` | -| `DaemonConfig` | `config.schema` | `daemon:` block | +| `DaemonConfig` | `config.schema` | `daemon:` block — `enabled`, `socket_path`, `log_file`, `pid_file`, `ipc_port` | | `MemoryConfig` | `config.schema` | `memory:` block — `max_tokens_in_context`, `auto_load` | | `SecurityConfig` | `config.schema` | `security:` block — `injection_detection`, `log_injections` | | `ConversationHistory` | `chat.history` | Holds message list; builds API payload via `build_for_api()`; trims to token budget | @@ -526,3 +565,5 @@ Import `console` from here; do not create a second `rich.Console()` in new code. | `PyraPlugin` | `plugins.base` | `@runtime_checkable` Protocol — the plugin interface | | `BasePlugin` | `plugins.base` | Concrete base with no-op defaults; plugins should inherit this | | `TaskPlanner` | `chat.planner` | Multi-step plan runner; `make_tool_handler()` returns the callable wired into the chat session; presents plan for user approval, executes each step via litellm with up to 5 tool-use iterations, verifies output before proceeding | +| `PluginSupervisor` | `daemon.core` | asyncio supervisor — `add_task(name, factory)`, `start()`, `stop()`, `reload()`, `status()`; restarts crashed tasks up to 10× with 5s back-off | +| `PidFile` | `daemon.pid` | `write()` (atomic), `read()`, `is_stale()`, `remove()`, context manager; `PidFileError(OSError)` raised when live PID already exists | diff --git a/src/pyra/cli.py b/src/pyra/cli.py index d4f53da..a6456c8 100644 --- a/src/pyra/cli.py +++ b/src/pyra/cli.py @@ -266,43 +266,144 @@ def daemon() -> None: _bootstrap_or_exit() +@daemon.command("run", hidden=True) +def daemon_run() -> None: + """Run daemon in foreground (used by service manager).""" + from pyra.daemon.core import run_foreground + run_foreground() + + @daemon.command("start") def daemon_start() -> None: """Start the Pyra daemon in the background.""" - console.print("[yellow]Daemon (Stage 6) is not yet implemented.[/yellow]") + from pyra.daemon.core import start_background + try: + start_background() + except FileNotFoundError: + console.print("[red]Error:[/red] Run [dim]pyra setup[/dim] first.") @daemon.command("stop") def daemon_stop() -> None: """Stop the running Pyra daemon.""" - console.print("[yellow]Daemon (Stage 6) is not yet implemented.[/yellow]") + _daemon_ipc("stop", success_msg="Daemon stopped.") @daemon.command("status") def daemon_status() -> None: """Show daemon status.""" - console.print("[yellow]Daemon (Stage 6) is not yet implemented.[/yellow]") + _daemon_ipc("status") @daemon.command("restart") def daemon_restart() -> None: """Restart the Pyra daemon.""" - console.print("[yellow]Daemon (Stage 6) is not yet implemented.[/yellow]") + import time + from pyra.daemon.core import start_background + _daemon_ipc("stop", success_msg=None) + time.sleep(1.5) + try: + start_background() + except FileNotFoundError: + console.print("[red]Error:[/red] Run [dim]pyra setup[/dim] first.") @daemon.command("install") def daemon_install() -> None: - """Install Pyra as a system service (launchd/systemd).""" - console.print("[yellow]Daemon service install (Stage 6) is not yet implemented.[/yellow]") + """Install Pyra as a system service (launchd/systemd/schtasks).""" + from pyra.daemon.service import detect_platform, install_service + try: + install_service() + console.print(f"[green]Service installed[/green] ({detect_platform()}).") + except Exception as exc: + console.print(f"[red]Install failed:[/red] {exc}") @daemon.command("uninstall") def daemon_uninstall() -> None: """Remove the Pyra system service.""" - console.print("[yellow]Daemon service uninstall (Stage 6) is not yet implemented.[/yellow]") + from pyra.daemon.service import uninstall_service + try: + uninstall_service() + console.print("[green]Service removed.[/green]") + except Exception as exc: + console.print(f"[red]Uninstall failed:[/red] {exc}") -@daemon.command("run", hidden=True) -def daemon_run() -> None: - """Run daemon in foreground (used by service manager).""" - console.print("[yellow]Daemon (Stage 6) is not yet implemented.[/yellow]") +def _daemon_ipc(cmd: str, *, success_msg: str | None = None) -> None: + """Send a command to the running daemon via IPC and render the response.""" + from pyra.config.manager import load_config + from pyra.daemon.ipc import ( + get_socket_path, + is_unix_socket, + get_port_file_path, + send_command, + ) + + try: + cfg = load_config() + except FileNotFoundError: + console.print("[red]Error:[/red] Run [dim]pyra setup[/dim] first.") + return + + if is_unix_socket(): + address = get_socket_path(cfg.daemon.socket_path) + else: + port = _read_windows_port() + if port is None: + console.print("[yellow]Daemon is not running.[/yellow]") + return + address = ("127.0.0.1", port) + + try: + resp = send_command(address, {"cmd": cmd}) + except (ConnectionRefusedError, FileNotFoundError, OSError): + console.print("[yellow]Daemon is not running.[/yellow]") + return + except ConnectionResetError: + console.print("[red]Permission denied:[/red] daemon rejected connection.") + return + except TimeoutError: + console.print("[red]Daemon did not respond in time.[/red]") + return + + if not resp.get("ok"): + console.print(f"[red]Error:[/red] {resp.get('data', {}).get('error', 'unknown')}") + return + + if cmd == "status": + _render_daemon_status(resp["data"]) + elif success_msg: + console.print(f"[green]{success_msg}[/green]") + + +def _read_windows_port() -> int | None: + from pyra.daemon.ipc import get_port_file_path + try: + return int(get_port_file_path().read_text().strip()) + except (FileNotFoundError, ValueError): + return None + + +def _render_daemon_status(data: dict) -> None: + from rich.table import Table + + uptime = data.get("uptime", 0.0) + pid = data.get("pid", "?") + tasks = data.get("tasks", []) + + hours, rem = divmod(int(uptime), 3600) + mins, secs = divmod(rem, 60) + uptime_str = f"{hours}h {mins}m {secs}s" if hours else f"{mins}m {secs}s" + + console.print(f"[bold green]Daemon running[/bold green] — PID {pid}, uptime {uptime_str}") + + if tasks: + table = Table("Task", "Alive", "Restarts", "Last error", show_header=True) + for t in tasks: + alive = "[green]yes[/green]" if t.get("alive") else "[red]no[/red]" + error = t.get("last_error") or "—" + table.add_row(t.get("name", "?"), alive, str(t.get("restart_count", 0)), error) + console.print(table) + else: + console.print("[dim]No plugin tasks registered.[/dim]") diff --git a/src/pyra/config/schema.py b/src/pyra/config/schema.py index a945442..03ff51f 100644 --- a/src/pyra/config/schema.py +++ b/src/pyra/config/schema.py @@ -36,6 +36,7 @@ class DaemonConfig(BaseModel): socket_path: str = "~/.pyra/daemon.sock" log_file: str = "~/.pyra/daemon.log" pid_file: str = "~/.pyra/daemon.pid" + ipc_port: int = 0 # Windows TCP loopback: 0 = OS-assigned, written to ~/.pyra/daemon.port class PyraConfig(BaseModel): diff --git a/src/pyra/daemon/__init__.py b/src/pyra/daemon/__init__.py index e69de29..2f10e68 100644 --- a/src/pyra/daemon/__init__.py +++ b/src/pyra/daemon/__init__.py @@ -0,0 +1,21 @@ +"""Pyra background daemon package.""" + +from pyra.daemon.core import PluginSupervisor, run_foreground, start_background +from pyra.daemon.ipc import IpcClient, IpcServer, send_command +from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path +from pyra.daemon.service import detect_platform, install_service, uninstall_service + +__all__ = [ + "run_foreground", + "start_background", + "PluginSupervisor", + "IpcClient", + "IpcServer", + "send_command", + "PidFile", + "PidFileError", + "resolve_pid_path", + "detect_platform", + "install_service", + "uninstall_service", +] diff --git a/src/pyra/daemon/core.py b/src/pyra/daemon/core.py new file mode 100644 index 0000000..d516d14 --- /dev/null +++ b/src/pyra/daemon/core.py @@ -0,0 +1,313 @@ +"""Pyra daemon core — asyncio event loop, plugin task supervisor, signal handling.""" + +from __future__ import annotations + +import asyncio +import logging +import logging.handlers +import os +import signal +import subprocess +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable, Coroutine + +from pyra.utils.paths import pyra_home, safe_chmod + + +_log = logging.getLogger("pyra.daemon") +_start_time: float = 0.0 + + +# ── Plugin task supervisor ──────────────────────────────────────────────────── + +@dataclass +class TaskRecord: + name: str + coro_factory: Callable[[], Coroutine] # type: ignore[type-arg] + task: asyncio.Task | None = field(default=None, repr=False) + restart_count: int = 0 + last_error: str | None = None + + def is_alive(self) -> bool: + return self.task is not None and not self.task.done() + + +class PluginSupervisor: + _RESTART_DELAY: float = 5.0 + _MAX_RESTARTS: int = 10 + + def __init__(self) -> None: + self._records: list[TaskRecord] = [] + self._shutdown = asyncio.Event() + + def add_task(self, name: str, factory: Callable[[], Coroutine]) -> None: # type: ignore[type-arg] + self._records.append(TaskRecord(name=name, coro_factory=factory)) + + async def start(self) -> None: + for record in self._records: + record.task = asyncio.create_task( + self._supervise(record), name=record.name + ) + _log.info("Supervisor started with %d plugin task(s).", len(self._records)) + + async def run_until_shutdown(self) -> None: + await self._shutdown.wait() + _log.info("Shutdown requested — stopping supervisor.") + + def request_shutdown(self) -> None: + self._shutdown.set() + + async def stop(self) -> None: + for record in self._records: + if record.task and not record.task.done(): + record.task.cancel() + tasks = [r.task for r in self._records if r.task] + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + async def reload(self) -> None: + """Cancel all running tasks and restart them with fresh coroutines.""" + for record in self._records: + if record.task and not record.task.done(): + record.task.cancel() + try: + await record.task + except (asyncio.CancelledError, Exception): + pass + record.restart_count = 0 + record.last_error = None + record.task = asyncio.create_task( + self._supervise(record), name=record.name + ) + _log.info("Reloaded %d plugin task(s).", len(self._records)) + + def status(self) -> list[dict]: + return [ + { + "name": r.name, + "alive": r.is_alive(), + "restart_count": r.restart_count, + "last_error": r.last_error, + } + for r in self._records + ] + + async def _supervise(self, record: TaskRecord) -> None: + while not self._shutdown.is_set(): + try: + await record.coro_factory() + _log.info("Plugin task %s completed normally.", record.name) + return + except asyncio.CancelledError: + return + except Exception as exc: + record.restart_count += 1 + record.last_error = f"{type(exc).__name__}: {exc}" + _log.error( + "Plugin task %s crashed (restart #%d): %s", + record.name, record.restart_count, exc, + exc_info=True, + ) + if record.restart_count >= self._MAX_RESTARTS: + _log.critical( + "Plugin task %s exceeded max restarts (%d). Giving up.", + record.name, self._MAX_RESTARTS, + ) + return + try: + await asyncio.wait_for( + asyncio.sleep(self._RESTART_DELAY), + timeout=self._RESTART_DELAY + 1, + ) + except asyncio.CancelledError: + return + + +# ── IPC command dispatch ────────────────────────────────────────────────────── + +def _make_ipc_handler(supervisor: PluginSupervisor): + async def handler(msg: dict) -> dict: + cmd = msg.get("cmd", "") + match cmd: + case "ping": + return {"ok": True, "data": {"pong": True}} + case "status": + return { + "ok": True, + "data": { + "uptime": time.monotonic() - _start_time, + "pid": os.getpid(), + "tasks": supervisor.status(), + }, + } + case "stop": + supervisor.request_shutdown() + return {"ok": True, "data": {}} + case "reload": + await supervisor.reload() + return {"ok": True, "data": {"tasks_reloaded": len(supervisor._records)}} + case _: + return {"ok": False, "data": {"error": f"unknown command: {cmd}"}} + + return handler + + +# ── Main async entrypoint ───────────────────────────────────────────────────── + +async def _run_daemon(cfg, supervisor: PluginSupervisor) -> None: + from pyra.daemon.ipc import IpcServer, get_socket_path, is_unix_socket + + # Install signal handlers now that the event loop is running. + _install_signal_handlers(supervisor) + + if is_unix_socket(): + address = get_socket_path(cfg.daemon.socket_path) + else: + address = ("127.0.0.1", cfg.daemon.ipc_port) + + server = IpcServer(address, _make_ipc_handler(supervisor)) + + await supervisor.start() + + async with asyncio.TaskGroup() as tg: + tg.create_task(server.start(), name="ipc_server") + tg.create_task(supervisor.run_until_shutdown(), name="shutdown_waiter") + + await server.stop() + await supervisor.stop() + + +# ── Foreground entry point (pyra daemon run) ────────────────────────────────── + +def run_foreground() -> None: + """Run the daemon in the foreground. Called by `pyra daemon run`.""" + from pyra.config.manager import load_config + from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path + from pyra.plugins.registry import PluginRegistry + + global _start_time + + cfg = load_config() + _setup_logging(cfg.daemon.log_file) + pid_path = resolve_pid_path(cfg.daemon.pid_file) + pid_file = PidFile(pid_path) + + existing = pid_file.read() + if existing is not None and not pid_file.is_stale(): + _log.error("Daemon already running (PID %d). Exiting.", existing) + sys.exit(1) + + registry = PluginRegistry() + from pyra.utils.paths import pyra_home as _pyra_home + plugins_dir = _pyra_home() / "plugins" + if plugins_dir.exists(): + registry.load_all(plugins_dir, cfg.plugins.enabled) + + supervisor = PluginSupervisor() + for name, factory in registry.get_daemon_task_factories(): + supervisor.add_task(name, factory) + + _start_time = time.monotonic() + + try: + with pid_file: + _log.info("Pyra daemon starting (PID %d).", os.getpid()) + try: + asyncio.run(_run_daemon(cfg, supervisor)) + except KeyboardInterrupt: + pass + _log.info("Pyra daemon stopped.") + except PidFileError as exc: + _log.error("Could not acquire PID file: %s", exc) + sys.exit(1) + + +# ── Background spawn (pyra daemon start) ───────────────────────────────────── + +def start_background() -> None: + """Spawn `pyra daemon run` as a detached background process.""" + from pyra.config.manager import load_config + from pyra.daemon.pid import PidFile, resolve_pid_path + from pyra.daemon.service import find_pyra_executable + + cfg = load_config() + pid_path = resolve_pid_path(cfg.daemon.pid_file) + pid_file = PidFile(pid_path) + + existing = pid_file.read() + if existing is not None and not pid_file.is_stale(): + from pyra.chat.renderer import console + console.print(f"[yellow]Daemon already running (PID {existing}).[/yellow]") + return + + exe = find_pyra_executable() + log_path = Path(cfg.daemon.log_file).expanduser() + log_path.parent.mkdir(parents=True, exist_ok=True) + + with open(log_path, "a") as log_fh: + if sys.platform == "win32": + DETACHED_PROCESS = 0x00000008 + CREATE_NEW_PROCESS_GROUP = 0x00000200 + subprocess.Popen( + [exe, "daemon", "run"], + creationflags=DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP, + stdout=log_fh, + stderr=log_fh, + close_fds=True, + ) + else: + subprocess.Popen( + [exe, "daemon", "run"], + start_new_session=True, + stdout=log_fh, + stderr=log_fh, + stdin=subprocess.DEVNULL, + close_fds=True, + ) + + from pyra.chat.renderer import console + + # Poll the PID file for up to 3 seconds to confirm startup. + for _ in range(30): + time.sleep(0.1) + pid = pid_file.read() + if pid is not None: + console.print(f"[green]Daemon started (PID {pid}).[/green]") + return + + console.print("[yellow]Daemon process spawned but PID file not yet written.[/yellow]") + + +# ── Signal handling ─────────────────────────────────────────────────────────── + +def _install_signal_handlers(supervisor: PluginSupervisor) -> None: + if sys.platform == "win32": + signal.signal(signal.SIGTERM, lambda *_: supervisor.request_shutdown()) + return + + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, supervisor.request_shutdown) + loop.add_signal_handler(signal.SIGHUP, supervisor.request_shutdown) + + +# ── Logging setup ───────────────────────────────────────────────────────────── + +def _setup_logging(log_file_str: str) -> None: + log_path = Path(log_file_str).expanduser() + log_path.parent.mkdir(parents=True, exist_ok=True) + + handler = logging.handlers.RotatingFileHandler( + log_path, maxBytes=5 * 1024 * 1024, backupCount=3 + ) + handler.setFormatter( + logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s") + ) + + root = logging.getLogger("pyra") + root.addHandler(handler) + root.setLevel(logging.INFO) + + safe_chmod(log_path, 0o600) diff --git a/src/pyra/daemon/ipc.py b/src/pyra/daemon/ipc.py new file mode 100644 index 0000000..8c14bd2 --- /dev/null +++ b/src/pyra/daemon/ipc.py @@ -0,0 +1,241 @@ +"""IPC transport for the Pyra daemon. + +Linux/macOS: Unix domain socket at ~/.pyra/daemon.sock (chmod 600, UID-checked). +Windows: TCP loopback on an OS-assigned port; actual port written to + ~/.pyra/daemon.port so clients can connect without knowing it ahead + of time. +""" + +from __future__ import annotations + +import asyncio +import json +import os +import struct +import sys +from pathlib import Path +from typing import Any, Awaitable, Callable + + +# ── Protocol types ──────────────────────────────────────────────────────────── + +IpcMessage = dict[str, Any] # must have "cmd" key +IpcResponse = dict[str, Any] # must have "ok" and "data" keys + + +# ── Encode / decode ─────────────────────────────────────────────────────────── + +def encode_message(msg: IpcMessage) -> bytes: + return (json.dumps(msg) + "\n").encode() + + +def decode_message(line: bytes) -> IpcMessage: + try: + return json.loads(line.rstrip(b"\n")) + except json.JSONDecodeError as exc: + raise ValueError(f"Invalid IPC message: {exc}") from exc + + +# ── Address helpers ─────────────────────────────────────────────────────────── + +def is_unix_socket() -> bool: + return sys.platform != "win32" + + +def get_socket_path(cfg_socket_path: str) -> Path: + """Expand ~ and return the Unix socket path.""" + return Path(cfg_socket_path).expanduser() + + +def get_port_file_path() -> Path: + from pyra.utils.paths import pyra_home + return pyra_home() / "daemon.port" + + +def _read_windows_port() -> int | None: + port_file = get_port_file_path() + try: + return int(port_file.read_text().strip()) + except (FileNotFoundError, ValueError): + return None + + +# ── Server ──────────────────────────────────────────────────────────────────── + +class IpcServer: + def __init__( + self, + address: Path | tuple[str, int], + handler: Callable[[IpcMessage], Awaitable[IpcResponse]], + ) -> None: + self._address = address + self._handler = handler + self._server: asyncio.Server | None = None + + async def start(self) -> None: + if is_unix_socket(): + assert isinstance(self._address, Path) + sock_path = self._address + if sock_path.exists(): + sock_path.unlink() + self._server = await asyncio.start_unix_server( + self._handle_client, path=str(sock_path) + ) + os.chmod(sock_path, 0o600) + else: + host, port = self._address if isinstance(self._address, tuple) else ("127.0.0.1", 0) + self._server = await asyncio.start_server( + self._handle_client, host=host, port=port + ) + actual_port = self._server.sockets[0].getsockname()[1] + port_file = get_port_file_path() + port_file.write_text(str(actual_port)) + + await self._server.start_serving() + + async def stop(self) -> None: + if self._server is not None: + self._server.close() + try: + await asyncio.wait_for(self._server.wait_closed(), timeout=5.0) + except asyncio.TimeoutError: + pass + if is_unix_socket() and isinstance(self._address, Path): + try: + self._address.unlink() + except FileNotFoundError: + pass + else: + port_file = get_port_file_path() + try: + port_file.unlink() + except FileNotFoundError: + pass + + async def _handle_client( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: + try: + if is_unix_socket() and not self._check_peer_uid(writer): + writer.close() + return + + line = await asyncio.wait_for(reader.readline(), timeout=5.0) + if not line: + return + + try: + msg = decode_message(line) + except ValueError: + resp: IpcResponse = {"ok": False, "data": {"error": "invalid JSON"}} + else: + resp = await self._handler(msg) + + writer.write(encode_message(resp)) + await writer.drain() + except (asyncio.TimeoutError, ConnectionResetError, BrokenPipeError): + pass + finally: + try: + writer.close() + await writer.wait_closed() + except Exception: + pass + + def _check_peer_uid(self, writer: asyncio.StreamWriter) -> bool: + """Return True if the peer's UID matches ours. Falls back to True on error.""" + try: + peer_uid = _get_peer_uid(writer) + if peer_uid is None: + return True # can't determine — allow (socket perms are the guard) + return peer_uid == os.getuid() + except Exception: + return True # don't crash the server on unexpected errors + + +# ── Client ──────────────────────────────────────────────────────────────────── + +class IpcClient: + def __init__(self, address: Path | tuple[str, int]) -> None: + self._address = address + + async def send(self, msg: IpcMessage, timeout: float = 5.0) -> IpcResponse: + if is_unix_socket(): + assert isinstance(self._address, Path) + reader, writer = await asyncio.wait_for( + asyncio.open_unix_connection(str(self._address)), timeout=timeout + ) + else: + host, port = self._address if isinstance(self._address, tuple) else ("127.0.0.1", 0) + reader, writer = await asyncio.wait_for( + asyncio.open_connection(host, port), timeout=timeout + ) + + try: + writer.write(encode_message(msg)) + await writer.drain() + line = await asyncio.wait_for(reader.readline(), timeout=timeout) + return decode_message(line) + finally: + writer.close() + try: + await writer.wait_closed() + except Exception: + pass + + +def send_command( + address: Path | tuple[str, int], + msg: IpcMessage, + timeout: float = 5.0, +) -> IpcResponse: + """Synchronous wrapper around IpcClient.send() for CLI callers.""" + return asyncio.run(IpcClient(address).send(msg, timeout=timeout)) + + +# ── Peer UID detection ──────────────────────────────────────────────────────── + +def _get_peer_uid(writer: asyncio.StreamWriter) -> int | None: + """Return the connecting peer's UID, or None if unavailable.""" + try: + sock = writer.get_extra_info("socket") + if sock is None: + return None + + if sys.platform == "linux": + # SO_PEERCRED: struct { pid_t pid; uid_t uid; gid_t gid; } + SO_PEERCRED = 17 + cred = sock.getsockopt( + socket_module().SOL_SOCKET, SO_PEERCRED, struct.calcsize("3i") + ) + _pid, uid, _gid = struct.unpack("3i", cred) + return uid + + if sys.platform == "darwin": + return _macos_peer_uid(sock.fileno()) + + except Exception: + pass + return None + + +def socket_module(): # lazy import to avoid top-level import on non-Unix + import socket + return socket + + +def _macos_peer_uid(fd: int) -> int | None: + """Use getpeereid(2) via ctypes to retrieve the peer UID on macOS.""" + import ctypes + import ctypes.util + + libc_name = ctypes.util.find_library("c") + if not libc_name: + return None + libc = ctypes.CDLL(libc_name) + + euid = ctypes.c_uint32(0) + egid = ctypes.c_uint32(0) + if libc.getpeereid(fd, ctypes.byref(euid), ctypes.byref(egid)) != 0: + return None + return euid.value diff --git a/src/pyra/daemon/pid.py b/src/pyra/daemon/pid.py new file mode 100644 index 0000000..524375f --- /dev/null +++ b/src/pyra/daemon/pid.py @@ -0,0 +1,94 @@ +"""PID file management for the Pyra daemon.""" + +from __future__ import annotations + +import os +import sys +from pathlib import Path + + +class PidFileError(OSError): + """Raised when a PID file operation fails due to a live conflicting process.""" + + +class PidFile: + def __init__(self, path: Path) -> None: + self._path = path + + def write(self) -> None: + """Write the current PID atomically. + + Raises PidFileError if a non-stale PID file already exists. + """ + existing = self.read() + if existing is not None and not self.is_stale(): + raise PidFileError( + f"Daemon already running with PID {existing} " + f"(PID file: {self._path})" + ) + tmp = self._path.with_suffix(".pid.tmp") + tmp.write_text(str(os.getpid())) + tmp.replace(self._path) + + def read(self) -> int | None: + """Return the PID from the file, or None if the file is absent or unreadable.""" + try: + return int(self._path.read_text().strip()) + except (FileNotFoundError, ValueError): + return None + + def is_stale(self) -> bool: + """True when the PID file exists but the process no longer runs.""" + pid = self.read() + if pid is None: + return False + return not _process_is_alive(pid) + + def remove(self) -> None: + """Delete the PID file, ignoring FileNotFoundError.""" + try: + self._path.unlink() + except FileNotFoundError: + pass + + def __enter__(self) -> "PidFile": + self.write() + return self + + def __exit__(self, *_: object) -> None: + self.remove() + + +def resolve_pid_path(cfg_path: str) -> Path: + """Expand ~ and return an absolute Path.""" + return Path(cfg_path).expanduser().resolve() + + +# ── Platform-specific process liveness check ───────────────────────────────── + +def _process_is_alive(pid: int) -> bool: + if sys.platform == "win32": + return _win_process_is_alive(pid) + return _posix_process_is_alive(pid) + + +def _posix_process_is_alive(pid: int) -> bool: + try: + os.kill(pid, 0) + return True + except ProcessLookupError: + return False + except PermissionError: + # Process exists but is owned by another user — still alive. + return True + + +def _win_process_is_alive(pid: int) -> bool: + import ctypes + + SYNCHRONIZE = 0x00100000 + handle = ctypes.windll.kernel32.OpenProcess(SYNCHRONIZE, False, pid) # type: ignore[attr-defined] + if handle == 0: + return False + ctypes.windll.kernel32.CloseHandle(handle) # type: ignore[attr-defined] + return True diff --git a/src/pyra/daemon/service.py b/src/pyra/daemon/service.py new file mode 100644 index 0000000..6bb0f58 --- /dev/null +++ b/src/pyra/daemon/service.py @@ -0,0 +1,212 @@ +"""OS-specific service file generation and install/uninstall for the Pyra daemon.""" + +from __future__ import annotations + +import platform +import shutil +import subprocess +import sys +from pathlib import Path +from typing import Literal + +from pyra.utils.paths import safe_chmod + + +def detect_platform() -> Literal["macos", "linux", "windows"]: + s = platform.system() + if s == "Darwin": + return "macos" + if s == "Linux": + return "linux" + if s == "Windows": + return "windows" + raise RuntimeError(f"Unsupported platform: {s}") + + +def find_pyra_executable() -> str: + """Return the full path to the active pyra executable. + + Tries, in order: + 1. shutil.which("pyra") — works when pyra is on PATH (activated venv) + 2. sys.executable's sibling "pyra" script — covers editable installs + 3. Fallback: sys.executable -m pyra + """ + found = shutil.which("pyra") + if found: + return found + + sibling = Path(sys.executable).parent / "pyra" + if sibling.exists(): + return str(sibling) + + return f"{sys.executable} -m pyra" + + +# ── Template generators ─────────────────────────────────────────────────────── + +def render_launchd_plist(exe: str, log_file: str, pid_file: str) -> str: + log = str(Path(log_file).expanduser()) + return f""" + + + + Label + com.pyra.daemon + ProgramArguments + + {exe} + daemon + run + + RunAtLoad + + KeepAlive + + StandardOutPath + {log} + StandardErrorPath + {log} + ProcessType + Background + + +""" + + +def render_systemd_unit(exe: str, log_file: str) -> str: + log = str(Path(log_file).expanduser()) + return f"""[Unit] +Description=Pyra Personal AI Assistant Daemon +After=default.target + +[Service] +Type=simple +ExecStart={exe} daemon run +Restart=on-failure +RestartSec=5s +StandardOutput=append:{log} +StandardError=append:{log} + +[Install] +WantedBy=default.target +""" + + +def render_schtasks_xml(exe: str) -> str: + return f""" + + + Pyra Personal AI Assistant Daemon + + + + true + + + + IgnoreNew + false + false + PT0S + + PT1M + 999 + + + + + {exe} + daemon run + + + +""" + + +# ── Install / uninstall ─────────────────────────────────────────────────────── + +def install_service() -> None: + """Generate and register the OS service for the current platform.""" + from pyra.config.manager import load_config + + cfg = load_config() + exe = find_pyra_executable() + plat = detect_platform() + + if plat == "macos": + _install_launchd(exe, cfg.daemon.log_file, cfg.daemon.pid_file) + elif plat == "linux": + _install_systemd(exe, cfg.daemon.log_file) + else: + _install_windows(exe) + + +def uninstall_service() -> None: + """Deregister the OS service for the current platform.""" + plat = detect_platform() + if plat == "macos": + _uninstall_launchd() + elif plat == "linux": + _uninstall_systemd() + else: + _uninstall_windows() + + +# ── macOS launchd ───────────────────────────────────────────────────────────── + +_PLIST_PATH = Path.home() / "Library" / "LaunchAgents" / "com.pyra.daemon.plist" + + +def _install_launchd(exe: str, log_file: str, pid_file: str) -> None: + _PLIST_PATH.parent.mkdir(parents=True, exist_ok=True) + _PLIST_PATH.write_text(render_launchd_plist(exe, log_file, pid_file)) + safe_chmod(_PLIST_PATH, 0o644) # launchd requires 644, not 600 + subprocess.run(["launchctl", "load", str(_PLIST_PATH)], check=True) + + +def _uninstall_launchd() -> None: + if _PLIST_PATH.exists(): + subprocess.run(["launchctl", "unload", str(_PLIST_PATH)], check=False) + _PLIST_PATH.unlink() + + +# ── Linux systemd ───────────────────────────────────────────────────────────── + +_SYSTEMD_UNIT = Path.home() / ".config" / "systemd" / "user" / "pyra.service" + + +def _install_systemd(exe: str, log_file: str) -> None: + _SYSTEMD_UNIT.parent.mkdir(parents=True, exist_ok=True) + _SYSTEMD_UNIT.write_text(render_systemd_unit(exe, log_file)) + subprocess.run(["systemctl", "--user", "daemon-reload"], check=True) + subprocess.run(["systemctl", "--user", "enable", "pyra"], check=True) + + +def _uninstall_systemd() -> None: + subprocess.run( + ["systemctl", "--user", "disable", "--now", "pyra"], check=False + ) + if _SYSTEMD_UNIT.exists(): + _SYSTEMD_UNIT.unlink() + subprocess.run(["systemctl", "--user", "daemon-reload"], check=False) + + +# ── Windows Task Scheduler ──────────────────────────────────────────────────── + +def _install_windows(exe: str) -> None: + from pyra.utils.paths import pyra_home + + xml_path = pyra_home() / "daemon_task.xml" + # schtasks /Create /XML requires UTF-16 encoding + xml_path.write_text(render_schtasks_xml(exe), encoding="utf-16") + subprocess.run( + ["schtasks", "/Create", "/TN", "PyraAssistant", "/XML", str(xml_path), "/F"], + check=True, + ) + + +def _uninstall_windows() -> None: + subprocess.run( + ["schtasks", "/Delete", "/TN", "PyraAssistant", "/F"], check=False + ) diff --git a/src/pyra/plugins/registry.py b/src/pyra/plugins/registry.py index ad83d91..b345a59 100644 --- a/src/pyra/plugins/registry.py +++ b/src/pyra/plugins/registry.py @@ -75,6 +75,32 @@ class PluginRegistry: pass return tasks + def get_daemon_task_factories( + self, + ) -> list[tuple[str, Callable[[], Coroutine]]]: # type: ignore[type-arg] + """Return (name, factory) pairs for all plugin daemon tasks. + + Each factory re-calls plugin.daemon_tasks() to produce a fresh coroutine, + enabling the supervisor to restart crashed tasks without changing the plugin + protocol. + """ + factories: list[tuple[str, Callable[[], Coroutine]]] = [] # type: ignore[type-arg] + for plugin in self._plugins.values(): + try: + initial = plugin.daemon_tasks() + n_tasks = len(initial) + for c in initial: + c.close() # prevent "coroutine never awaited" RuntimeWarning + except Exception: + continue + for i in range(n_tasks): + name = f"{plugin.name}.task_{i}" + # Capture plugin and index by value so each closure is independent. + def _factory(p=plugin, idx=i) -> Coroutine: # type: ignore[type-arg] + return p.daemon_tasks()[idx] + factories.append((name, _factory)) + return factories + def find_tool(self, name: str) -> Tool | None: return self._tools.get(name) diff --git a/tests/unit/test_daemon_core.py b/tests/unit/test_daemon_core.py new file mode 100644 index 0000000..2525c64 --- /dev/null +++ b/tests/unit/test_daemon_core.py @@ -0,0 +1,226 @@ +"""Unit tests for the daemon core — PluginSupervisor and IPC handler dispatch.""" + +from __future__ import annotations + +import asyncio + +import pytest + +from pyra.daemon.core import PluginSupervisor, _make_ipc_handler + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +async def _drain(n: int = 20) -> None: + """Yield to the event loop n times to let scheduled tasks run.""" + for _ in range(n): + await asyncio.sleep(0) + + +# ── PluginSupervisor — lifecycle ────────────────────────────────────────────── + +async def test_supervisor_empty_starts_and_stops_cleanly() -> None: + sup = PluginSupervisor() + await sup.start() + await sup.stop() + assert sup.status() == [] + + +async def test_supervisor_runs_task_to_completion() -> None: + done = asyncio.Event() + + async def task(): + done.set() + + sup = PluginSupervisor() + sup._RESTART_DELAY = 0.0 + sup.add_task("t", task) + await sup.start() + + await asyncio.wait_for(done.wait(), timeout=1.0) + await sup.stop() + + assert sup._records[0].restart_count == 0 + assert sup._records[0].last_error is None + + +async def test_supervisor_restarts_crashed_task() -> None: + call_count = 0 + completed = asyncio.Event() + + async def flaky(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("first call fails") + completed.set() + + sup = PluginSupervisor() + sup._RESTART_DELAY = 0.0 + sup.add_task("flaky", flaky) + await sup.start() + + await asyncio.wait_for(completed.wait(), timeout=1.0) + await sup.stop() + + assert sup._records[0].restart_count == 1 + assert "RuntimeError" in (sup._records[0].last_error or "") + + +async def test_supervisor_gives_up_after_max_restarts() -> None: + async def always_fails(): + raise ValueError("always") + + sup = PluginSupervisor() + sup._RESTART_DELAY = 0.0 + sup._MAX_RESTARTS = 3 + sup.add_task("failing", always_fails) + await sup.start() + + # Allow enough iterations for 3 restarts + give-up. + for _ in range(200): + await asyncio.sleep(0) + if sup._records[0].task and sup._records[0].task.done(): + break + + await sup.stop() + + assert sup._records[0].restart_count == 3 + assert sup._records[0].last_error is not None + + +# ── PluginSupervisor — status ───────────────────────────────────────────────── + +async def test_supervisor_status_returns_correct_shape() -> None: + sup = PluginSupervisor() + sup._RESTART_DELAY = 0.0 + + async def noop(): + pass + + sup.add_task("noop", noop) + await sup.start() + await _drain() + + statuses = sup.status() + assert len(statuses) == 1 + s = statuses[0] + assert set(s.keys()) == {"name", "alive", "restart_count", "last_error"} + assert s["name"] == "noop" + assert isinstance(s["alive"], bool) + assert isinstance(s["restart_count"], int) + + await sup.stop() + + +async def test_supervisor_status_empty_when_no_tasks() -> None: + sup = PluginSupervisor() + await sup.start() + assert sup.status() == [] + await sup.stop() + + +# ── PluginSupervisor — reload ───────────────────────────────────────────────── + +async def test_supervisor_reload_restarts_tasks() -> None: + call_count = 0 + + async def counting(): + nonlocal call_count + call_count += 1 + # Hang until cancelled so reload can cancel it. + await asyncio.sleep(10) + + sup = PluginSupervisor() + sup._RESTART_DELAY = 0.0 + sup.add_task("c", counting) + await sup.start() + + await _drain() + assert call_count == 1 + + await sup.reload() + await _drain() + + # After reload, the task should have been restarted (called a second time). + assert call_count == 2 + assert sup._records[0].restart_count == 0 # reset by reload + + await sup.stop() + + +async def test_supervisor_reload_resets_restart_count() -> None: + call_count = 0 + + async def flaky(): + nonlocal call_count + call_count += 1 + if call_count <= 2: + raise RuntimeError("crash") + await asyncio.sleep(10) + + sup = PluginSupervisor() + sup._RESTART_DELAY = 0.0 + sup.add_task("f", flaky) + await sup.start() + + # Wait for 2 crashes to accumulate. + for _ in range(200): + await asyncio.sleep(0) + if sup._records[0].restart_count >= 2: + break + + assert sup._records[0].restart_count == 2 + + await sup.reload() + # Reload must reset the counter. + assert sup._records[0].restart_count == 0 + + await sup.stop() + + +# ── IPC command handler ─────────────────────────────────────────────────────── + +async def test_ipc_handler_ping() -> None: + sup = PluginSupervisor() + handler = _make_ipc_handler(sup) + resp = await handler({"cmd": "ping"}) + assert resp["ok"] is True + assert resp["data"]["pong"] is True + + +async def test_ipc_handler_status_shape() -> None: + sup = PluginSupervisor() + handler = _make_ipc_handler(sup) + resp = await handler({"cmd": "status"}) + assert resp["ok"] is True + assert "uptime" in resp["data"] + assert "pid" in resp["data"] + assert "tasks" in resp["data"] + assert isinstance(resp["data"]["tasks"], list) + + +async def test_ipc_handler_stop_signals_shutdown() -> None: + sup = PluginSupervisor() + handler = _make_ipc_handler(sup) + assert not sup._shutdown.is_set() + resp = await handler({"cmd": "stop"}) + assert resp["ok"] is True + assert sup._shutdown.is_set() + + +async def test_ipc_handler_reload_returns_task_count() -> None: + sup = PluginSupervisor() + handler = _make_ipc_handler(sup) + resp = await handler({"cmd": "reload"}) + assert resp["ok"] is True + assert resp["data"]["tasks_reloaded"] == 0 + + +async def test_ipc_handler_unknown_command() -> None: + sup = PluginSupervisor() + handler = _make_ipc_handler(sup) + resp = await handler({"cmd": "bogus"}) + assert resp["ok"] is False + assert "error" in resp["data"] + assert "bogus" in resp["data"]["error"] diff --git a/tests/unit/test_daemon_ipc.py b/tests/unit/test_daemon_ipc.py new file mode 100644 index 0000000..4ed466f --- /dev/null +++ b/tests/unit/test_daemon_ipc.py @@ -0,0 +1,162 @@ +"""Unit tests for the IPC layer.""" + +from __future__ import annotations + +import asyncio +import os +import sys +import tempfile +from pathlib import Path + +import pytest + +from pyra.daemon.ipc import ( + IpcClient, + IpcMessage, + IpcResponse, + IpcServer, + decode_message, + encode_message, + is_unix_socket, +) + + +@pytest.fixture +def sock_path(): + """Short socket path that fits within macOS's 104-char AF_UNIX limit.""" + with tempfile.TemporaryDirectory(dir="/tmp") as d: + yield Path(d) / "t.sock" + + +# ── Protocol encode / decode ────────────────────────────────────────────────── + +def test_encode_appends_newline() -> None: + data = encode_message({"cmd": "ping"}) + assert data.endswith(b"\n") + + +def test_encode_is_valid_json() -> None: + import json + data = encode_message({"cmd": "status", "extra": 42}) + assert json.loads(data) == {"cmd": "status", "extra": 42} + + +def test_decode_roundtrip() -> None: + msg: IpcMessage = {"cmd": "stop"} + assert decode_message(encode_message(msg)) == msg + + +def test_decode_strips_newline() -> None: + assert decode_message(b'{"cmd": "stop"}\n')["cmd"] == "stop" + + +def test_decode_raises_on_bad_json() -> None: + with pytest.raises(ValueError, match="Invalid IPC message"): + decode_message(b"not json\n") + + +def test_decode_raises_on_empty_line() -> None: + with pytest.raises(ValueError): + decode_message(b"\n") + + +# ── is_unix_socket ──────────────────────────────────────────────────────────── + +def test_is_unix_socket_matches_platform() -> None: + if sys.platform == "win32": + assert not is_unix_socket() + else: + assert is_unix_socket() + + +# ── Server + client roundtrip (Unix only) ───────────────────────────────────── + +@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test") +async def test_server_client_ping(sock_path: Path) -> None: + async def handler(msg: IpcMessage) -> IpcResponse: + return {"ok": True, "data": {"pong": True}} + + server = IpcServer(sock_path, handler) + await server.start() + try: + resp = await IpcClient(sock_path).send({"cmd": "ping"}) + assert resp["ok"] is True + assert resp["data"]["pong"] is True + finally: + await server.stop() + + +@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test") +async def test_server_echoes_error_for_bad_json(sock_path: Path) -> None: + async def handler(msg: IpcMessage) -> IpcResponse: + return {"ok": True, "data": {}} + + server = IpcServer(sock_path, handler) + await server.start() + try: + reader, writer = await asyncio.open_unix_connection(str(sock_path)) + writer.write(b"not valid json\n") + await writer.drain() + line = await asyncio.wait_for(reader.readline(), timeout=3.0) + resp = decode_message(line) + assert resp["ok"] is False + assert "error" in resp["data"] + finally: + try: + writer.close() + except Exception: + pass + await server.stop() + + +@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test") +async def test_handler_response_returned_to_client(sock_path: Path) -> None: + async def handler(msg: IpcMessage) -> IpcResponse: + if msg.get("cmd") == "status": + return {"ok": True, "data": {"uptime": 99.0}} + return {"ok": False, "data": {"error": "unknown"}} + + server = IpcServer(sock_path, handler) + await server.start() + try: + resp = await IpcClient(sock_path).send({"cmd": "status"}) + assert resp["ok"] is True + assert resp["data"]["uptime"] == 99.0 + + resp2 = await IpcClient(sock_path).send({"cmd": "bogus"}) + assert resp2["ok"] is False + finally: + await server.stop() + + +@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test") +async def test_client_raises_when_no_server(sock_path: Path) -> None: + client = IpcClient(sock_path) + with pytest.raises((ConnectionRefusedError, FileNotFoundError, OSError)): + await client.send({"cmd": "ping"}) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test") +async def test_socket_file_chmod_600(sock_path: Path) -> None: + async def handler(msg: IpcMessage) -> IpcResponse: + return {"ok": True, "data": {}} + + server = IpcServer(sock_path, handler) + await server.start() + try: + mode = oct(sock_path.stat().st_mode & 0o777) + assert mode == oct(0o600), f"Expected 0o600, got {mode}" + finally: + await server.stop() + + +@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test") +async def test_stop_removes_socket_file(sock_path: Path) -> None: + async def handler(msg: IpcMessage) -> IpcResponse: + return {"ok": True, "data": {}} + + server = IpcServer(sock_path, handler) + await server.start() + assert sock_path.exists() + await server.stop() + assert not sock_path.exists() diff --git a/tests/unit/test_daemon_pid.py b/tests/unit/test_daemon_pid.py new file mode 100644 index 0000000..7bdb82a --- /dev/null +++ b/tests/unit/test_daemon_pid.py @@ -0,0 +1,103 @@ +"""Unit tests for daemon PID file management.""" + +from __future__ import annotations + +import os +from pathlib import Path + +import pytest + +from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path + + +def test_write_creates_file(tmp_path: Path) -> None: + p = PidFile(tmp_path / "daemon.pid") + p.write() + assert (tmp_path / "daemon.pid").exists() + assert int((tmp_path / "daemon.pid").read_text().strip()) == os.getpid() + + +def test_read_returns_none_when_absent(tmp_path: Path) -> None: + p = PidFile(tmp_path / "daemon.pid") + assert p.read() is None + + +def test_read_returns_pid_when_present(tmp_path: Path) -> None: + pid_file = tmp_path / "daemon.pid" + pid_file.write_text("12345") + p = PidFile(pid_file) + assert p.read() == 12345 + + +def test_read_returns_none_on_bad_content(tmp_path: Path) -> None: + pid_file = tmp_path / "daemon.pid" + pid_file.write_text("not-a-number") + p = PidFile(pid_file) + assert p.read() is None + + +def test_is_stale_false_for_self(tmp_path: Path) -> None: + p = PidFile(tmp_path / "daemon.pid") + p.write() + assert not p.is_stale() + + +def test_is_stale_true_for_dead_pid(tmp_path: Path) -> None: + pid_file = tmp_path / "daemon.pid" + pid_file.write_text("999999999") # unrealistically large PID + p = PidFile(pid_file) + assert p.is_stale() + + +def test_is_stale_false_when_file_absent(tmp_path: Path) -> None: + p = PidFile(tmp_path / "daemon.pid") + assert not p.is_stale() + + +def test_remove_deletes_file(tmp_path: Path) -> None: + p = PidFile(tmp_path / "daemon.pid") + p.write() + p.remove() + assert not (tmp_path / "daemon.pid").exists() + + +def test_remove_is_idempotent(tmp_path: Path) -> None: + p = PidFile(tmp_path / "daemon.pid") + p.remove() # must not raise + + +def test_context_manager_writes_and_removes(tmp_path: Path) -> None: + pid_file = tmp_path / "daemon.pid" + p = PidFile(pid_file) + with p: + assert pid_file.exists() + assert int(pid_file.read_text().strip()) == os.getpid() + assert not pid_file.exists() + + +def test_write_raises_when_live_pid_exists(tmp_path: Path) -> None: + p = PidFile(tmp_path / "daemon.pid") + p.write() # writes self PID (which is alive) + p2 = PidFile(tmp_path / "daemon.pid") + with pytest.raises(PidFileError): + p2.write() + + +def test_write_succeeds_over_stale_pid(tmp_path: Path) -> None: + pid_file = tmp_path / "daemon.pid" + pid_file.write_text("999999999") # stale + p = PidFile(pid_file) + p.write() # should not raise + assert int(pid_file.read_text().strip()) == os.getpid() + + +def test_resolve_pid_path_expands_tilde() -> None: + result = resolve_pid_path("~/.pyra/daemon.pid") + assert not str(result).startswith("~") + assert result.is_absolute() + + +def test_resolve_pid_path_absolute_unchanged(tmp_path: Path) -> None: + path = tmp_path / "daemon.pid" + result = resolve_pid_path(str(path)) + assert result == path diff --git a/tests/unit/test_daemon_service.py b/tests/unit/test_daemon_service.py new file mode 100644 index 0000000..0dc951a --- /dev/null +++ b/tests/unit/test_daemon_service.py @@ -0,0 +1,189 @@ +"""Unit tests for daemon service file generation and platform detection.""" + +from __future__ import annotations + +import subprocess +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from pyra.daemon.service import ( + detect_platform, + find_pyra_executable, + render_launchd_plist, + render_systemd_unit, + render_schtasks_xml, +) + + +# ── Template rendering ──────────────────────────────────────────────────────── + +def test_render_launchd_plist_contains_exe() -> None: + xml = render_launchd_plist("/usr/local/bin/pyra", "~/.pyra/daemon.log", "~/.pyra/daemon.pid") + assert "/usr/local/bin/pyra" in xml + assert "daemon" in xml + assert "run" in xml + assert "com.pyra.daemon" in xml + assert "" in xml # KeepAlive and RunAtLoad + + +def test_render_launchd_plist_expands_log_tilde() -> None: + xml = render_launchd_plist("/bin/pyra", "~/.pyra/daemon.log", "~/.pyra/daemon.pid") + assert "~" not in xml + + +def test_render_systemd_unit_contains_exe() -> None: + unit = render_systemd_unit("/usr/local/bin/pyra", "~/.pyra/daemon.log") + assert "ExecStart=/usr/local/bin/pyra daemon run" in unit + assert "Restart=on-failure" in unit + assert "Type=simple" in unit + assert "WantedBy=default.target" in unit + + +def test_render_systemd_unit_expands_log_tilde() -> None: + unit = render_systemd_unit("/bin/pyra", "~/.pyra/daemon.log") + assert "~" not in unit + + +def test_render_schtasks_xml_contains_exe() -> None: + xml = render_schtasks_xml("C:\\Users\\test\\pyra.exe") + assert "C:\\Users\\test\\pyra.exe" in xml + assert "LogonTrigger" in xml + assert "daemon run" in xml + assert "IgnoreNew" in xml + + +def test_render_schtasks_xml_no_time_limit() -> None: + xml = render_schtasks_xml("pyra.exe") + assert "PT0S" in xml # ExecutionTimeLimit=PT0S means unlimited + + +# ── Platform detection ──────────────────────────────────────────────────────── + +def test_detect_platform_returns_known_value() -> None: + result = detect_platform() + assert result in ("macos", "linux", "windows") + + +@pytest.mark.parametrize("system,expected", [ + ("Darwin", "macos"), + ("Linux", "linux"), + ("Windows", "windows"), +]) +def test_detect_platform_maps_correctly(system: str, expected: str) -> None: + with patch("platform.system", return_value=system): + assert detect_platform() == expected + + +def test_detect_platform_raises_on_unknown() -> None: + with patch("platform.system", return_value="FreeBSD"): + with pytest.raises(RuntimeError, match="Unsupported platform"): + detect_platform() + + +# ── Executable detection ────────────────────────────────────────────────────── + +def test_find_pyra_executable_returns_string() -> None: + result = find_pyra_executable() + assert isinstance(result, str) + assert len(result) > 0 + + +def test_find_pyra_executable_uses_which_when_available(tmp_path: Path) -> None: + fake_pyra = tmp_path / "pyra" + fake_pyra.touch() + with patch("shutil.which", return_value=str(fake_pyra)): + assert find_pyra_executable() == str(fake_pyra) + + +def test_find_pyra_executable_falls_back_to_sibling(tmp_path: Path) -> None: + fake_python = tmp_path / "python3" + fake_pyra = tmp_path / "pyra" + fake_pyra.touch() + with patch("shutil.which", return_value=None): + with patch("sys.executable", str(fake_python)): + assert find_pyra_executable() == str(fake_pyra) + + +def test_find_pyra_executable_falls_back_to_module(tmp_path: Path) -> None: + fake_python = tmp_path / "python3" + with patch("shutil.which", return_value=None): + with patch("sys.executable", str(fake_python)): + result = find_pyra_executable() + assert result == f"{fake_python} -m pyra" + + +# ── Install / uninstall (subprocess mocked) ─────────────────────────────────── + +@pytest.mark.skipif(sys.platform == "win32", reason="launchd install is macOS-only") +def test_install_launchd_writes_plist_and_calls_launchctl( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + import pyra.daemon.service as svc + + plist_path = tmp_path / "Library" / "LaunchAgents" / "com.pyra.daemon.plist" + monkeypatch.setattr(svc, "_PLIST_PATH", plist_path) + + calls: list[list[str]] = [] + monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: calls.append(cmd)) + + svc._install_launchd("/usr/local/bin/pyra", "~/.pyra/daemon.log", "~/.pyra/daemon.pid") + + assert plist_path.exists() + assert "com.pyra.daemon" in plist_path.read_text() + assert any("launchctl" in c[0] for c in calls) + + +@pytest.mark.skipif(sys.platform == "win32", reason="systemd install is Linux-only") +def test_install_systemd_writes_unit_and_calls_systemctl( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + import pyra.daemon.service as svc + + unit_path = tmp_path / ".config" / "systemd" / "user" / "pyra.service" + monkeypatch.setattr(svc, "_SYSTEMD_UNIT", unit_path) + + calls: list[list[str]] = [] + monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: calls.append(cmd)) + + svc._install_systemd("/usr/local/bin/pyra", "~/.pyra/daemon.log") + + assert unit_path.exists() + assert "ExecStart" in unit_path.read_text() + assert any("systemctl" in c[0] for c in calls) + + +@pytest.mark.skipif(sys.platform == "win32", reason="launchd uninstall is macOS-only") +def test_uninstall_launchd_removes_plist( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + import pyra.daemon.service as svc + + plist_path = tmp_path / "Library" / "LaunchAgents" / "com.pyra.daemon.plist" + plist_path.parent.mkdir(parents=True) + plist_path.write_text("") + monkeypatch.setattr(svc, "_PLIST_PATH", plist_path) + monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: None) + + svc._uninstall_launchd() + + assert not plist_path.exists() + + +@pytest.mark.skipif(sys.platform == "win32", reason="systemd uninstall is Linux-only") +def test_uninstall_systemd_removes_unit( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + import pyra.daemon.service as svc + + unit_path = tmp_path / ".config" / "systemd" / "user" / "pyra.service" + unit_path.parent.mkdir(parents=True) + unit_path.write_text("[Service]") + monkeypatch.setattr(svc, "_SYSTEMD_UNIT", unit_path) + monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: None) + + svc._uninstall_systemd() + + assert not unit_path.exists()