feat(daemon): add IPC transport module
Newline-delimited JSON over Unix socket (macOS/Linux, chmod 600, UID-checked via SO_PEERCRED/getpeereid) with TCP loopback fallback on Windows. Port written to ~/.pyra/daemon.port for Windows clients. Sync send_command() wrapper for CLI. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,241 @@
|
|||||||
|
"""IPC transport for the Pyra daemon.
|
||||||
|
|
||||||
|
Linux/macOS: Unix domain socket at ~/.pyra/daemon.sock (chmod 600, UID-checked).
|
||||||
|
Windows: TCP loopback on an OS-assigned port; actual port written to
|
||||||
|
~/.pyra/daemon.port so clients can connect without knowing it ahead
|
||||||
|
of time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Awaitable, Callable
|
||||||
|
|
||||||
|
|
||||||
|
# ── Protocol types ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
IpcMessage = dict[str, Any] # must have "cmd" key
|
||||||
|
IpcResponse = dict[str, Any] # must have "ok" and "data" keys
|
||||||
|
|
||||||
|
|
||||||
|
# ── Encode / decode ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def encode_message(msg: IpcMessage) -> bytes:
|
||||||
|
return (json.dumps(msg) + "\n").encode()
|
||||||
|
|
||||||
|
|
||||||
|
def decode_message(line: bytes) -> IpcMessage:
|
||||||
|
try:
|
||||||
|
return json.loads(line.rstrip(b"\n"))
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise ValueError(f"Invalid IPC message: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
# ── Address helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def is_unix_socket() -> bool:
|
||||||
|
return sys.platform != "win32"
|
||||||
|
|
||||||
|
|
||||||
|
def get_socket_path(cfg_socket_path: str) -> Path:
|
||||||
|
"""Expand ~ and return the Unix socket path."""
|
||||||
|
return Path(cfg_socket_path).expanduser()
|
||||||
|
|
||||||
|
|
||||||
|
def get_port_file_path() -> Path:
|
||||||
|
from pyra.utils.paths import pyra_home
|
||||||
|
return pyra_home() / "daemon.port"
|
||||||
|
|
||||||
|
|
||||||
|
def _read_windows_port() -> int | None:
|
||||||
|
port_file = get_port_file_path()
|
||||||
|
try:
|
||||||
|
return int(port_file.read_text().strip())
|
||||||
|
except (FileNotFoundError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Server ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class IpcServer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
address: Path | tuple[str, int],
|
||||||
|
handler: Callable[[IpcMessage], Awaitable[IpcResponse]],
|
||||||
|
) -> None:
|
||||||
|
self._address = address
|
||||||
|
self._handler = handler
|
||||||
|
self._server: asyncio.Server | None = None
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
if is_unix_socket():
|
||||||
|
assert isinstance(self._address, Path)
|
||||||
|
sock_path = self._address
|
||||||
|
if sock_path.exists():
|
||||||
|
sock_path.unlink()
|
||||||
|
self._server = await asyncio.start_unix_server(
|
||||||
|
self._handle_client, path=str(sock_path)
|
||||||
|
)
|
||||||
|
os.chmod(sock_path, 0o600)
|
||||||
|
else:
|
||||||
|
host, port = self._address if isinstance(self._address, tuple) else ("127.0.0.1", 0)
|
||||||
|
self._server = await asyncio.start_server(
|
||||||
|
self._handle_client, host=host, port=port
|
||||||
|
)
|
||||||
|
actual_port = self._server.sockets[0].getsockname()[1]
|
||||||
|
port_file = get_port_file_path()
|
||||||
|
port_file.write_text(str(actual_port))
|
||||||
|
|
||||||
|
await self._server.start_serving()
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
if self._server is not None:
|
||||||
|
self._server.close()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._server.wait_closed(), timeout=5.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
if is_unix_socket() and isinstance(self._address, Path):
|
||||||
|
try:
|
||||||
|
self._address.unlink()
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
port_file = get_port_file_path()
|
||||||
|
try:
|
||||||
|
port_file.unlink()
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _handle_client(
|
||||||
|
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
if is_unix_socket() and not self._check_peer_uid(writer):
|
||||||
|
writer.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
line = await asyncio.wait_for(reader.readline(), timeout=5.0)
|
||||||
|
if not line:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
msg = decode_message(line)
|
||||||
|
except ValueError:
|
||||||
|
resp: IpcResponse = {"ok": False, "data": {"error": "invalid JSON"}}
|
||||||
|
else:
|
||||||
|
resp = await self._handler(msg)
|
||||||
|
|
||||||
|
writer.write(encode_message(resp))
|
||||||
|
await writer.drain()
|
||||||
|
except (asyncio.TimeoutError, ConnectionResetError, BrokenPipeError):
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
writer.close()
|
||||||
|
await writer.wait_closed()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _check_peer_uid(self, writer: asyncio.StreamWriter) -> bool:
|
||||||
|
"""Return True if the peer's UID matches ours. Falls back to True on error."""
|
||||||
|
try:
|
||||||
|
peer_uid = _get_peer_uid(writer)
|
||||||
|
if peer_uid is None:
|
||||||
|
return True # can't determine — allow (socket perms are the guard)
|
||||||
|
return peer_uid == os.getuid()
|
||||||
|
except Exception:
|
||||||
|
return True # don't crash the server on unexpected errors
|
||||||
|
|
||||||
|
|
||||||
|
# ── Client ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class IpcClient:
|
||||||
|
def __init__(self, address: Path | tuple[str, int]) -> None:
|
||||||
|
self._address = address
|
||||||
|
|
||||||
|
async def send(self, msg: IpcMessage, timeout: float = 5.0) -> IpcResponse:
|
||||||
|
if is_unix_socket():
|
||||||
|
assert isinstance(self._address, Path)
|
||||||
|
reader, writer = await asyncio.wait_for(
|
||||||
|
asyncio.open_unix_connection(str(self._address)), timeout=timeout
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
host, port = self._address if isinstance(self._address, tuple) else ("127.0.0.1", 0)
|
||||||
|
reader, writer = await asyncio.wait_for(
|
||||||
|
asyncio.open_connection(host, port), timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
writer.write(encode_message(msg))
|
||||||
|
await writer.drain()
|
||||||
|
line = await asyncio.wait_for(reader.readline(), timeout=timeout)
|
||||||
|
return decode_message(line)
|
||||||
|
finally:
|
||||||
|
writer.close()
|
||||||
|
try:
|
||||||
|
await writer.wait_closed()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def send_command(
|
||||||
|
address: Path | tuple[str, int],
|
||||||
|
msg: IpcMessage,
|
||||||
|
timeout: float = 5.0,
|
||||||
|
) -> IpcResponse:
|
||||||
|
"""Synchronous wrapper around IpcClient.send() for CLI callers."""
|
||||||
|
return asyncio.run(IpcClient(address).send(msg, timeout=timeout))
|
||||||
|
|
||||||
|
|
||||||
|
# ── Peer UID detection ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _get_peer_uid(writer: asyncio.StreamWriter) -> int | None:
|
||||||
|
"""Return the connecting peer's UID, or None if unavailable."""
|
||||||
|
try:
|
||||||
|
sock = writer.get_extra_info("socket")
|
||||||
|
if sock is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if sys.platform == "linux":
|
||||||
|
# SO_PEERCRED: struct { pid_t pid; uid_t uid; gid_t gid; }
|
||||||
|
SO_PEERCRED = 17
|
||||||
|
cred = sock.getsockopt(
|
||||||
|
socket_module().SOL_SOCKET, SO_PEERCRED, struct.calcsize("3i")
|
||||||
|
)
|
||||||
|
_pid, uid, _gid = struct.unpack("3i", cred)
|
||||||
|
return uid
|
||||||
|
|
||||||
|
if sys.platform == "darwin":
|
||||||
|
return _macos_peer_uid(sock.fileno())
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def socket_module(): # lazy import to avoid top-level import on non-Unix
|
||||||
|
import socket
|
||||||
|
return socket
|
||||||
|
|
||||||
|
|
||||||
|
def _macos_peer_uid(fd: int) -> int | None:
|
||||||
|
"""Use getpeereid(2) via ctypes to retrieve the peer UID on macOS."""
|
||||||
|
import ctypes
|
||||||
|
import ctypes.util
|
||||||
|
|
||||||
|
libc_name = ctypes.util.find_library("c")
|
||||||
|
if not libc_name:
|
||||||
|
return None
|
||||||
|
libc = ctypes.CDLL(libc_name)
|
||||||
|
|
||||||
|
euid = ctypes.c_uint32(0)
|
||||||
|
egid = ctypes.c_uint32(0)
|
||||||
|
if libc.getpeereid(fd, ctypes.byref(euid), ctypes.byref(egid)) != 0:
|
||||||
|
return None
|
||||||
|
return euid.value
|
||||||
@@ -0,0 +1,162 @@
|
|||||||
|
"""Unit tests for the IPC layer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pyra.daemon.ipc import (
|
||||||
|
IpcClient,
|
||||||
|
IpcMessage,
|
||||||
|
IpcResponse,
|
||||||
|
IpcServer,
|
||||||
|
decode_message,
|
||||||
|
encode_message,
|
||||||
|
is_unix_socket,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sock_path():
|
||||||
|
"""Short socket path that fits within macOS's 104-char AF_UNIX limit."""
|
||||||
|
with tempfile.TemporaryDirectory(dir="/tmp") as d:
|
||||||
|
yield Path(d) / "t.sock"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Protocol encode / decode ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_encode_appends_newline() -> None:
|
||||||
|
data = encode_message({"cmd": "ping"})
|
||||||
|
assert data.endswith(b"\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_is_valid_json() -> None:
|
||||||
|
import json
|
||||||
|
data = encode_message({"cmd": "status", "extra": 42})
|
||||||
|
assert json.loads(data) == {"cmd": "status", "extra": 42}
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_roundtrip() -> None:
|
||||||
|
msg: IpcMessage = {"cmd": "stop"}
|
||||||
|
assert decode_message(encode_message(msg)) == msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_strips_newline() -> None:
|
||||||
|
assert decode_message(b'{"cmd": "stop"}\n')["cmd"] == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_raises_on_bad_json() -> None:
|
||||||
|
with pytest.raises(ValueError, match="Invalid IPC message"):
|
||||||
|
decode_message(b"not json\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_raises_on_empty_line() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
decode_message(b"\n")
|
||||||
|
|
||||||
|
|
||||||
|
# ── is_unix_socket ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_is_unix_socket_matches_platform() -> None:
|
||||||
|
if sys.platform == "win32":
|
||||||
|
assert not is_unix_socket()
|
||||||
|
else:
|
||||||
|
assert is_unix_socket()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Server + client roundtrip (Unix only) ─────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
|
||||||
|
async def test_server_client_ping(sock_path: Path) -> None:
|
||||||
|
async def handler(msg: IpcMessage) -> IpcResponse:
|
||||||
|
return {"ok": True, "data": {"pong": True}}
|
||||||
|
|
||||||
|
server = IpcServer(sock_path, handler)
|
||||||
|
await server.start()
|
||||||
|
try:
|
||||||
|
resp = await IpcClient(sock_path).send({"cmd": "ping"})
|
||||||
|
assert resp["ok"] is True
|
||||||
|
assert resp["data"]["pong"] is True
|
||||||
|
finally:
|
||||||
|
await server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
|
||||||
|
async def test_server_echoes_error_for_bad_json(sock_path: Path) -> None:
|
||||||
|
async def handler(msg: IpcMessage) -> IpcResponse:
|
||||||
|
return {"ok": True, "data": {}}
|
||||||
|
|
||||||
|
server = IpcServer(sock_path, handler)
|
||||||
|
await server.start()
|
||||||
|
try:
|
||||||
|
reader, writer = await asyncio.open_unix_connection(str(sock_path))
|
||||||
|
writer.write(b"not valid json\n")
|
||||||
|
await writer.drain()
|
||||||
|
line = await asyncio.wait_for(reader.readline(), timeout=3.0)
|
||||||
|
resp = decode_message(line)
|
||||||
|
assert resp["ok"] is False
|
||||||
|
assert "error" in resp["data"]
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
writer.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
await server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
|
||||||
|
async def test_handler_response_returned_to_client(sock_path: Path) -> None:
|
||||||
|
async def handler(msg: IpcMessage) -> IpcResponse:
|
||||||
|
if msg.get("cmd") == "status":
|
||||||
|
return {"ok": True, "data": {"uptime": 99.0}}
|
||||||
|
return {"ok": False, "data": {"error": "unknown"}}
|
||||||
|
|
||||||
|
server = IpcServer(sock_path, handler)
|
||||||
|
await server.start()
|
||||||
|
try:
|
||||||
|
resp = await IpcClient(sock_path).send({"cmd": "status"})
|
||||||
|
assert resp["ok"] is True
|
||||||
|
assert resp["data"]["uptime"] == 99.0
|
||||||
|
|
||||||
|
resp2 = await IpcClient(sock_path).send({"cmd": "bogus"})
|
||||||
|
assert resp2["ok"] is False
|
||||||
|
finally:
|
||||||
|
await server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
|
||||||
|
async def test_client_raises_when_no_server(sock_path: Path) -> None:
|
||||||
|
client = IpcClient(sock_path)
|
||||||
|
with pytest.raises((ConnectionRefusedError, FileNotFoundError, OSError)):
|
||||||
|
await client.send({"cmd": "ping"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
|
||||||
|
async def test_socket_file_chmod_600(sock_path: Path) -> None:
|
||||||
|
async def handler(msg: IpcMessage) -> IpcResponse:
|
||||||
|
return {"ok": True, "data": {}}
|
||||||
|
|
||||||
|
server = IpcServer(sock_path, handler)
|
||||||
|
await server.start()
|
||||||
|
try:
|
||||||
|
mode = oct(sock_path.stat().st_mode & 0o777)
|
||||||
|
assert mode == oct(0o600), f"Expected 0o600, got {mode}"
|
||||||
|
finally:
|
||||||
|
await server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="Unix socket test")
|
||||||
|
async def test_stop_removes_socket_file(sock_path: Path) -> None:
|
||||||
|
async def handler(msg: IpcMessage) -> IpcResponse:
|
||||||
|
return {"ok": True, "data": {}}
|
||||||
|
|
||||||
|
server = IpcServer(sock_path, handler)
|
||||||
|
await server.start()
|
||||||
|
assert sock_path.exists()
|
||||||
|
await server.stop()
|
||||||
|
assert not sock_path.exists()
|
||||||
Reference in New Issue
Block a user