From d42b8b4a4712d73f4f228a3b9e7ee2fe9da2a401 Mon Sep 17 00:00:00 2001 From: curo1305 Date: Tue, 19 May 2026 15:26:25 +0200 Subject: [PATCH] feat(daemon): add IPC transport module Newline-delimited JSON over Unix socket (macOS/Linux, chmod 600, UID-checked via SO_PEERCRED/getpeereid) with TCP loopback fallback on Windows. Port written to ~/.pyra/daemon.port for Windows clients. Sync send_command() wrapper for CLI. Co-Authored-By: Claude Sonnet 4.6 --- src/pyra/daemon/ipc.py | 241 ++++++++++++++++++++++++++++++++++ tests/unit/test_daemon_ipc.py | 162 +++++++++++++++++++++++ 2 files changed, 403 insertions(+) create mode 100644 src/pyra/daemon/ipc.py create mode 100644 tests/unit/test_daemon_ipc.py diff --git a/src/pyra/daemon/ipc.py b/src/pyra/daemon/ipc.py new file mode 100644 index 0000000..8c14bd2 --- /dev/null +++ b/src/pyra/daemon/ipc.py @@ -0,0 +1,241 @@ +"""IPC transport for the Pyra daemon. + +Linux/macOS: Unix domain socket at ~/.pyra/daemon.sock (chmod 600, UID-checked). +Windows: TCP loopback on an OS-assigned port; actual port written to + ~/.pyra/daemon.port so clients can connect without knowing it ahead + of time. +""" + +from __future__ import annotations + +import asyncio +import json +import os +import struct +import sys +from pathlib import Path +from typing import Any, Awaitable, Callable + + +# ── Protocol types ──────────────────────────────────────────────────────────── + +IpcMessage = dict[str, Any] # must have "cmd" key +IpcResponse = dict[str, Any] # must have "ok" and "data" keys + + +# ── Encode / decode ─────────────────────────────────────────────────────────── + +def encode_message(msg: IpcMessage) -> bytes: + return (json.dumps(msg) + "\n").encode() + + +def decode_message(line: bytes) -> IpcMessage: + try: + return json.loads(line.rstrip(b"\n")) + except json.JSONDecodeError as exc: + raise ValueError(f"Invalid IPC message: {exc}") from exc + + +# ── Address helpers ─────────────────────────────────────────────────────────── + +def is_unix_socket() -> bool: + return sys.platform != "win32" + + +def get_socket_path(cfg_socket_path: str) -> Path: + """Expand ~ and return the Unix socket path.""" + return Path(cfg_socket_path).expanduser() + + +def get_port_file_path() -> Path: + from pyra.utils.paths import pyra_home + return pyra_home() / "daemon.port" + + +def _read_windows_port() -> int | None: + port_file = get_port_file_path() + try: + return int(port_file.read_text().strip()) + except (FileNotFoundError, ValueError): + return None + + +# ── Server ──────────────────────────────────────────────────────────────────── + +class IpcServer: + def __init__( + self, + address: Path | tuple[str, int], + handler: Callable[[IpcMessage], Awaitable[IpcResponse]], + ) -> None: + self._address = address + self._handler = handler + self._server: asyncio.Server | None = None + + async def start(self) -> None: + if is_unix_socket(): + assert isinstance(self._address, Path) + sock_path = self._address + if sock_path.exists(): + sock_path.unlink() + self._server = await asyncio.start_unix_server( + self._handle_client, path=str(sock_path) + ) + os.chmod(sock_path, 0o600) + else: + host, port = self._address if isinstance(self._address, tuple) else ("127.0.0.1", 0) + self._server = await asyncio.start_server( + self._handle_client, host=host, port=port + ) + actual_port = self._server.sockets[0].getsockname()[1] + port_file = get_port_file_path() + port_file.write_text(str(actual_port)) + + await self._server.start_serving() + + async def stop(self) -> None: + if self._server is not None: + self._server.close() + try: + await asyncio.wait_for(self._server.wait_closed(), timeout=5.0) + except asyncio.TimeoutError: + pass + if is_unix_socket() and isinstance(self._address, Path): + try: + self._address.unlink() + except FileNotFoundError: + pass + else: + port_file = get_port_file_path() + try: + port_file.unlink() + except FileNotFoundError: + pass + + async def _handle_client( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: + try: + if is_unix_socket() and not self._check_peer_uid(writer): + writer.close() + return + + line = await asyncio.wait_for(reader.readline(), timeout=5.0) + if not line: + return + + try: + msg = decode_message(line) + except ValueError: + resp: IpcResponse = {"ok": False, "data": {"error": "invalid JSON"}} + else: + resp = await self._handler(msg) + + writer.write(encode_message(resp)) + await writer.drain() + except (asyncio.TimeoutError, ConnectionResetError, BrokenPipeError): + pass + finally: + try: + writer.close() + await writer.wait_closed() + except Exception: + pass + + def _check_peer_uid(self, writer: asyncio.StreamWriter) -> bool: + """Return True if the peer's UID matches ours. Falls back to True on error.""" + try: + peer_uid = _get_peer_uid(writer) + if peer_uid is None: + return True # can't determine — allow (socket perms are the guard) + return peer_uid == os.getuid() + except Exception: + return True # don't crash the server on unexpected errors + + +# ── Client ──────────────────────────────────────────────────────────────────── + +class IpcClient: + def __init__(self, address: Path | tuple[str, int]) -> None: + self._address = address + + async def send(self, msg: IpcMessage, timeout: float = 5.0) -> IpcResponse: + if is_unix_socket(): + assert isinstance(self._address, Path) + reader, writer = await asyncio.wait_for( + asyncio.open_unix_connection(str(self._address)), timeout=timeout + ) + else: + host, port = self._address if isinstance(self._address, tuple) else ("127.0.0.1", 0) + reader, writer = await asyncio.wait_for( + asyncio.open_connection(host, port), timeout=timeout + ) + + try: + writer.write(encode_message(msg)) + await writer.drain() + line = await asyncio.wait_for(reader.readline(), timeout=timeout) + return decode_message(line) + finally: + writer.close() + try: + await writer.wait_closed() + except Exception: + pass + + +def send_command( + address: Path | tuple[str, int], + msg: IpcMessage, + timeout: float = 5.0, +) -> IpcResponse: + """Synchronous wrapper around IpcClient.send() for CLI callers.""" + return asyncio.run(IpcClient(address).send(msg, timeout=timeout)) + + +# ── Peer UID detection ──────────────────────────────────────────────────────── + +def _get_peer_uid(writer: asyncio.StreamWriter) -> int | None: + """Return the connecting peer's UID, or None if unavailable.""" + try: + sock = writer.get_extra_info("socket") + if sock is None: + return None + + if sys.platform == "linux": + # SO_PEERCRED: struct { pid_t pid; uid_t uid; gid_t gid; } + SO_PEERCRED = 17 + cred = sock.getsockopt( + socket_module().SOL_SOCKET, SO_PEERCRED, struct.calcsize("3i") + ) + _pid, uid, _gid = struct.unpack("3i", cred) + return uid + + if sys.platform == "darwin": + return _macos_peer_uid(sock.fileno()) + + except Exception: + pass + return None + + +def socket_module(): # lazy import to avoid top-level import on non-Unix + import socket + return socket + + +def _macos_peer_uid(fd: int) -> int | None: + """Use getpeereid(2) via ctypes to retrieve the peer UID on macOS.""" + import ctypes + import ctypes.util + + libc_name = ctypes.util.find_library("c") + if not libc_name: + return None + libc = ctypes.CDLL(libc_name) + + euid = ctypes.c_uint32(0) + egid = ctypes.c_uint32(0) + if libc.getpeereid(fd, ctypes.byref(euid), ctypes.byref(egid)) != 0: + return None + return euid.value diff --git a/tests/unit/test_daemon_ipc.py b/tests/unit/test_daemon_ipc.py new file mode 100644 index 0000000..4ed466f --- /dev/null +++ b/tests/unit/test_daemon_ipc.py @@ -0,0 +1,162 @@ +"""Unit tests for the IPC layer.""" + +from __future__ import annotations + +import asyncio +import os +import sys +import tempfile +from pathlib import Path + +import pytest + +from pyra.daemon.ipc import ( + IpcClient, + IpcMessage, + IpcResponse, + IpcServer, + decode_message, + encode_message, + is_unix_socket, +) + + +@pytest.fixture +def sock_path(): + """Short socket path that fits within macOS's 104-char AF_UNIX limit.""" + with tempfile.TemporaryDirectory(dir="/tmp") as d: + yield Path(d) / "t.sock" + + +# ── Protocol encode / decode ────────────────────────────────────────────────── + +def test_encode_appends_newline() -> None: + data = encode_message({"cmd": "ping"}) + assert data.endswith(b"\n") + + +def test_encode_is_valid_json() -> None: + import json + data = encode_message({"cmd": "status", "extra": 42}) + assert json.loads(data) == {"cmd": "status", "extra": 42} + + +def test_decode_roundtrip() -> None: + msg: IpcMessage = {"cmd": "stop"} + assert decode_message(encode_message(msg)) == msg + + +def test_decode_strips_newline() -> None: + assert decode_message(b'{"cmd": "stop"}\n')["cmd"] == "stop" + + +def test_decode_raises_on_bad_json() -> None: + with pytest.raises(ValueError, match="Invalid IPC message"): + decode_message(b"not json\n") + + +def test_decode_raises_on_empty_line() -> None: + with pytest.raises(ValueError): + decode_message(b"\n") + + +# ── is_unix_socket ──────────────────────────────────────────────────────────── + +def test_is_unix_socket_matches_platform() -> None: + if sys.platform == "win32": + assert not is_unix_socket() + else: + assert is_unix_socket() + + +# ── Server + client roundtrip (Unix only) ───────────────────────────────────── + +@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test") +async def test_server_client_ping(sock_path: Path) -> None: + async def handler(msg: IpcMessage) -> IpcResponse: + return {"ok": True, "data": {"pong": True}} + + server = IpcServer(sock_path, handler) + await server.start() + try: + resp = await IpcClient(sock_path).send({"cmd": "ping"}) + assert resp["ok"] is True + assert resp["data"]["pong"] is True + finally: + await server.stop() + + +@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test") +async def test_server_echoes_error_for_bad_json(sock_path: Path) -> None: + async def handler(msg: IpcMessage) -> IpcResponse: + return {"ok": True, "data": {}} + + server = IpcServer(sock_path, handler) + await server.start() + try: + reader, writer = await asyncio.open_unix_connection(str(sock_path)) + writer.write(b"not valid json\n") + await writer.drain() + line = await asyncio.wait_for(reader.readline(), timeout=3.0) + resp = decode_message(line) + assert resp["ok"] is False + assert "error" in resp["data"] + finally: + try: + writer.close() + except Exception: + pass + await server.stop() + + +@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test") +async def test_handler_response_returned_to_client(sock_path: Path) -> None: + async def handler(msg: IpcMessage) -> IpcResponse: + if msg.get("cmd") == "status": + return {"ok": True, "data": {"uptime": 99.0}} + return {"ok": False, "data": {"error": "unknown"}} + + server = IpcServer(sock_path, handler) + await server.start() + try: + resp = await IpcClient(sock_path).send({"cmd": "status"}) + assert resp["ok"] is True + assert resp["data"]["uptime"] == 99.0 + + resp2 = await IpcClient(sock_path).send({"cmd": "bogus"}) + assert resp2["ok"] is False + finally: + await server.stop() + + +@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test") +async def test_client_raises_when_no_server(sock_path: Path) -> None: + client = IpcClient(sock_path) + with pytest.raises((ConnectionRefusedError, FileNotFoundError, OSError)): + await client.send({"cmd": "ping"}) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test") +async def test_socket_file_chmod_600(sock_path: Path) -> None: + async def handler(msg: IpcMessage) -> IpcResponse: + return {"ok": True, "data": {}} + + server = IpcServer(sock_path, handler) + await server.start() + try: + mode = oct(sock_path.stat().st_mode & 0o777) + assert mode == oct(0o600), f"Expected 0o600, got {mode}" + finally: + await server.stop() + + +@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test") +async def test_stop_removes_socket_file(sock_path: Path) -> None: + async def handler(msg: IpcMessage) -> IpcResponse: + return {"ok": True, "data": {}} + + server = IpcServer(sock_path, handler) + await server.start() + assert sock_path.exists() + await server.stop() + assert not sock_path.exists()