"""IPC transport for the Pyra daemon. Linux/macOS: Unix domain socket at ~/.pyra/daemon.sock (chmod 600, UID-checked). Windows: TCP loopback on an OS-assigned port; actual port written to ~/.pyra/daemon.port so clients can connect without knowing it ahead of time. """ from __future__ import annotations import asyncio import json import os import struct import sys from pathlib import Path from typing import Any, Awaitable, Callable # ── Protocol types ──────────────────────────────────────────────────────────── IpcMessage = dict[str, Any] # must have "cmd" key IpcResponse = dict[str, Any] # must have "ok" and "data" keys # ── Encode / decode ─────────────────────────────────────────────────────────── def encode_message(msg: IpcMessage) -> bytes: return (json.dumps(msg) + "\n").encode() def decode_message(line: bytes) -> IpcMessage: try: return json.loads(line.rstrip(b"\n")) except json.JSONDecodeError as exc: raise ValueError(f"Invalid IPC message: {exc}") from exc # ── Address helpers ─────────────────────────────────────────────────────────── def is_unix_socket() -> bool: return sys.platform != "win32" def get_socket_path(cfg_socket_path: str) -> Path: """Expand ~ and return the Unix socket path.""" return Path(cfg_socket_path).expanduser() def get_port_file_path() -> Path: from pyra.utils.paths import pyra_home return pyra_home() / "daemon.port" def _read_windows_port() -> int | None: port_file = get_port_file_path() try: return int(port_file.read_text().strip()) except (FileNotFoundError, ValueError): return None # ── Server ──────────────────────────────────────────────────────────────────── class IpcServer: def __init__( self, address: Path | tuple[str, int], handler: Callable[[IpcMessage], Awaitable[IpcResponse]], ) -> None: self._address = address self._handler = handler self._server: asyncio.Server | None = None async def start(self) -> None: if is_unix_socket(): assert isinstance(self._address, Path) sock_path = self._address if sock_path.exists(): sock_path.unlink() self._server = await asyncio.start_unix_server( self._handle_client, path=str(sock_path) ) os.chmod(sock_path, 0o600) else: host, port = self._address if isinstance(self._address, tuple) else ("127.0.0.1", 0) self._server = await asyncio.start_server( self._handle_client, host=host, port=port ) actual_port = self._server.sockets[0].getsockname()[1] port_file = get_port_file_path() port_file.write_text(str(actual_port)) await self._server.start_serving() async def stop(self) -> None: if self._server is not None: self._server.close() try: await asyncio.wait_for(self._server.wait_closed(), timeout=5.0) except asyncio.TimeoutError: pass if is_unix_socket() and isinstance(self._address, Path): try: self._address.unlink() except FileNotFoundError: pass else: port_file = get_port_file_path() try: port_file.unlink() except FileNotFoundError: pass async def _handle_client( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: try: if is_unix_socket() and not self._check_peer_uid(writer): writer.close() return line = await asyncio.wait_for(reader.readline(), timeout=5.0) if not line: return try: msg = decode_message(line) except ValueError: resp: IpcResponse = {"ok": False, "data": {"error": "invalid JSON"}} else: resp = await self._handler(msg) writer.write(encode_message(resp)) await writer.drain() except (asyncio.TimeoutError, ConnectionResetError, BrokenPipeError): pass finally: try: writer.close() await writer.wait_closed() except Exception: pass def _check_peer_uid(self, writer: asyncio.StreamWriter) -> bool: """Return True if the peer's UID matches ours. Falls back to True on error.""" try: peer_uid = _get_peer_uid(writer) if peer_uid is None: return True # can't determine — allow (socket perms are the guard) return peer_uid == os.getuid() except Exception: return True # don't crash the server on unexpected errors # ── Client ──────────────────────────────────────────────────────────────────── class IpcClient: def __init__(self, address: Path | tuple[str, int]) -> None: self._address = address async def send(self, msg: IpcMessage, timeout: float = 5.0) -> IpcResponse: if is_unix_socket(): assert isinstance(self._address, Path) reader, writer = await asyncio.wait_for( asyncio.open_unix_connection(str(self._address)), timeout=timeout ) else: host, port = self._address if isinstance(self._address, tuple) else ("127.0.0.1", 0) reader, writer = await asyncio.wait_for( asyncio.open_connection(host, port), timeout=timeout ) try: writer.write(encode_message(msg)) await writer.drain() line = await asyncio.wait_for(reader.readline(), timeout=timeout) return decode_message(line) finally: writer.close() try: await writer.wait_closed() except Exception: pass def send_command( address: Path | tuple[str, int], msg: IpcMessage, timeout: float = 5.0, ) -> IpcResponse: """Synchronous wrapper around IpcClient.send() for CLI callers.""" return asyncio.run(IpcClient(address).send(msg, timeout=timeout)) # ── Peer UID detection ──────────────────────────────────────────────────────── def _get_peer_uid(writer: asyncio.StreamWriter) -> int | None: """Return the connecting peer's UID, or None if unavailable.""" try: sock = writer.get_extra_info("socket") if sock is None: return None if sys.platform == "linux": # SO_PEERCRED: struct { pid_t pid; uid_t uid; gid_t gid; } SO_PEERCRED = 17 cred = sock.getsockopt( socket_module().SOL_SOCKET, SO_PEERCRED, struct.calcsize("3i") ) _pid, uid, _gid = struct.unpack("3i", cred) return uid if sys.platform == "darwin": return _macos_peer_uid(sock.fileno()) except Exception: pass return None def socket_module(): # lazy import to avoid top-level import on non-Unix import socket return socket def _macos_peer_uid(fd: int) -> int | None: """Use getpeereid(2) via ctypes to retrieve the peer UID on macOS.""" import ctypes import ctypes.util libc_name = ctypes.util.find_library("c") if not libc_name: return None libc = ctypes.CDLL(libc_name) euid = ctypes.c_uint32(0) egid = ctypes.c_uint32(0) if libc.getpeereid(fd, ctypes.byref(euid), ctypes.byref(egid)) != 0: return None return euid.value