from __future__ import annotations from typing import TYPE_CHECKING, Any import litellm from rich.panel import Panel from pyra.chat.renderer import ( console, render_error, render_info, render_streaming_response, render_text_response, ) from pyra.setup.providers import get_provider from pyra.vault.reader import get_key if TYPE_CHECKING: from pyra.config.schema import PyraConfig from pyra.plugins.executor import ToolExecutor from pyra.plugins.registry import PluginRegistry _STEP_SYSTEM_BASE = """\ You are Pyra, executing one step of a multi-step plan. Security constraints: - You cannot access ~/.pyra/vault/ — it is physically blocked by the application. - You cannot execute shell commands — use the provided tools instead. - You cannot read or modify files outside ~/.pyra/memory/ directly. Work only on the assigned step. Use available tools if needed. Clearly describe what you accomplished when finished. """ _VERIFY_SYSTEM = ( "You evaluate task step outcomes. " "Reply only with the single word SUCCESS or FAILURE." ) class TaskPlanner: def __init__(self, cfg: PyraConfig, registry: PluginRegistry, executor: ToolExecutor) -> None: self._cfg = cfg self._registry = registry self._executor = executor def make_tool_handler(self): def handle(task: str, steps: list) -> str: return self._run_plan(task, steps) return handle def _run_plan(self, task: str, steps: list) -> str: normalised = [ s if isinstance(s, dict) else {"description": s} for s in steps ] if not self._ask_plan_approval(task, normalised): return "Plan declined by user." previous_results: list[str] = [] summaries: list[str] = [] n = len(normalised) for i, step in enumerate(normalised): desc = step.get("description", f"Step {i + 1}") agent_name = step.get("agent") label = f" [{agent_name}]" if agent_name else "" render_info(f"[Plan] Step {i + 1}/{n}{label}: {desc}") try: output = self._execute_step(desc, agent_name, task, previous_results, n) except Exception as exc: render_error(f"[Plan] Step {i + 1} error: {exc}") return f"Plan failed at step {i + 1} ({desc}): {exc}" if not self._verify_step(desc, output): render_error(f"[Plan] Step {i + 1} failed verification.") return ( f"Plan failed at step {i + 1} ({desc}): " f"output did not pass verification.\n{output[:500]}" ) summary = output[:800].strip() previous_results.append(summary) summaries.append(f"Step {i + 1} ({desc}): {summary}") render_info(f"[Plan] Step {i + 1} ✓") render_info("[Plan] All steps completed successfully.") body = "\n\n".join(summaries) result = f"Plan completed successfully.\n\n{body}" return result[:3900] def _execute_step( self, desc: str, agent_name: str | None, task: str, previous_results: list[str], total: int, ) -> str: step_num = len(previous_results) + 1 agent_info = self._registry.get_agent(agent_name) if agent_name else None if agent_info: agent_spec, agent_tools = agent_info system_prompt = agent_spec.system_prompt tools = agent_tools else: system_prompt = _STEP_SYSTEM_BASE tools = [t for t in self._registry.get_all_tools() if t.name != "plan_and_execute"] messages: list[dict[str, Any]] = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": self._step_user_msg(task, step_num, total, desc, previous_results)}, ] tools_spec = [ { "type": "function", "function": { "name": t.name, "description": t.description, "parameters": t.parameters, }, } for t in tools ] base_kw = self._base_kwargs() litellm.suppress_debug_info = True if not tools_spec: stream = litellm.completion(**base_kw, messages=messages, stream=True) return render_streaming_response(stream) for _ in range(5): resp = litellm.completion( **base_kw, messages=messages, tools=tools_spec, tool_choice="auto", stream=False, ) msg = resp.choices[0].message if not msg.tool_calls: text = msg.content or "" render_text_response(text) return text messages.append({ "role": "assistant", "content": msg.content, "tool_calls": [ { "id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}, } for tc in msg.tool_calls ], }) results = self._executor.execute_tool_call_batch(msg.tool_calls) for r in results: messages.append({"role": "tool", "tool_call_id": r["tool_call_id"], "content": r["result"]}) return "Step exceeded maximum tool iterations." def _verify_step(self, desc: str, output: str) -> bool: try: resp = litellm.completion( **self._base_kwargs(), messages=[ {"role": "system", "content": _VERIFY_SYSTEM}, {"role": "user", "content": f"Step: {desc}\n\nOutput:\n{output[:1000]}"}, ], stream=False, ) text = (resp.choices[0].message.content or "").upper() return "SUCCESS" in text except Exception: return True def _base_kwargs(self) -> dict: provider = get_provider(self._cfg.ai.provider_id) api_key = get_key(self._cfg.ai.provider_id) if provider.requires_key else "local" kw: dict = { "model": f"{provider.litellm_prefix}{self._cfg.ai.model}", "api_key": api_key, } if self._cfg.ai.base_url: kw["api_base"] = self._cfg.ai.base_url return kw def _step_user_msg( self, task: str, step_num: int, total: int, desc: str, previous_results: list[str], ) -> str: lines = [f"Overall task: {task}", "", f"Step {step_num}/{total}: {desc}"] if previous_results: lines += ["", "Results from previous steps:"] for i, r in enumerate(previous_results, 1): lines.append(f" Step {i}: {r}") return "\n".join(lines) def _ask_plan_approval(self, task: str, steps: list[dict]) -> bool: lines = [f"[bold]Task:[/bold] {task}", "", "[bold]Steps:[/bold]"] for i, step in enumerate(steps, 1): desc = step.get("description", "") agent = step.get("agent", "") suffix = f" [dim][{agent}][/dim]" if agent else "" lines.append(f" {i}. {desc}{suffix}") console.print(Panel( "\n".join(lines), title="[bold cyan]Pyra — Multi-Step Plan[/bold cyan]", border_style="cyan", )) try: answer = console.input("[bold]Execute this plan?[/bold] [dim][y/N][/dim] ").strip().lower() except (KeyboardInterrupt, EOFError): return False return answer == "y"