diff --git a/src/pyra/config/schema.py b/src/pyra/config/schema.py index a945442..03ff51f 100644 --- a/src/pyra/config/schema.py +++ b/src/pyra/config/schema.py @@ -36,6 +36,7 @@ class DaemonConfig(BaseModel): socket_path: str = "~/.pyra/daemon.sock" log_file: str = "~/.pyra/daemon.log" pid_file: str = "~/.pyra/daemon.pid" + ipc_port: int = 0 # Windows TCP loopback: 0 = OS-assigned, written to ~/.pyra/daemon.port class PyraConfig(BaseModel): diff --git a/src/pyra/daemon/__init__.py b/src/pyra/daemon/__init__.py index e69de29..2f10e68 100644 --- a/src/pyra/daemon/__init__.py +++ b/src/pyra/daemon/__init__.py @@ -0,0 +1,21 @@ +"""Pyra background daemon package.""" + +from pyra.daemon.core import PluginSupervisor, run_foreground, start_background +from pyra.daemon.ipc import IpcClient, IpcServer, send_command +from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path +from pyra.daemon.service import detect_platform, install_service, uninstall_service + +__all__ = [ + "run_foreground", + "start_background", + "PluginSupervisor", + "IpcClient", + "IpcServer", + "send_command", + "PidFile", + "PidFileError", + "resolve_pid_path", + "detect_platform", + "install_service", + "uninstall_service", +] diff --git a/src/pyra/daemon/core.py b/src/pyra/daemon/core.py new file mode 100644 index 0000000..ae8115d --- /dev/null +++ b/src/pyra/daemon/core.py @@ -0,0 +1,291 @@ +"""Pyra daemon core — asyncio event loop, plugin task supervisor, signal handling.""" + +from __future__ import annotations + +import asyncio +import logging +import logging.handlers +import os +import signal +import subprocess +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable, Coroutine + +from pyra.utils.paths import pyra_home, safe_chmod + + +_log = logging.getLogger("pyra.daemon") +_start_time: float = 0.0 + + +# ── Plugin task supervisor ──────────────────────────────────────────────────── + +@dataclass +class TaskRecord: + name: str + coro_factory: Callable[[], Coroutine] # type: ignore[type-arg] + task: asyncio.Task | None = field(default=None, repr=False) + restart_count: int = 0 + last_error: str | None = None + + def is_alive(self) -> bool: + return self.task is not None and not self.task.done() + + +class PluginSupervisor: + _RESTART_DELAY: float = 5.0 + _MAX_RESTARTS: int = 10 + + def __init__(self) -> None: + self._records: list[TaskRecord] = [] + self._shutdown = asyncio.Event() + + def add_task(self, name: str, factory: Callable[[], Coroutine]) -> None: # type: ignore[type-arg] + self._records.append(TaskRecord(name=name, coro_factory=factory)) + + async def start(self) -> None: + for record in self._records: + record.task = asyncio.create_task( + self._supervise(record), name=record.name + ) + _log.info("Supervisor started with %d plugin task(s).", len(self._records)) + + async def run_until_shutdown(self) -> None: + await self._shutdown.wait() + _log.info("Shutdown requested — stopping supervisor.") + + def request_shutdown(self) -> None: + self._shutdown.set() + + async def stop(self) -> None: + for record in self._records: + if record.task and not record.task.done(): + record.task.cancel() + tasks = [r.task for r in self._records if r.task] + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + def status(self) -> list[dict]: + return [ + { + "name": r.name, + "alive": r.is_alive(), + "restart_count": r.restart_count, + "last_error": r.last_error, + } + for r in self._records + ] + + async def _supervise(self, record: TaskRecord) -> None: + while not self._shutdown.is_set(): + try: + await record.coro_factory() + _log.info("Plugin task %s completed normally.", record.name) + return + except asyncio.CancelledError: + return + except Exception as exc: + record.restart_count += 1 + record.last_error = f"{type(exc).__name__}: {exc}" + _log.error( + "Plugin task %s crashed (restart #%d): %s", + record.name, record.restart_count, exc, + exc_info=True, + ) + if record.restart_count >= self._MAX_RESTARTS: + _log.critical( + "Plugin task %s exceeded max restarts (%d). Giving up.", + record.name, self._MAX_RESTARTS, + ) + return + try: + await asyncio.wait_for( + asyncio.sleep(self._RESTART_DELAY), + timeout=self._RESTART_DELAY + 1, + ) + except asyncio.CancelledError: + return + + +# ── IPC command dispatch ────────────────────────────────────────────────────── + +def _make_ipc_handler(supervisor: PluginSupervisor): + async def handler(msg: dict) -> dict: + cmd = msg.get("cmd", "") + match cmd: + case "ping": + return {"ok": True, "data": {"pong": True}} + case "status": + return { + "ok": True, + "data": { + "uptime": time.monotonic() - _start_time, + "pid": os.getpid(), + "tasks": supervisor.status(), + }, + } + case "stop": + supervisor.request_shutdown() + return {"ok": True, "data": {}} + case "reload": + _log.info("Reload requested via IPC.") + return {"ok": True, "data": {}} + case _: + return {"ok": False, "data": {"error": f"unknown command: {cmd}"}} + + return handler + + +# ── Main async entrypoint ───────────────────────────────────────────────────── + +async def _run_daemon(cfg, supervisor: PluginSupervisor) -> None: + from pyra.daemon.ipc import IpcServer, get_socket_path, is_unix_socket + + if is_unix_socket(): + address = get_socket_path(cfg.daemon.socket_path) + else: + address = ("127.0.0.1", cfg.daemon.ipc_port) + + server = IpcServer(address, _make_ipc_handler(supervisor)) + + await supervisor.start() + + async with asyncio.TaskGroup() as tg: + tg.create_task(server.start(), name="ipc_server") + tg.create_task(supervisor.run_until_shutdown(), name="shutdown_waiter") + + await server.stop() + await supervisor.stop() + + +# ── Foreground entry point (pyra daemon run) ────────────────────────────────── + +def run_foreground() -> None: + """Run the daemon in the foreground. Called by `pyra daemon run`.""" + from pyra.config.manager import load_config + from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path + from pyra.plugins.registry import PluginRegistry + + global _start_time + + cfg = load_config() + _setup_logging(cfg.daemon.log_file) + pid_path = resolve_pid_path(cfg.daemon.pid_file) + pid_file = PidFile(pid_path) + + existing = pid_file.read() + if existing is not None and not pid_file.is_stale(): + _log.error("Daemon already running (PID %d). Exiting.", existing) + sys.exit(1) + + registry = PluginRegistry() + from pyra.utils.paths import pyra_home as _pyra_home + plugins_dir = _pyra_home() / "plugins" + if plugins_dir.exists(): + registry.load_all(plugins_dir, cfg.plugins.enabled) + + supervisor = PluginSupervisor() + for name, factory in registry.get_daemon_task_factories(): + supervisor.add_task(name, factory) + + _start_time = time.monotonic() + + with pid_file: + _install_signal_handlers(supervisor) + _log.info("Pyra daemon starting (PID %d).", os.getpid()) + try: + asyncio.run(_run_daemon(cfg, supervisor)) + except KeyboardInterrupt: + pass + _log.info("Pyra daemon stopped.") + + +# ── Background spawn (pyra daemon start) ───────────────────────────────────── + +def start_background() -> None: + """Spawn `pyra daemon run` as a detached background process.""" + from pyra.config.manager import load_config + from pyra.daemon.pid import PidFile, resolve_pid_path + from pyra.daemon.service import find_pyra_executable + + cfg = load_config() + pid_path = resolve_pid_path(cfg.daemon.pid_file) + pid_file = PidFile(pid_path) + + existing = pid_file.read() + if existing is not None and not pid_file.is_stale(): + from pyra.chat.renderer import console + console.print(f"[yellow]Daemon already running (PID {existing}).[/yellow]") + return + + exe = find_pyra_executable() + log_path = Path(cfg.daemon.log_file).expanduser() + log_path.parent.mkdir(parents=True, exist_ok=True) + + with open(log_path, "a") as log_fh: + if sys.platform == "win32": + DETACHED_PROCESS = 0x00000008 + CREATE_NEW_PROCESS_GROUP = 0x00000200 + subprocess.Popen( + [exe, "daemon", "run"], + creationflags=DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP, + stdout=log_fh, + stderr=log_fh, + close_fds=True, + ) + else: + subprocess.Popen( + [exe, "daemon", "run"], + start_new_session=True, + stdout=log_fh, + stderr=log_fh, + stdin=subprocess.DEVNULL, + close_fds=True, + ) + + from pyra.chat.renderer import console + + # Poll the PID file for up to 3 seconds to confirm startup. + for _ in range(30): + time.sleep(0.1) + pid = pid_file.read() + if pid is not None: + console.print(f"[green]Daemon started (PID {pid}).[/green]") + return + + console.print("[yellow]Daemon process spawned but PID file not yet written.[/yellow]") + + +# ── Signal handling ─────────────────────────────────────────────────────────── + +def _install_signal_handlers(supervisor: PluginSupervisor) -> None: + if sys.platform == "win32": + signal.signal(signal.SIGTERM, lambda *_: supervisor.request_shutdown()) + return + + loop = asyncio.get_event_loop() + loop.add_signal_handler(signal.SIGTERM, supervisor.request_shutdown) + loop.add_signal_handler(signal.SIGHUP, supervisor.request_shutdown) + + +# ── Logging setup ───────────────────────────────────────────────────────────── + +def _setup_logging(log_file_str: str) -> None: + log_path = Path(log_file_str).expanduser() + log_path.parent.mkdir(parents=True, exist_ok=True) + + handler = logging.handlers.RotatingFileHandler( + log_path, maxBytes=5 * 1024 * 1024, backupCount=3 + ) + handler.setFormatter( + logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s") + ) + + root = logging.getLogger("pyra") + root.addHandler(handler) + root.setLevel(logging.INFO) + + safe_chmod(log_path, 0o600) diff --git a/src/pyra/plugins/registry.py b/src/pyra/plugins/registry.py index ad83d91..9186904 100644 --- a/src/pyra/plugins/registry.py +++ b/src/pyra/plugins/registry.py @@ -75,6 +75,29 @@ class PluginRegistry: pass return tasks + def get_daemon_task_factories( + self, + ) -> list[tuple[str, Callable[[], Coroutine]]]: # type: ignore[type-arg] + """Return (name, factory) pairs for all plugin daemon tasks. + + Each factory re-calls plugin.daemon_tasks() to produce a fresh coroutine, + enabling the supervisor to restart crashed tasks without changing the plugin + protocol. + """ + factories: list[tuple[str, Callable[[], Coroutine]]] = [] # type: ignore[type-arg] + for plugin in self._plugins.values(): + try: + n_tasks = len(plugin.daemon_tasks()) + except Exception: + continue + for i in range(n_tasks): + name = f"{plugin.name}.task_{i}" + # Capture plugin and index by value so each closure is independent. + def _factory(p=plugin, idx=i) -> Coroutine: # type: ignore[type-arg] + return p.daemon_tasks()[idx] + factories.append((name, _factory)) + return factories + def find_tool(self, name: str) -> Tool | None: return self._tools.get(name)