feat(daemon): add core asyncio daemon, supervisor, and registry factories

- core.py: asyncio event loop entry point, PluginSupervisor with per-task
  restart (up to 10 times, 5s back-off), IPC dispatch, signal handling
  (SIGTERM/SIGHUP on POSIX), RotatingFileHandler, start_background() helper
- daemon/__init__.py: export public API
- plugins/registry.py: add get_daemon_task_factories() so supervisor can
  restart crashed tasks by re-calling plugin.daemon_tasks()[i]
- config/schema.py: add DaemonConfig.ipc_port for Windows TCP loopback

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
curo1305
2026-05-19 15:33:57 +02:00
parent d42b8b4a47
commit 3d3ce694b9
4 changed files with 336 additions and 0 deletions
+1
View File
@@ -36,6 +36,7 @@ class DaemonConfig(BaseModel):
socket_path: str = "~/.pyra/daemon.sock"
log_file: str = "~/.pyra/daemon.log"
pid_file: str = "~/.pyra/daemon.pid"
ipc_port: int = 0 # Windows TCP loopback: 0 = OS-assigned, written to ~/.pyra/daemon.port
class PyraConfig(BaseModel):
+21
View File
@@ -0,0 +1,21 @@
"""Pyra background daemon package."""
from pyra.daemon.core import PluginSupervisor, run_foreground, start_background
from pyra.daemon.ipc import IpcClient, IpcServer, send_command
from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path
from pyra.daemon.service import detect_platform, install_service, uninstall_service
__all__ = [
"run_foreground",
"start_background",
"PluginSupervisor",
"IpcClient",
"IpcServer",
"send_command",
"PidFile",
"PidFileError",
"resolve_pid_path",
"detect_platform",
"install_service",
"uninstall_service",
]
+291
View File
@@ -0,0 +1,291 @@
"""Pyra daemon core — asyncio event loop, plugin task supervisor, signal handling."""
from __future__ import annotations
import asyncio
import logging
import logging.handlers
import os
import signal
import subprocess
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Coroutine
from pyra.utils.paths import pyra_home, safe_chmod
_log = logging.getLogger("pyra.daemon")
_start_time: float = 0.0
# ── Plugin task supervisor ────────────────────────────────────────────────────
@dataclass
class TaskRecord:
name: str
coro_factory: Callable[[], Coroutine] # type: ignore[type-arg]
task: asyncio.Task | None = field(default=None, repr=False)
restart_count: int = 0
last_error: str | None = None
def is_alive(self) -> bool:
return self.task is not None and not self.task.done()
class PluginSupervisor:
_RESTART_DELAY: float = 5.0
_MAX_RESTARTS: int = 10
def __init__(self) -> None:
self._records: list[TaskRecord] = []
self._shutdown = asyncio.Event()
def add_task(self, name: str, factory: Callable[[], Coroutine]) -> None: # type: ignore[type-arg]
self._records.append(TaskRecord(name=name, coro_factory=factory))
async def start(self) -> None:
for record in self._records:
record.task = asyncio.create_task(
self._supervise(record), name=record.name
)
_log.info("Supervisor started with %d plugin task(s).", len(self._records))
async def run_until_shutdown(self) -> None:
await self._shutdown.wait()
_log.info("Shutdown requested — stopping supervisor.")
def request_shutdown(self) -> None:
self._shutdown.set()
async def stop(self) -> None:
for record in self._records:
if record.task and not record.task.done():
record.task.cancel()
tasks = [r.task for r in self._records if r.task]
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
def status(self) -> list[dict]:
return [
{
"name": r.name,
"alive": r.is_alive(),
"restart_count": r.restart_count,
"last_error": r.last_error,
}
for r in self._records
]
async def _supervise(self, record: TaskRecord) -> None:
while not self._shutdown.is_set():
try:
await record.coro_factory()
_log.info("Plugin task %s completed normally.", record.name)
return
except asyncio.CancelledError:
return
except Exception as exc:
record.restart_count += 1
record.last_error = f"{type(exc).__name__}: {exc}"
_log.error(
"Plugin task %s crashed (restart #%d): %s",
record.name, record.restart_count, exc,
exc_info=True,
)
if record.restart_count >= self._MAX_RESTARTS:
_log.critical(
"Plugin task %s exceeded max restarts (%d). Giving up.",
record.name, self._MAX_RESTARTS,
)
return
try:
await asyncio.wait_for(
asyncio.sleep(self._RESTART_DELAY),
timeout=self._RESTART_DELAY + 1,
)
except asyncio.CancelledError:
return
# ── IPC command dispatch ──────────────────────────────────────────────────────
def _make_ipc_handler(supervisor: PluginSupervisor):
async def handler(msg: dict) -> dict:
cmd = msg.get("cmd", "")
match cmd:
case "ping":
return {"ok": True, "data": {"pong": True}}
case "status":
return {
"ok": True,
"data": {
"uptime": time.monotonic() - _start_time,
"pid": os.getpid(),
"tasks": supervisor.status(),
},
}
case "stop":
supervisor.request_shutdown()
return {"ok": True, "data": {}}
case "reload":
_log.info("Reload requested via IPC.")
return {"ok": True, "data": {}}
case _:
return {"ok": False, "data": {"error": f"unknown command: {cmd}"}}
return handler
# ── Main async entrypoint ─────────────────────────────────────────────────────
async def _run_daemon(cfg, supervisor: PluginSupervisor) -> None:
from pyra.daemon.ipc import IpcServer, get_socket_path, is_unix_socket
if is_unix_socket():
address = get_socket_path(cfg.daemon.socket_path)
else:
address = ("127.0.0.1", cfg.daemon.ipc_port)
server = IpcServer(address, _make_ipc_handler(supervisor))
await supervisor.start()
async with asyncio.TaskGroup() as tg:
tg.create_task(server.start(), name="ipc_server")
tg.create_task(supervisor.run_until_shutdown(), name="shutdown_waiter")
await server.stop()
await supervisor.stop()
# ── Foreground entry point (pyra daemon run) ──────────────────────────────────
def run_foreground() -> None:
"""Run the daemon in the foreground. Called by `pyra daemon run`."""
from pyra.config.manager import load_config
from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path
from pyra.plugins.registry import PluginRegistry
global _start_time
cfg = load_config()
_setup_logging(cfg.daemon.log_file)
pid_path = resolve_pid_path(cfg.daemon.pid_file)
pid_file = PidFile(pid_path)
existing = pid_file.read()
if existing is not None and not pid_file.is_stale():
_log.error("Daemon already running (PID %d). Exiting.", existing)
sys.exit(1)
registry = PluginRegistry()
from pyra.utils.paths import pyra_home as _pyra_home
plugins_dir = _pyra_home() / "plugins"
if plugins_dir.exists():
registry.load_all(plugins_dir, cfg.plugins.enabled)
supervisor = PluginSupervisor()
for name, factory in registry.get_daemon_task_factories():
supervisor.add_task(name, factory)
_start_time = time.monotonic()
with pid_file:
_install_signal_handlers(supervisor)
_log.info("Pyra daemon starting (PID %d).", os.getpid())
try:
asyncio.run(_run_daemon(cfg, supervisor))
except KeyboardInterrupt:
pass
_log.info("Pyra daemon stopped.")
# ── Background spawn (pyra daemon start) ─────────────────────────────────────
def start_background() -> None:
"""Spawn `pyra daemon run` as a detached background process."""
from pyra.config.manager import load_config
from pyra.daemon.pid import PidFile, resolve_pid_path
from pyra.daemon.service import find_pyra_executable
cfg = load_config()
pid_path = resolve_pid_path(cfg.daemon.pid_file)
pid_file = PidFile(pid_path)
existing = pid_file.read()
if existing is not None and not pid_file.is_stale():
from pyra.chat.renderer import console
console.print(f"[yellow]Daemon already running (PID {existing}).[/yellow]")
return
exe = find_pyra_executable()
log_path = Path(cfg.daemon.log_file).expanduser()
log_path.parent.mkdir(parents=True, exist_ok=True)
with open(log_path, "a") as log_fh:
if sys.platform == "win32":
DETACHED_PROCESS = 0x00000008
CREATE_NEW_PROCESS_GROUP = 0x00000200
subprocess.Popen(
[exe, "daemon", "run"],
creationflags=DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP,
stdout=log_fh,
stderr=log_fh,
close_fds=True,
)
else:
subprocess.Popen(
[exe, "daemon", "run"],
start_new_session=True,
stdout=log_fh,
stderr=log_fh,
stdin=subprocess.DEVNULL,
close_fds=True,
)
from pyra.chat.renderer import console
# Poll the PID file for up to 3 seconds to confirm startup.
for _ in range(30):
time.sleep(0.1)
pid = pid_file.read()
if pid is not None:
console.print(f"[green]Daemon started (PID {pid}).[/green]")
return
console.print("[yellow]Daemon process spawned but PID file not yet written.[/yellow]")
# ── Signal handling ───────────────────────────────────────────────────────────
def _install_signal_handlers(supervisor: PluginSupervisor) -> None:
if sys.platform == "win32":
signal.signal(signal.SIGTERM, lambda *_: supervisor.request_shutdown())
return
loop = asyncio.get_event_loop()
loop.add_signal_handler(signal.SIGTERM, supervisor.request_shutdown)
loop.add_signal_handler(signal.SIGHUP, supervisor.request_shutdown)
# ── Logging setup ─────────────────────────────────────────────────────────────
def _setup_logging(log_file_str: str) -> None:
log_path = Path(log_file_str).expanduser()
log_path.parent.mkdir(parents=True, exist_ok=True)
handler = logging.handlers.RotatingFileHandler(
log_path, maxBytes=5 * 1024 * 1024, backupCount=3
)
handler.setFormatter(
logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s")
)
root = logging.getLogger("pyra")
root.addHandler(handler)
root.setLevel(logging.INFO)
safe_chmod(log_path, 0o600)
+23
View File
@@ -75,6 +75,29 @@ class PluginRegistry:
pass
return tasks
def get_daemon_task_factories(
self,
) -> list[tuple[str, Callable[[], Coroutine]]]: # type: ignore[type-arg]
"""Return (name, factory) pairs for all plugin daemon tasks.
Each factory re-calls plugin.daemon_tasks() to produce a fresh coroutine,
enabling the supervisor to restart crashed tasks without changing the plugin
protocol.
"""
factories: list[tuple[str, Callable[[], Coroutine]]] = [] # type: ignore[type-arg]
for plugin in self._plugins.values():
try:
n_tasks = len(plugin.daemon_tasks())
except Exception:
continue
for i in range(n_tasks):
name = f"{plugin.name}.task_{i}"
# Capture plugin and index by value so each closure is independent.
def _factory(p=plugin, idx=i) -> Coroutine: # type: ignore[type-arg]
return p.daemon_tasks()[idx]
factories.append((name, _factory))
return factories
def find_tool(self, name: str) -> Tool | None:
return self._tools.get(name)