Compare commits
5 Commits
0e052c4992
...
c41ad0afc6
| Author | SHA1 | Date | |
|---|---|---|---|
| c41ad0afc6 | |||
| 3d3ce694b9 | |||
| d42b8b4a47 | |||
| 513871ef96 | |||
| eaed52006f |
+112
-11
@@ -266,43 +266,144 @@ def daemon() -> None:
|
||||
_bootstrap_or_exit()
|
||||
|
||||
|
||||
@daemon.command("run", hidden=True)
|
||||
def daemon_run() -> None:
|
||||
"""Run daemon in foreground (used by service manager)."""
|
||||
from pyra.daemon.core import run_foreground
|
||||
run_foreground()
|
||||
|
||||
|
||||
@daemon.command("start")
|
||||
def daemon_start() -> None:
|
||||
"""Start the Pyra daemon in the background."""
|
||||
console.print("[yellow]Daemon (Stage 6) is not yet implemented.[/yellow]")
|
||||
from pyra.daemon.core import start_background
|
||||
try:
|
||||
start_background()
|
||||
except FileNotFoundError:
|
||||
console.print("[red]Error:[/red] Run [dim]pyra setup[/dim] first.")
|
||||
|
||||
|
||||
@daemon.command("stop")
|
||||
def daemon_stop() -> None:
|
||||
"""Stop the running Pyra daemon."""
|
||||
console.print("[yellow]Daemon (Stage 6) is not yet implemented.[/yellow]")
|
||||
_daemon_ipc("stop", success_msg="Daemon stopped.")
|
||||
|
||||
|
||||
@daemon.command("status")
|
||||
def daemon_status() -> None:
|
||||
"""Show daemon status."""
|
||||
console.print("[yellow]Daemon (Stage 6) is not yet implemented.[/yellow]")
|
||||
_daemon_ipc("status")
|
||||
|
||||
|
||||
@daemon.command("restart")
|
||||
def daemon_restart() -> None:
|
||||
"""Restart the Pyra daemon."""
|
||||
console.print("[yellow]Daemon (Stage 6) is not yet implemented.[/yellow]")
|
||||
import time
|
||||
from pyra.daemon.core import start_background
|
||||
_daemon_ipc("stop", success_msg=None)
|
||||
time.sleep(1.5)
|
||||
try:
|
||||
start_background()
|
||||
except FileNotFoundError:
|
||||
console.print("[red]Error:[/red] Run [dim]pyra setup[/dim] first.")
|
||||
|
||||
|
||||
@daemon.command("install")
|
||||
def daemon_install() -> None:
|
||||
"""Install Pyra as a system service (launchd/systemd)."""
|
||||
console.print("[yellow]Daemon service install (Stage 6) is not yet implemented.[/yellow]")
|
||||
"""Install Pyra as a system service (launchd/systemd/schtasks)."""
|
||||
from pyra.daemon.service import detect_platform, install_service
|
||||
try:
|
||||
install_service()
|
||||
console.print(f"[green]Service installed[/green] ({detect_platform()}).")
|
||||
except Exception as exc:
|
||||
console.print(f"[red]Install failed:[/red] {exc}")
|
||||
|
||||
|
||||
@daemon.command("uninstall")
|
||||
def daemon_uninstall() -> None:
|
||||
"""Remove the Pyra system service."""
|
||||
console.print("[yellow]Daemon service uninstall (Stage 6) is not yet implemented.[/yellow]")
|
||||
from pyra.daemon.service import uninstall_service
|
||||
try:
|
||||
uninstall_service()
|
||||
console.print("[green]Service removed.[/green]")
|
||||
except Exception as exc:
|
||||
console.print(f"[red]Uninstall failed:[/red] {exc}")
|
||||
|
||||
|
||||
@daemon.command("run", hidden=True)
|
||||
def daemon_run() -> None:
|
||||
"""Run daemon in foreground (used by service manager)."""
|
||||
console.print("[yellow]Daemon (Stage 6) is not yet implemented.[/yellow]")
|
||||
def _daemon_ipc(cmd: str, *, success_msg: str | None = None) -> None:
|
||||
"""Send a command to the running daemon via IPC and render the response."""
|
||||
from pyra.config.manager import load_config
|
||||
from pyra.daemon.ipc import (
|
||||
get_socket_path,
|
||||
is_unix_socket,
|
||||
get_port_file_path,
|
||||
send_command,
|
||||
)
|
||||
|
||||
try:
|
||||
cfg = load_config()
|
||||
except FileNotFoundError:
|
||||
console.print("[red]Error:[/red] Run [dim]pyra setup[/dim] first.")
|
||||
return
|
||||
|
||||
if is_unix_socket():
|
||||
address = get_socket_path(cfg.daemon.socket_path)
|
||||
else:
|
||||
port = _read_windows_port()
|
||||
if port is None:
|
||||
console.print("[yellow]Daemon is not running.[/yellow]")
|
||||
return
|
||||
address = ("127.0.0.1", port)
|
||||
|
||||
try:
|
||||
resp = send_command(address, {"cmd": cmd})
|
||||
except (ConnectionRefusedError, FileNotFoundError, OSError):
|
||||
console.print("[yellow]Daemon is not running.[/yellow]")
|
||||
return
|
||||
except ConnectionResetError:
|
||||
console.print("[red]Permission denied:[/red] daemon rejected connection.")
|
||||
return
|
||||
except TimeoutError:
|
||||
console.print("[red]Daemon did not respond in time.[/red]")
|
||||
return
|
||||
|
||||
if not resp.get("ok"):
|
||||
console.print(f"[red]Error:[/red] {resp.get('data', {}).get('error', 'unknown')}")
|
||||
return
|
||||
|
||||
if cmd == "status":
|
||||
_render_daemon_status(resp["data"])
|
||||
elif success_msg:
|
||||
console.print(f"[green]{success_msg}[/green]")
|
||||
|
||||
|
||||
def _read_windows_port() -> int | None:
|
||||
from pyra.daemon.ipc import get_port_file_path
|
||||
try:
|
||||
return int(get_port_file_path().read_text().strip())
|
||||
except (FileNotFoundError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _render_daemon_status(data: dict) -> None:
|
||||
from rich.table import Table
|
||||
|
||||
uptime = data.get("uptime", 0.0)
|
||||
pid = data.get("pid", "?")
|
||||
tasks = data.get("tasks", [])
|
||||
|
||||
hours, rem = divmod(int(uptime), 3600)
|
||||
mins, secs = divmod(rem, 60)
|
||||
uptime_str = f"{hours}h {mins}m {secs}s" if hours else f"{mins}m {secs}s"
|
||||
|
||||
console.print(f"[bold green]Daemon running[/bold green] — PID {pid}, uptime {uptime_str}")
|
||||
|
||||
if tasks:
|
||||
table = Table("Task", "Alive", "Restarts", "Last error", show_header=True)
|
||||
for t in tasks:
|
||||
alive = "[green]yes[/green]" if t.get("alive") else "[red]no[/red]"
|
||||
error = t.get("last_error") or "—"
|
||||
table.add_row(t.get("name", "?"), alive, str(t.get("restart_count", 0)), error)
|
||||
console.print(table)
|
||||
else:
|
||||
console.print("[dim]No plugin tasks registered.[/dim]")
|
||||
|
||||
@@ -36,6 +36,7 @@ class DaemonConfig(BaseModel):
|
||||
socket_path: str = "~/.pyra/daemon.sock"
|
||||
log_file: str = "~/.pyra/daemon.log"
|
||||
pid_file: str = "~/.pyra/daemon.pid"
|
||||
ipc_port: int = 0 # Windows TCP loopback: 0 = OS-assigned, written to ~/.pyra/daemon.port
|
||||
|
||||
|
||||
class PyraConfig(BaseModel):
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
"""Pyra background daemon package."""
|
||||
|
||||
from pyra.daemon.core import PluginSupervisor, run_foreground, start_background
|
||||
from pyra.daemon.ipc import IpcClient, IpcServer, send_command
|
||||
from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path
|
||||
from pyra.daemon.service import detect_platform, install_service, uninstall_service
|
||||
|
||||
__all__ = [
|
||||
"run_foreground",
|
||||
"start_background",
|
||||
"PluginSupervisor",
|
||||
"IpcClient",
|
||||
"IpcServer",
|
||||
"send_command",
|
||||
"PidFile",
|
||||
"PidFileError",
|
||||
"resolve_pid_path",
|
||||
"detect_platform",
|
||||
"install_service",
|
||||
"uninstall_service",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,291 @@
|
||||
"""Pyra daemon core — asyncio event loop, plugin task supervisor, signal handling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable, Coroutine
|
||||
|
||||
from pyra.utils.paths import pyra_home, safe_chmod
|
||||
|
||||
|
||||
_log = logging.getLogger("pyra.daemon")
|
||||
_start_time: float = 0.0
|
||||
|
||||
|
||||
# ── Plugin task supervisor ────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class TaskRecord:
|
||||
name: str
|
||||
coro_factory: Callable[[], Coroutine] # type: ignore[type-arg]
|
||||
task: asyncio.Task | None = field(default=None, repr=False)
|
||||
restart_count: int = 0
|
||||
last_error: str | None = None
|
||||
|
||||
def is_alive(self) -> bool:
|
||||
return self.task is not None and not self.task.done()
|
||||
|
||||
|
||||
class PluginSupervisor:
|
||||
_RESTART_DELAY: float = 5.0
|
||||
_MAX_RESTARTS: int = 10
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._records: list[TaskRecord] = []
|
||||
self._shutdown = asyncio.Event()
|
||||
|
||||
def add_task(self, name: str, factory: Callable[[], Coroutine]) -> None: # type: ignore[type-arg]
|
||||
self._records.append(TaskRecord(name=name, coro_factory=factory))
|
||||
|
||||
async def start(self) -> None:
|
||||
for record in self._records:
|
||||
record.task = asyncio.create_task(
|
||||
self._supervise(record), name=record.name
|
||||
)
|
||||
_log.info("Supervisor started with %d plugin task(s).", len(self._records))
|
||||
|
||||
async def run_until_shutdown(self) -> None:
|
||||
await self._shutdown.wait()
|
||||
_log.info("Shutdown requested — stopping supervisor.")
|
||||
|
||||
def request_shutdown(self) -> None:
|
||||
self._shutdown.set()
|
||||
|
||||
async def stop(self) -> None:
|
||||
for record in self._records:
|
||||
if record.task and not record.task.done():
|
||||
record.task.cancel()
|
||||
tasks = [r.task for r in self._records if r.task]
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
def status(self) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"name": r.name,
|
||||
"alive": r.is_alive(),
|
||||
"restart_count": r.restart_count,
|
||||
"last_error": r.last_error,
|
||||
}
|
||||
for r in self._records
|
||||
]
|
||||
|
||||
async def _supervise(self, record: TaskRecord) -> None:
|
||||
while not self._shutdown.is_set():
|
||||
try:
|
||||
await record.coro_factory()
|
||||
_log.info("Plugin task %s completed normally.", record.name)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
record.restart_count += 1
|
||||
record.last_error = f"{type(exc).__name__}: {exc}"
|
||||
_log.error(
|
||||
"Plugin task %s crashed (restart #%d): %s",
|
||||
record.name, record.restart_count, exc,
|
||||
exc_info=True,
|
||||
)
|
||||
if record.restart_count >= self._MAX_RESTARTS:
|
||||
_log.critical(
|
||||
"Plugin task %s exceeded max restarts (%d). Giving up.",
|
||||
record.name, self._MAX_RESTARTS,
|
||||
)
|
||||
return
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.sleep(self._RESTART_DELAY),
|
||||
timeout=self._RESTART_DELAY + 1,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
|
||||
# ── IPC command dispatch ──────────────────────────────────────────────────────
|
||||
|
||||
def _make_ipc_handler(supervisor: PluginSupervisor):
|
||||
async def handler(msg: dict) -> dict:
|
||||
cmd = msg.get("cmd", "")
|
||||
match cmd:
|
||||
case "ping":
|
||||
return {"ok": True, "data": {"pong": True}}
|
||||
case "status":
|
||||
return {
|
||||
"ok": True,
|
||||
"data": {
|
||||
"uptime": time.monotonic() - _start_time,
|
||||
"pid": os.getpid(),
|
||||
"tasks": supervisor.status(),
|
||||
},
|
||||
}
|
||||
case "stop":
|
||||
supervisor.request_shutdown()
|
||||
return {"ok": True, "data": {}}
|
||||
case "reload":
|
||||
_log.info("Reload requested via IPC.")
|
||||
return {"ok": True, "data": {}}
|
||||
case _:
|
||||
return {"ok": False, "data": {"error": f"unknown command: {cmd}"}}
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
# ── Main async entrypoint ─────────────────────────────────────────────────────
|
||||
|
||||
async def _run_daemon(cfg, supervisor: PluginSupervisor) -> None:
|
||||
from pyra.daemon.ipc import IpcServer, get_socket_path, is_unix_socket
|
||||
|
||||
if is_unix_socket():
|
||||
address = get_socket_path(cfg.daemon.socket_path)
|
||||
else:
|
||||
address = ("127.0.0.1", cfg.daemon.ipc_port)
|
||||
|
||||
server = IpcServer(address, _make_ipc_handler(supervisor))
|
||||
|
||||
await supervisor.start()
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
tg.create_task(server.start(), name="ipc_server")
|
||||
tg.create_task(supervisor.run_until_shutdown(), name="shutdown_waiter")
|
||||
|
||||
await server.stop()
|
||||
await supervisor.stop()
|
||||
|
||||
|
||||
# ── Foreground entry point (pyra daemon run) ──────────────────────────────────
|
||||
|
||||
def run_foreground() -> None:
|
||||
"""Run the daemon in the foreground. Called by `pyra daemon run`."""
|
||||
from pyra.config.manager import load_config
|
||||
from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path
|
||||
from pyra.plugins.registry import PluginRegistry
|
||||
|
||||
global _start_time
|
||||
|
||||
cfg = load_config()
|
||||
_setup_logging(cfg.daemon.log_file)
|
||||
pid_path = resolve_pid_path(cfg.daemon.pid_file)
|
||||
pid_file = PidFile(pid_path)
|
||||
|
||||
existing = pid_file.read()
|
||||
if existing is not None and not pid_file.is_stale():
|
||||
_log.error("Daemon already running (PID %d). Exiting.", existing)
|
||||
sys.exit(1)
|
||||
|
||||
registry = PluginRegistry()
|
||||
from pyra.utils.paths import pyra_home as _pyra_home
|
||||
plugins_dir = _pyra_home() / "plugins"
|
||||
if plugins_dir.exists():
|
||||
registry.load_all(plugins_dir, cfg.plugins.enabled)
|
||||
|
||||
supervisor = PluginSupervisor()
|
||||
for name, factory in registry.get_daemon_task_factories():
|
||||
supervisor.add_task(name, factory)
|
||||
|
||||
_start_time = time.monotonic()
|
||||
|
||||
with pid_file:
|
||||
_install_signal_handlers(supervisor)
|
||||
_log.info("Pyra daemon starting (PID %d).", os.getpid())
|
||||
try:
|
||||
asyncio.run(_run_daemon(cfg, supervisor))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
_log.info("Pyra daemon stopped.")
|
||||
|
||||
|
||||
# ── Background spawn (pyra daemon start) ─────────────────────────────────────
|
||||
|
||||
def start_background() -> None:
|
||||
"""Spawn `pyra daemon run` as a detached background process."""
|
||||
from pyra.config.manager import load_config
|
||||
from pyra.daemon.pid import PidFile, resolve_pid_path
|
||||
from pyra.daemon.service import find_pyra_executable
|
||||
|
||||
cfg = load_config()
|
||||
pid_path = resolve_pid_path(cfg.daemon.pid_file)
|
||||
pid_file = PidFile(pid_path)
|
||||
|
||||
existing = pid_file.read()
|
||||
if existing is not None and not pid_file.is_stale():
|
||||
from pyra.chat.renderer import console
|
||||
console.print(f"[yellow]Daemon already running (PID {existing}).[/yellow]")
|
||||
return
|
||||
|
||||
exe = find_pyra_executable()
|
||||
log_path = Path(cfg.daemon.log_file).expanduser()
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(log_path, "a") as log_fh:
|
||||
if sys.platform == "win32":
|
||||
DETACHED_PROCESS = 0x00000008
|
||||
CREATE_NEW_PROCESS_GROUP = 0x00000200
|
||||
subprocess.Popen(
|
||||
[exe, "daemon", "run"],
|
||||
creationflags=DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP,
|
||||
stdout=log_fh,
|
||||
stderr=log_fh,
|
||||
close_fds=True,
|
||||
)
|
||||
else:
|
||||
subprocess.Popen(
|
||||
[exe, "daemon", "run"],
|
||||
start_new_session=True,
|
||||
stdout=log_fh,
|
||||
stderr=log_fh,
|
||||
stdin=subprocess.DEVNULL,
|
||||
close_fds=True,
|
||||
)
|
||||
|
||||
from pyra.chat.renderer import console
|
||||
|
||||
# Poll the PID file for up to 3 seconds to confirm startup.
|
||||
for _ in range(30):
|
||||
time.sleep(0.1)
|
||||
pid = pid_file.read()
|
||||
if pid is not None:
|
||||
console.print(f"[green]Daemon started (PID {pid}).[/green]")
|
||||
return
|
||||
|
||||
console.print("[yellow]Daemon process spawned but PID file not yet written.[/yellow]")
|
||||
|
||||
|
||||
# ── Signal handling ───────────────────────────────────────────────────────────
|
||||
|
||||
def _install_signal_handlers(supervisor: PluginSupervisor) -> None:
|
||||
if sys.platform == "win32":
|
||||
signal.signal(signal.SIGTERM, lambda *_: supervisor.request_shutdown())
|
||||
return
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.add_signal_handler(signal.SIGTERM, supervisor.request_shutdown)
|
||||
loop.add_signal_handler(signal.SIGHUP, supervisor.request_shutdown)
|
||||
|
||||
|
||||
# ── Logging setup ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _setup_logging(log_file_str: str) -> None:
|
||||
log_path = Path(log_file_str).expanduser()
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
handler = logging.handlers.RotatingFileHandler(
|
||||
log_path, maxBytes=5 * 1024 * 1024, backupCount=3
|
||||
)
|
||||
handler.setFormatter(
|
||||
logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
)
|
||||
|
||||
root = logging.getLogger("pyra")
|
||||
root.addHandler(handler)
|
||||
root.setLevel(logging.INFO)
|
||||
|
||||
safe_chmod(log_path, 0o600)
|
||||
@@ -0,0 +1,241 @@
|
||||
"""IPC transport for the Pyra daemon.
|
||||
|
||||
Linux/macOS: Unix domain socket at ~/.pyra/daemon.sock (chmod 600, UID-checked).
|
||||
Windows: TCP loopback on an OS-assigned port; actual port written to
|
||||
~/.pyra/daemon.port so clients can connect without knowing it ahead
|
||||
of time.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
|
||||
# ── Protocol types ────────────────────────────────────────────────────────────
|
||||
|
||||
IpcMessage = dict[str, Any] # must have "cmd" key
|
||||
IpcResponse = dict[str, Any] # must have "ok" and "data" keys
|
||||
|
||||
|
||||
# ── Encode / decode ───────────────────────────────────────────────────────────
|
||||
|
||||
def encode_message(msg: IpcMessage) -> bytes:
|
||||
return (json.dumps(msg) + "\n").encode()
|
||||
|
||||
|
||||
def decode_message(line: bytes) -> IpcMessage:
|
||||
try:
|
||||
return json.loads(line.rstrip(b"\n"))
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(f"Invalid IPC message: {exc}") from exc
|
||||
|
||||
|
||||
# ── Address helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def is_unix_socket() -> bool:
|
||||
return sys.platform != "win32"
|
||||
|
||||
|
||||
def get_socket_path(cfg_socket_path: str) -> Path:
|
||||
"""Expand ~ and return the Unix socket path."""
|
||||
return Path(cfg_socket_path).expanduser()
|
||||
|
||||
|
||||
def get_port_file_path() -> Path:
|
||||
from pyra.utils.paths import pyra_home
|
||||
return pyra_home() / "daemon.port"
|
||||
|
||||
|
||||
def _read_windows_port() -> int | None:
|
||||
port_file = get_port_file_path()
|
||||
try:
|
||||
return int(port_file.read_text().strip())
|
||||
except (FileNotFoundError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
# ── Server ────────────────────────────────────────────────────────────────────
|
||||
|
||||
class IpcServer:
|
||||
def __init__(
|
||||
self,
|
||||
address: Path | tuple[str, int],
|
||||
handler: Callable[[IpcMessage], Awaitable[IpcResponse]],
|
||||
) -> None:
|
||||
self._address = address
|
||||
self._handler = handler
|
||||
self._server: asyncio.Server | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
if is_unix_socket():
|
||||
assert isinstance(self._address, Path)
|
||||
sock_path = self._address
|
||||
if sock_path.exists():
|
||||
sock_path.unlink()
|
||||
self._server = await asyncio.start_unix_server(
|
||||
self._handle_client, path=str(sock_path)
|
||||
)
|
||||
os.chmod(sock_path, 0o600)
|
||||
else:
|
||||
host, port = self._address if isinstance(self._address, tuple) else ("127.0.0.1", 0)
|
||||
self._server = await asyncio.start_server(
|
||||
self._handle_client, host=host, port=port
|
||||
)
|
||||
actual_port = self._server.sockets[0].getsockname()[1]
|
||||
port_file = get_port_file_path()
|
||||
port_file.write_text(str(actual_port))
|
||||
|
||||
await self._server.start_serving()
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._server is not None:
|
||||
self._server.close()
|
||||
try:
|
||||
await asyncio.wait_for(self._server.wait_closed(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
if is_unix_socket() and isinstance(self._address, Path):
|
||||
try:
|
||||
self._address.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
else:
|
||||
port_file = get_port_file_path()
|
||||
try:
|
||||
port_file.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
async def _handle_client(
|
||||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
try:
|
||||
if is_unix_socket() and not self._check_peer_uid(writer):
|
||||
writer.close()
|
||||
return
|
||||
|
||||
line = await asyncio.wait_for(reader.readline(), timeout=5.0)
|
||||
if not line:
|
||||
return
|
||||
|
||||
try:
|
||||
msg = decode_message(line)
|
||||
except ValueError:
|
||||
resp: IpcResponse = {"ok": False, "data": {"error": "invalid JSON"}}
|
||||
else:
|
||||
resp = await self._handler(msg)
|
||||
|
||||
writer.write(encode_message(resp))
|
||||
await writer.drain()
|
||||
except (asyncio.TimeoutError, ConnectionResetError, BrokenPipeError):
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _check_peer_uid(self, writer: asyncio.StreamWriter) -> bool:
|
||||
"""Return True if the peer's UID matches ours. Falls back to True on error."""
|
||||
try:
|
||||
peer_uid = _get_peer_uid(writer)
|
||||
if peer_uid is None:
|
||||
return True # can't determine — allow (socket perms are the guard)
|
||||
return peer_uid == os.getuid()
|
||||
except Exception:
|
||||
return True # don't crash the server on unexpected errors
|
||||
|
||||
|
||||
# ── Client ────────────────────────────────────────────────────────────────────
|
||||
|
||||
class IpcClient:
|
||||
def __init__(self, address: Path | tuple[str, int]) -> None:
|
||||
self._address = address
|
||||
|
||||
async def send(self, msg: IpcMessage, timeout: float = 5.0) -> IpcResponse:
|
||||
if is_unix_socket():
|
||||
assert isinstance(self._address, Path)
|
||||
reader, writer = await asyncio.wait_for(
|
||||
asyncio.open_unix_connection(str(self._address)), timeout=timeout
|
||||
)
|
||||
else:
|
||||
host, port = self._address if isinstance(self._address, tuple) else ("127.0.0.1", 0)
|
||||
reader, writer = await asyncio.wait_for(
|
||||
asyncio.open_connection(host, port), timeout=timeout
|
||||
)
|
||||
|
||||
try:
|
||||
writer.write(encode_message(msg))
|
||||
await writer.drain()
|
||||
line = await asyncio.wait_for(reader.readline(), timeout=timeout)
|
||||
return decode_message(line)
|
||||
finally:
|
||||
writer.close()
|
||||
try:
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def send_command(
|
||||
address: Path | tuple[str, int],
|
||||
msg: IpcMessage,
|
||||
timeout: float = 5.0,
|
||||
) -> IpcResponse:
|
||||
"""Synchronous wrapper around IpcClient.send() for CLI callers."""
|
||||
return asyncio.run(IpcClient(address).send(msg, timeout=timeout))
|
||||
|
||||
|
||||
# ── Peer UID detection ────────────────────────────────────────────────────────
|
||||
|
||||
def _get_peer_uid(writer: asyncio.StreamWriter) -> int | None:
|
||||
"""Return the connecting peer's UID, or None if unavailable."""
|
||||
try:
|
||||
sock = writer.get_extra_info("socket")
|
||||
if sock is None:
|
||||
return None
|
||||
|
||||
if sys.platform == "linux":
|
||||
# SO_PEERCRED: struct { pid_t pid; uid_t uid; gid_t gid; }
|
||||
SO_PEERCRED = 17
|
||||
cred = sock.getsockopt(
|
||||
socket_module().SOL_SOCKET, SO_PEERCRED, struct.calcsize("3i")
|
||||
)
|
||||
_pid, uid, _gid = struct.unpack("3i", cred)
|
||||
return uid
|
||||
|
||||
if sys.platform == "darwin":
|
||||
return _macos_peer_uid(sock.fileno())
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def socket_module(): # lazy import to avoid top-level import on non-Unix
|
||||
import socket
|
||||
return socket
|
||||
|
||||
|
||||
def _macos_peer_uid(fd: int) -> int | None:
|
||||
"""Use getpeereid(2) via ctypes to retrieve the peer UID on macOS."""
|
||||
import ctypes
|
||||
import ctypes.util
|
||||
|
||||
libc_name = ctypes.util.find_library("c")
|
||||
if not libc_name:
|
||||
return None
|
||||
libc = ctypes.CDLL(libc_name)
|
||||
|
||||
euid = ctypes.c_uint32(0)
|
||||
egid = ctypes.c_uint32(0)
|
||||
if libc.getpeereid(fd, ctypes.byref(euid), ctypes.byref(egid)) != 0:
|
||||
return None
|
||||
return euid.value
|
||||
@@ -0,0 +1,94 @@
|
||||
"""PID file management for the Pyra daemon."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class PidFileError(OSError):
|
||||
"""Raised when a PID file operation fails due to a live conflicting process."""
|
||||
|
||||
|
||||
class PidFile:
|
||||
def __init__(self, path: Path) -> None:
|
||||
self._path = path
|
||||
|
||||
def write(self) -> None:
|
||||
"""Write the current PID atomically.
|
||||
|
||||
Raises PidFileError if a non-stale PID file already exists.
|
||||
"""
|
||||
existing = self.read()
|
||||
if existing is not None and not self.is_stale():
|
||||
raise PidFileError(
|
||||
f"Daemon already running with PID {existing} "
|
||||
f"(PID file: {self._path})"
|
||||
)
|
||||
tmp = self._path.with_suffix(".pid.tmp")
|
||||
tmp.write_text(str(os.getpid()))
|
||||
tmp.replace(self._path)
|
||||
|
||||
def read(self) -> int | None:
|
||||
"""Return the PID from the file, or None if the file is absent or unreadable."""
|
||||
try:
|
||||
return int(self._path.read_text().strip())
|
||||
except (FileNotFoundError, ValueError):
|
||||
return None
|
||||
|
||||
def is_stale(self) -> bool:
|
||||
"""True when the PID file exists but the process no longer runs."""
|
||||
pid = self.read()
|
||||
if pid is None:
|
||||
return False
|
||||
return not _process_is_alive(pid)
|
||||
|
||||
def remove(self) -> None:
|
||||
"""Delete the PID file, ignoring FileNotFoundError."""
|
||||
try:
|
||||
self._path.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def __enter__(self) -> "PidFile":
|
||||
self.write()
|
||||
return self
|
||||
|
||||
def __exit__(self, *_: object) -> None:
|
||||
self.remove()
|
||||
|
||||
|
||||
def resolve_pid_path(cfg_path: str) -> Path:
|
||||
"""Expand ~ and return an absolute Path."""
|
||||
return Path(cfg_path).expanduser().resolve()
|
||||
|
||||
|
||||
# ── Platform-specific process liveness check ─────────────────────────────────
|
||||
|
||||
def _process_is_alive(pid: int) -> bool:
|
||||
if sys.platform == "win32":
|
||||
return _win_process_is_alive(pid)
|
||||
return _posix_process_is_alive(pid)
|
||||
|
||||
|
||||
def _posix_process_is_alive(pid: int) -> bool:
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
return True
|
||||
except ProcessLookupError:
|
||||
return False
|
||||
except PermissionError:
|
||||
# Process exists but is owned by another user — still alive.
|
||||
return True
|
||||
|
||||
|
||||
def _win_process_is_alive(pid: int) -> bool:
|
||||
import ctypes
|
||||
|
||||
SYNCHRONIZE = 0x00100000
|
||||
handle = ctypes.windll.kernel32.OpenProcess(SYNCHRONIZE, False, pid) # type: ignore[attr-defined]
|
||||
if handle == 0:
|
||||
return False
|
||||
ctypes.windll.kernel32.CloseHandle(handle) # type: ignore[attr-defined]
|
||||
return True
|
||||
@@ -0,0 +1,212 @@
|
||||
"""OS-specific service file generation and install/uninstall for the Pyra daemon."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from pyra.utils.paths import safe_chmod
|
||||
|
||||
|
||||
def detect_platform() -> Literal["macos", "linux", "windows"]:
|
||||
s = platform.system()
|
||||
if s == "Darwin":
|
||||
return "macos"
|
||||
if s == "Linux":
|
||||
return "linux"
|
||||
if s == "Windows":
|
||||
return "windows"
|
||||
raise RuntimeError(f"Unsupported platform: {s}")
|
||||
|
||||
|
||||
def find_pyra_executable() -> str:
|
||||
"""Return the full path to the active pyra executable.
|
||||
|
||||
Tries, in order:
|
||||
1. shutil.which("pyra") — works when pyra is on PATH (activated venv)
|
||||
2. sys.executable's sibling "pyra" script — covers editable installs
|
||||
3. Fallback: sys.executable -m pyra
|
||||
"""
|
||||
found = shutil.which("pyra")
|
||||
if found:
|
||||
return found
|
||||
|
||||
sibling = Path(sys.executable).parent / "pyra"
|
||||
if sibling.exists():
|
||||
return str(sibling)
|
||||
|
||||
return f"{sys.executable} -m pyra"
|
||||
|
||||
|
||||
# ── Template generators ───────────────────────────────────────────────────────
|
||||
|
||||
def render_launchd_plist(exe: str, log_file: str, pid_file: str) -> str:
|
||||
log = str(Path(log_file).expanduser())
|
||||
return f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN"
|
||||
"http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>com.pyra.daemon</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>{exe}</string>
|
||||
<string>daemon</string>
|
||||
<string>run</string>
|
||||
</array>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>KeepAlive</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>{log}</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>{log}</string>
|
||||
<key>ProcessType</key>
|
||||
<string>Background</string>
|
||||
</dict>
|
||||
</plist>
|
||||
"""
|
||||
|
||||
|
||||
def render_systemd_unit(exe: str, log_file: str) -> str:
|
||||
log = str(Path(log_file).expanduser())
|
||||
return f"""[Unit]
|
||||
Description=Pyra Personal AI Assistant Daemon
|
||||
After=default.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart={exe} daemon run
|
||||
Restart=on-failure
|
||||
RestartSec=5s
|
||||
StandardOutput=append:{log}
|
||||
StandardError=append:{log}
|
||||
|
||||
[Install]
|
||||
WantedBy=default.target
|
||||
"""
|
||||
|
||||
|
||||
def render_schtasks_xml(exe: str) -> str:
|
||||
return f"""<?xml version="1.0" encoding="UTF-16"?>
|
||||
<Task version="1.2" xmlns="http://schemas.microsoft.com/windows/2004/02/mit/task">
|
||||
<RegistrationInfo>
|
||||
<Description>Pyra Personal AI Assistant Daemon</Description>
|
||||
</RegistrationInfo>
|
||||
<Triggers>
|
||||
<LogonTrigger>
|
||||
<Enabled>true</Enabled>
|
||||
</LogonTrigger>
|
||||
</Triggers>
|
||||
<Settings>
|
||||
<MultipleInstancesPolicy>IgnoreNew</MultipleInstancesPolicy>
|
||||
<DisallowStartIfOnBatteries>false</DisallowStartIfOnBatteries>
|
||||
<StopIfGoingOnBatteries>false</StopIfGoingOnBatteries>
|
||||
<ExecutionTimeLimit>PT0S</ExecutionTimeLimit>
|
||||
<RestartOnFailure>
|
||||
<Interval>PT1M</Interval>
|
||||
<Count>999</Count>
|
||||
</RestartOnFailure>
|
||||
</Settings>
|
||||
<Actions Context="Author">
|
||||
<Exec>
|
||||
<Command>{exe}</Command>
|
||||
<Arguments>daemon run</Arguments>
|
||||
</Exec>
|
||||
</Actions>
|
||||
</Task>
|
||||
"""
|
||||
|
||||
|
||||
# ── Install / uninstall ───────────────────────────────────────────────────────
|
||||
|
||||
def install_service() -> None:
|
||||
"""Generate and register the OS service for the current platform."""
|
||||
from pyra.config.manager import load_config
|
||||
|
||||
cfg = load_config()
|
||||
exe = find_pyra_executable()
|
||||
plat = detect_platform()
|
||||
|
||||
if plat == "macos":
|
||||
_install_launchd(exe, cfg.daemon.log_file, cfg.daemon.pid_file)
|
||||
elif plat == "linux":
|
||||
_install_systemd(exe, cfg.daemon.log_file)
|
||||
else:
|
||||
_install_windows(exe)
|
||||
|
||||
|
||||
def uninstall_service() -> None:
|
||||
"""Deregister the OS service for the current platform."""
|
||||
plat = detect_platform()
|
||||
if plat == "macos":
|
||||
_uninstall_launchd()
|
||||
elif plat == "linux":
|
||||
_uninstall_systemd()
|
||||
else:
|
||||
_uninstall_windows()
|
||||
|
||||
|
||||
# ── macOS launchd ─────────────────────────────────────────────────────────────
|
||||
|
||||
_PLIST_PATH = Path.home() / "Library" / "LaunchAgents" / "com.pyra.daemon.plist"
|
||||
|
||||
|
||||
def _install_launchd(exe: str, log_file: str, pid_file: str) -> None:
|
||||
_PLIST_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
_PLIST_PATH.write_text(render_launchd_plist(exe, log_file, pid_file))
|
||||
safe_chmod(_PLIST_PATH, 0o644) # launchd requires 644, not 600
|
||||
subprocess.run(["launchctl", "load", str(_PLIST_PATH)], check=True)
|
||||
|
||||
|
||||
def _uninstall_launchd() -> None:
|
||||
if _PLIST_PATH.exists():
|
||||
subprocess.run(["launchctl", "unload", str(_PLIST_PATH)], check=False)
|
||||
_PLIST_PATH.unlink()
|
||||
|
||||
|
||||
# ── Linux systemd ─────────────────────────────────────────────────────────────
|
||||
|
||||
_SYSTEMD_UNIT = Path.home() / ".config" / "systemd" / "user" / "pyra.service"
|
||||
|
||||
|
||||
def _install_systemd(exe: str, log_file: str) -> None:
|
||||
_SYSTEMD_UNIT.parent.mkdir(parents=True, exist_ok=True)
|
||||
_SYSTEMD_UNIT.write_text(render_systemd_unit(exe, log_file))
|
||||
subprocess.run(["systemctl", "--user", "daemon-reload"], check=True)
|
||||
subprocess.run(["systemctl", "--user", "enable", "pyra"], check=True)
|
||||
|
||||
|
||||
def _uninstall_systemd() -> None:
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "disable", "--now", "pyra"], check=False
|
||||
)
|
||||
if _SYSTEMD_UNIT.exists():
|
||||
_SYSTEMD_UNIT.unlink()
|
||||
subprocess.run(["systemctl", "--user", "daemon-reload"], check=False)
|
||||
|
||||
|
||||
# ── Windows Task Scheduler ────────────────────────────────────────────────────
|
||||
|
||||
def _install_windows(exe: str) -> None:
|
||||
from pyra.utils.paths import pyra_home
|
||||
|
||||
xml_path = pyra_home() / "daemon_task.xml"
|
||||
# schtasks /Create /XML requires UTF-16 encoding
|
||||
xml_path.write_text(render_schtasks_xml(exe), encoding="utf-16")
|
||||
subprocess.run(
|
||||
["schtasks", "/Create", "/TN", "PyraAssistant", "/XML", str(xml_path), "/F"],
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
def _uninstall_windows() -> None:
|
||||
subprocess.run(
|
||||
["schtasks", "/Delete", "/TN", "PyraAssistant", "/F"], check=False
|
||||
)
|
||||
@@ -75,6 +75,29 @@ class PluginRegistry:
|
||||
pass
|
||||
return tasks
|
||||
|
||||
def get_daemon_task_factories(
|
||||
self,
|
||||
) -> list[tuple[str, Callable[[], Coroutine]]]: # type: ignore[type-arg]
|
||||
"""Return (name, factory) pairs for all plugin daemon tasks.
|
||||
|
||||
Each factory re-calls plugin.daemon_tasks() to produce a fresh coroutine,
|
||||
enabling the supervisor to restart crashed tasks without changing the plugin
|
||||
protocol.
|
||||
"""
|
||||
factories: list[tuple[str, Callable[[], Coroutine]]] = [] # type: ignore[type-arg]
|
||||
for plugin in self._plugins.values():
|
||||
try:
|
||||
n_tasks = len(plugin.daemon_tasks())
|
||||
except Exception:
|
||||
continue
|
||||
for i in range(n_tasks):
|
||||
name = f"{plugin.name}.task_{i}"
|
||||
# Capture plugin and index by value so each closure is independent.
|
||||
def _factory(p=plugin, idx=i) -> Coroutine: # type: ignore[type-arg]
|
||||
return p.daemon_tasks()[idx]
|
||||
factories.append((name, _factory))
|
||||
return factories
|
||||
|
||||
def find_tool(self, name: str) -> Tool | None:
|
||||
return self._tools.get(name)
|
||||
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
"""Unit tests for the IPC layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from pyra.daemon.ipc import (
|
||||
IpcClient,
|
||||
IpcMessage,
|
||||
IpcResponse,
|
||||
IpcServer,
|
||||
decode_message,
|
||||
encode_message,
|
||||
is_unix_socket,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sock_path():
|
||||
"""Short socket path that fits within macOS's 104-char AF_UNIX limit."""
|
||||
with tempfile.TemporaryDirectory(dir="/tmp") as d:
|
||||
yield Path(d) / "t.sock"
|
||||
|
||||
|
||||
# ── Protocol encode / decode ──────────────────────────────────────────────────
|
||||
|
||||
def test_encode_appends_newline() -> None:
|
||||
data = encode_message({"cmd": "ping"})
|
||||
assert data.endswith(b"\n")
|
||||
|
||||
|
||||
def test_encode_is_valid_json() -> None:
|
||||
import json
|
||||
data = encode_message({"cmd": "status", "extra": 42})
|
||||
assert json.loads(data) == {"cmd": "status", "extra": 42}
|
||||
|
||||
|
||||
def test_decode_roundtrip() -> None:
|
||||
msg: IpcMessage = {"cmd": "stop"}
|
||||
assert decode_message(encode_message(msg)) == msg
|
||||
|
||||
|
||||
def test_decode_strips_newline() -> None:
|
||||
assert decode_message(b'{"cmd": "stop"}\n')["cmd"] == "stop"
|
||||
|
||||
|
||||
def test_decode_raises_on_bad_json() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid IPC message"):
|
||||
decode_message(b"not json\n")
|
||||
|
||||
|
||||
def test_decode_raises_on_empty_line() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
decode_message(b"\n")
|
||||
|
||||
|
||||
# ── is_unix_socket ────────────────────────────────────────────────────────────
|
||||
|
||||
def test_is_unix_socket_matches_platform() -> None:
|
||||
if sys.platform == "win32":
|
||||
assert not is_unix_socket()
|
||||
else:
|
||||
assert is_unix_socket()
|
||||
|
||||
|
||||
# ── Server + client roundtrip (Unix only) ─────────────────────────────────────
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
|
||||
async def test_server_client_ping(sock_path: Path) -> None:
|
||||
async def handler(msg: IpcMessage) -> IpcResponse:
|
||||
return {"ok": True, "data": {"pong": True}}
|
||||
|
||||
server = IpcServer(sock_path, handler)
|
||||
await server.start()
|
||||
try:
|
||||
resp = await IpcClient(sock_path).send({"cmd": "ping"})
|
||||
assert resp["ok"] is True
|
||||
assert resp["data"]["pong"] is True
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
|
||||
async def test_server_echoes_error_for_bad_json(sock_path: Path) -> None:
|
||||
async def handler(msg: IpcMessage) -> IpcResponse:
|
||||
return {"ok": True, "data": {}}
|
||||
|
||||
server = IpcServer(sock_path, handler)
|
||||
await server.start()
|
||||
try:
|
||||
reader, writer = await asyncio.open_unix_connection(str(sock_path))
|
||||
writer.write(b"not valid json\n")
|
||||
await writer.drain()
|
||||
line = await asyncio.wait_for(reader.readline(), timeout=3.0)
|
||||
resp = decode_message(line)
|
||||
assert resp["ok"] is False
|
||||
assert "error" in resp["data"]
|
||||
finally:
|
||||
try:
|
||||
writer.close()
|
||||
except Exception:
|
||||
pass
|
||||
await server.stop()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
|
||||
async def test_handler_response_returned_to_client(sock_path: Path) -> None:
|
||||
async def handler(msg: IpcMessage) -> IpcResponse:
|
||||
if msg.get("cmd") == "status":
|
||||
return {"ok": True, "data": {"uptime": 99.0}}
|
||||
return {"ok": False, "data": {"error": "unknown"}}
|
||||
|
||||
server = IpcServer(sock_path, handler)
|
||||
await server.start()
|
||||
try:
|
||||
resp = await IpcClient(sock_path).send({"cmd": "status"})
|
||||
assert resp["ok"] is True
|
||||
assert resp["data"]["uptime"] == 99.0
|
||||
|
||||
resp2 = await IpcClient(sock_path).send({"cmd": "bogus"})
|
||||
assert resp2["ok"] is False
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
|
||||
async def test_client_raises_when_no_server(sock_path: Path) -> None:
|
||||
client = IpcClient(sock_path)
|
||||
with pytest.raises((ConnectionRefusedError, FileNotFoundError, OSError)):
|
||||
await client.send({"cmd": "ping"})
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
|
||||
async def test_socket_file_chmod_600(sock_path: Path) -> None:
|
||||
async def handler(msg: IpcMessage) -> IpcResponse:
|
||||
return {"ok": True, "data": {}}
|
||||
|
||||
server = IpcServer(sock_path, handler)
|
||||
await server.start()
|
||||
try:
|
||||
mode = oct(sock_path.stat().st_mode & 0o777)
|
||||
assert mode == oct(0o600), f"Expected 0o600, got {mode}"
|
||||
finally:
|
||||
await server.stop()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
|
||||
async def test_stop_removes_socket_file(sock_path: Path) -> None:
|
||||
async def handler(msg: IpcMessage) -> IpcResponse:
|
||||
return {"ok": True, "data": {}}
|
||||
|
||||
server = IpcServer(sock_path, handler)
|
||||
await server.start()
|
||||
assert sock_path.exists()
|
||||
await server.stop()
|
||||
assert not sock_path.exists()
|
||||
@@ -0,0 +1,103 @@
|
||||
"""Unit tests for daemon PID file management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from pyra.daemon.pid import PidFile, PidFileError, resolve_pid_path
|
||||
|
||||
|
||||
def test_write_creates_file(tmp_path: Path) -> None:
|
||||
p = PidFile(tmp_path / "daemon.pid")
|
||||
p.write()
|
||||
assert (tmp_path / "daemon.pid").exists()
|
||||
assert int((tmp_path / "daemon.pid").read_text().strip()) == os.getpid()
|
||||
|
||||
|
||||
def test_read_returns_none_when_absent(tmp_path: Path) -> None:
|
||||
p = PidFile(tmp_path / "daemon.pid")
|
||||
assert p.read() is None
|
||||
|
||||
|
||||
def test_read_returns_pid_when_present(tmp_path: Path) -> None:
|
||||
pid_file = tmp_path / "daemon.pid"
|
||||
pid_file.write_text("12345")
|
||||
p = PidFile(pid_file)
|
||||
assert p.read() == 12345
|
||||
|
||||
|
||||
def test_read_returns_none_on_bad_content(tmp_path: Path) -> None:
|
||||
pid_file = tmp_path / "daemon.pid"
|
||||
pid_file.write_text("not-a-number")
|
||||
p = PidFile(pid_file)
|
||||
assert p.read() is None
|
||||
|
||||
|
||||
def test_is_stale_false_for_self(tmp_path: Path) -> None:
|
||||
p = PidFile(tmp_path / "daemon.pid")
|
||||
p.write()
|
||||
assert not p.is_stale()
|
||||
|
||||
|
||||
def test_is_stale_true_for_dead_pid(tmp_path: Path) -> None:
|
||||
pid_file = tmp_path / "daemon.pid"
|
||||
pid_file.write_text("999999999") # unrealistically large PID
|
||||
p = PidFile(pid_file)
|
||||
assert p.is_stale()
|
||||
|
||||
|
||||
def test_is_stale_false_when_file_absent(tmp_path: Path) -> None:
|
||||
p = PidFile(tmp_path / "daemon.pid")
|
||||
assert not p.is_stale()
|
||||
|
||||
|
||||
def test_remove_deletes_file(tmp_path: Path) -> None:
|
||||
p = PidFile(tmp_path / "daemon.pid")
|
||||
p.write()
|
||||
p.remove()
|
||||
assert not (tmp_path / "daemon.pid").exists()
|
||||
|
||||
|
||||
def test_remove_is_idempotent(tmp_path: Path) -> None:
|
||||
p = PidFile(tmp_path / "daemon.pid")
|
||||
p.remove() # must not raise
|
||||
|
||||
|
||||
def test_context_manager_writes_and_removes(tmp_path: Path) -> None:
|
||||
pid_file = tmp_path / "daemon.pid"
|
||||
p = PidFile(pid_file)
|
||||
with p:
|
||||
assert pid_file.exists()
|
||||
assert int(pid_file.read_text().strip()) == os.getpid()
|
||||
assert not pid_file.exists()
|
||||
|
||||
|
||||
def test_write_raises_when_live_pid_exists(tmp_path: Path) -> None:
|
||||
p = PidFile(tmp_path / "daemon.pid")
|
||||
p.write() # writes self PID (which is alive)
|
||||
p2 = PidFile(tmp_path / "daemon.pid")
|
||||
with pytest.raises(PidFileError):
|
||||
p2.write()
|
||||
|
||||
|
||||
def test_write_succeeds_over_stale_pid(tmp_path: Path) -> None:
|
||||
pid_file = tmp_path / "daemon.pid"
|
||||
pid_file.write_text("999999999") # stale
|
||||
p = PidFile(pid_file)
|
||||
p.write() # should not raise
|
||||
assert int(pid_file.read_text().strip()) == os.getpid()
|
||||
|
||||
|
||||
def test_resolve_pid_path_expands_tilde() -> None:
|
||||
result = resolve_pid_path("~/.pyra/daemon.pid")
|
||||
assert not str(result).startswith("~")
|
||||
assert result.is_absolute()
|
||||
|
||||
|
||||
def test_resolve_pid_path_absolute_unchanged(tmp_path: Path) -> None:
|
||||
path = tmp_path / "daemon.pid"
|
||||
result = resolve_pid_path(str(path))
|
||||
assert result == path
|
||||
@@ -0,0 +1,189 @@
|
||||
"""Unit tests for daemon service file generation and platform detection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from pyra.daemon.service import (
|
||||
detect_platform,
|
||||
find_pyra_executable,
|
||||
render_launchd_plist,
|
||||
render_systemd_unit,
|
||||
render_schtasks_xml,
|
||||
)
|
||||
|
||||
|
||||
# ── Template rendering ────────────────────────────────────────────────────────
|
||||
|
||||
def test_render_launchd_plist_contains_exe() -> None:
|
||||
xml = render_launchd_plist("/usr/local/bin/pyra", "~/.pyra/daemon.log", "~/.pyra/daemon.pid")
|
||||
assert "/usr/local/bin/pyra" in xml
|
||||
assert "<string>daemon</string>" in xml
|
||||
assert "<string>run</string>" in xml
|
||||
assert "com.pyra.daemon" in xml
|
||||
assert "<true/>" in xml # KeepAlive and RunAtLoad
|
||||
|
||||
|
||||
def test_render_launchd_plist_expands_log_tilde() -> None:
|
||||
xml = render_launchd_plist("/bin/pyra", "~/.pyra/daemon.log", "~/.pyra/daemon.pid")
|
||||
assert "~" not in xml
|
||||
|
||||
|
||||
def test_render_systemd_unit_contains_exe() -> None:
|
||||
unit = render_systemd_unit("/usr/local/bin/pyra", "~/.pyra/daemon.log")
|
||||
assert "ExecStart=/usr/local/bin/pyra daemon run" in unit
|
||||
assert "Restart=on-failure" in unit
|
||||
assert "Type=simple" in unit
|
||||
assert "WantedBy=default.target" in unit
|
||||
|
||||
|
||||
def test_render_systemd_unit_expands_log_tilde() -> None:
|
||||
unit = render_systemd_unit("/bin/pyra", "~/.pyra/daemon.log")
|
||||
assert "~" not in unit
|
||||
|
||||
|
||||
def test_render_schtasks_xml_contains_exe() -> None:
|
||||
xml = render_schtasks_xml("C:\\Users\\test\\pyra.exe")
|
||||
assert "C:\\Users\\test\\pyra.exe" in xml
|
||||
assert "LogonTrigger" in xml
|
||||
assert "daemon run" in xml
|
||||
assert "IgnoreNew" in xml
|
||||
|
||||
|
||||
def test_render_schtasks_xml_no_time_limit() -> None:
|
||||
xml = render_schtasks_xml("pyra.exe")
|
||||
assert "PT0S" in xml # ExecutionTimeLimit=PT0S means unlimited
|
||||
|
||||
|
||||
# ── Platform detection ────────────────────────────────────────────────────────
|
||||
|
||||
def test_detect_platform_returns_known_value() -> None:
|
||||
result = detect_platform()
|
||||
assert result in ("macos", "linux", "windows")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("system,expected", [
|
||||
("Darwin", "macos"),
|
||||
("Linux", "linux"),
|
||||
("Windows", "windows"),
|
||||
])
|
||||
def test_detect_platform_maps_correctly(system: str, expected: str) -> None:
|
||||
with patch("platform.system", return_value=system):
|
||||
assert detect_platform() == expected
|
||||
|
||||
|
||||
def test_detect_platform_raises_on_unknown() -> None:
|
||||
with patch("platform.system", return_value="FreeBSD"):
|
||||
with pytest.raises(RuntimeError, match="Unsupported platform"):
|
||||
detect_platform()
|
||||
|
||||
|
||||
# ── Executable detection ──────────────────────────────────────────────────────
|
||||
|
||||
def test_find_pyra_executable_returns_string() -> None:
|
||||
result = find_pyra_executable()
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_find_pyra_executable_uses_which_when_available(tmp_path: Path) -> None:
|
||||
fake_pyra = tmp_path / "pyra"
|
||||
fake_pyra.touch()
|
||||
with patch("shutil.which", return_value=str(fake_pyra)):
|
||||
assert find_pyra_executable() == str(fake_pyra)
|
||||
|
||||
|
||||
def test_find_pyra_executable_falls_back_to_sibling(tmp_path: Path) -> None:
|
||||
fake_python = tmp_path / "python3"
|
||||
fake_pyra = tmp_path / "pyra"
|
||||
fake_pyra.touch()
|
||||
with patch("shutil.which", return_value=None):
|
||||
with patch("sys.executable", str(fake_python)):
|
||||
assert find_pyra_executable() == str(fake_pyra)
|
||||
|
||||
|
||||
def test_find_pyra_executable_falls_back_to_module(tmp_path: Path) -> None:
|
||||
fake_python = tmp_path / "python3"
|
||||
with patch("shutil.which", return_value=None):
|
||||
with patch("sys.executable", str(fake_python)):
|
||||
result = find_pyra_executable()
|
||||
assert result == f"{fake_python} -m pyra"
|
||||
|
||||
|
||||
# ── Install / uninstall (subprocess mocked) ───────────────────────────────────
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="launchd install is macOS-only")
|
||||
def test_install_launchd_writes_plist_and_calls_launchctl(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
import pyra.daemon.service as svc
|
||||
|
||||
plist_path = tmp_path / "Library" / "LaunchAgents" / "com.pyra.daemon.plist"
|
||||
monkeypatch.setattr(svc, "_PLIST_PATH", plist_path)
|
||||
|
||||
calls: list[list[str]] = []
|
||||
monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: calls.append(cmd))
|
||||
|
||||
svc._install_launchd("/usr/local/bin/pyra", "~/.pyra/daemon.log", "~/.pyra/daemon.pid")
|
||||
|
||||
assert plist_path.exists()
|
||||
assert "com.pyra.daemon" in plist_path.read_text()
|
||||
assert any("launchctl" in c[0] for c in calls)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="systemd install is Linux-only")
|
||||
def test_install_systemd_writes_unit_and_calls_systemctl(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
import pyra.daemon.service as svc
|
||||
|
||||
unit_path = tmp_path / ".config" / "systemd" / "user" / "pyra.service"
|
||||
monkeypatch.setattr(svc, "_SYSTEMD_UNIT", unit_path)
|
||||
|
||||
calls: list[list[str]] = []
|
||||
monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: calls.append(cmd))
|
||||
|
||||
svc._install_systemd("/usr/local/bin/pyra", "~/.pyra/daemon.log")
|
||||
|
||||
assert unit_path.exists()
|
||||
assert "ExecStart" in unit_path.read_text()
|
||||
assert any("systemctl" in c[0] for c in calls)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="launchd uninstall is macOS-only")
|
||||
def test_uninstall_launchd_removes_plist(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
import pyra.daemon.service as svc
|
||||
|
||||
plist_path = tmp_path / "Library" / "LaunchAgents" / "com.pyra.daemon.plist"
|
||||
plist_path.parent.mkdir(parents=True)
|
||||
plist_path.write_text("<plist/>")
|
||||
monkeypatch.setattr(svc, "_PLIST_PATH", plist_path)
|
||||
monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: None)
|
||||
|
||||
svc._uninstall_launchd()
|
||||
|
||||
assert not plist_path.exists()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="systemd uninstall is Linux-only")
|
||||
def test_uninstall_systemd_removes_unit(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
import pyra.daemon.service as svc
|
||||
|
||||
unit_path = tmp_path / ".config" / "systemd" / "user" / "pyra.service"
|
||||
unit_path.parent.mkdir(parents=True)
|
||||
unit_path.write_text("[Service]")
|
||||
monkeypatch.setattr(svc, "_SYSTEMD_UNIT", unit_path)
|
||||
monkeypatch.setattr(subprocess, "run", lambda cmd, **kw: None)
|
||||
|
||||
svc._uninstall_systemd()
|
||||
|
||||
assert not unit_path.exists()
|
||||
Reference in New Issue
Block a user