from __future__ import annotations import litellm from prompt_toolkit import PromptSession from prompt_toolkit.completion import WordCompleter from prompt_toolkit.history import FileHistory from pyra.chat.history import ConversationHistory from pyra.chat.renderer import ( console, render_error, render_info, render_injection_warning, render_streaming_response, render_system, render_text_response, ) from pyra.chat.planner import TaskPlanner from pyra.config.manager import load_config from pyra.config.schema import PyraConfig from pyra.memory.reader import list_memories, lookup_memories, read_memory from pyra.memory.writer import write_memory from pyra.plugins.base import Tool from pyra.plugins.executor import ToolExecutor from pyra.plugins.registry import PluginRegistry from pyra.security.injection import scan_response from pyra.setup.providers import get_provider from pyra.setup.wizard import fetch_loaded_models from pyra.utils.paths import pyra_home _HISTORY_FILE = pyra_home() / ".chat_history" _STATIC_COMMANDS = { "/quit": "Exit Pyra", "/exit": "Exit Pyra", "/clear": "Clear conversation history", "/memory list": "List memory files", "/config": "Open configuration TUI", "/help": "Show available slash commands", } def _handle_memory_lookup(query: str) -> str: results = lookup_memories(query) if not results: return f"No memory entries found matching '{query}'." lines = [ f"- {r['file']}: {r['summary']} (keywords: {', '.join(r['keywords'])})" for r in results ] return "\n".join(lines) def _handle_memory_read(file: str) -> str: try: return read_memory(file) except (FileNotFoundError, PermissionError) as exc: return f"Error: {exc}" def _handle_memory_write(file: str, content: str, summary: str, keywords: list) -> str: try: write_memory(file, content, summary=summary, keywords=list(keywords)) return f"Memory saved: {file}" except (ValueError, PermissionError) as exc: return f"Error: {exc}" def start_chat() -> None: try: cfg = load_config() except FileNotFoundError as exc: render_error(str(exc)) return registry = PluginRegistry.instance() registry.load_all(pyra_home() / "plugins", cfg.plugins.enabled) executor = ToolExecutor(registry, console) planner = TaskPlanner(cfg, registry, executor) registry.register_builtin(Tool( name="plan_and_execute", description=( "Decompose a multi-step task into sequential steps and execute each with " "a focused sub-agent. Use only for long tasks with 3 or more sequential " "steps that need verification between them — for simple 2-step tasks, do " "them directly. Specify 'agent' per step to route to a specialized agent." ), parameters={ "type": "object", "properties": { "task": {"type": "string", "description": "Overall task description."}, "steps": { "type": "array", "items": { "type": "object", "properties": { "description": {"type": "string", "description": "What this step should accomplish."}, "agent": {"type": "string", "description": "Optional agent name to handle this step."}, }, "required": ["description"], }, "minItems": 1, "description": "Ordered steps. Each step optionally routes to a named agent.", }, }, "required": ["task", "steps"], }, handler=planner.make_tool_handler(), requires_approval=False, )) registry.register_builtin(Tool( name="memory_lookup", description=( "Search the memory index by keyword or topic. " "Always call this BEFORE memory_write to check whether a matching entry already exists." ), parameters={ "type": "object", "properties": { "query": {"type": "string", "description": "Keyword or topic to search for."}, }, "required": ["query"], }, handler=_handle_memory_lookup, requires_approval=False, )) registry.register_builtin(Tool( name="memory_read", description="Read the full content of a memory file by its relative path (e.g. 'user/profile.md'). Use memory_lookup first to find the correct path.", parameters={ "type": "object", "properties": { "file": {"type": "string", "description": "Relative path to the memory file."}, }, "required": ["file"], }, handler=_handle_memory_read, requires_approval=False, )) registry.register_builtin(Tool( name="memory_write", description=( "Write or overwrite a memory file. Always call memory_lookup first to avoid duplicates. " "If an existing file covers the same topic, read it first and merge the content." ), parameters={ "type": "object", "properties": { "file": {"type": "string", "description": "Relative path, e.g. 'user/profile.md' or 'knowledge/python_tips.md'."}, "content": {"type": "string", "description": "Full Markdown content to write."}, "summary": {"type": "string", "description": "One-sentence summary of what this memory file stores."}, "keywords": { "type": "array", "items": {"type": "string"}, "description": "Keywords for index lookup (3–8 terms).", }, }, "required": ["file", "content", "summary", "keywords"], }, handler=_handle_memory_write, requires_approval=False, )) history = ConversationHistory(cfg, registry) plugin_slash = registry.get_slash_commands() all_commands = list(_STATIC_COMMANDS) + list(plugin_slash) session: PromptSession = PromptSession( history=FileHistory(str(_HISTORY_FILE)), completer=WordCompleter(all_commands, sentence=True), complete_while_typing=False, multiline=False, ) provider = get_provider(cfg.ai.provider_id) render_system( f"[bold cyan]Pyra[/bold cyan] | {provider.display_name} | {cfg.ai.model}\n" "[dim]Type /help for commands, /quit to exit.[/dim]" ) if provider.group == "Local": loaded = fetch_loaded_models(provider) if not loaded: render_info(f"No model currently loaded in {provider.display_name}.") elif cfg.ai.model not in loaded: render_info( f"Model '{cfg.ai.model}' not loaded in {provider.display_name}. " f"Loaded: {', '.join(loaded)}" ) _flags: dict = {"use_tools": True} while True: try: user_input = session.prompt("› ").strip() except (KeyboardInterrupt, EOFError): console.print("\n[dim]Goodbye.[/dim]") break if not user_input: continue if user_input in ("/quit", "/exit"): console.print("[dim]Goodbye.[/dim]") break if user_input == "/clear": history.clear() render_info("Conversation cleared.") continue if user_input == "/help": _show_help(plugin_slash) continue if user_input == "/memory list": _show_memory_list() continue if user_input == "/config": from pyra.config.tui import launch_config_tui launch_config_tui() try: cfg = load_config() except FileNotFoundError: pass continue if user_input in plugin_slash: try: plugin_slash[user_input]() except Exception as exc: render_error(f"Plugin command error: {exc}") continue if user_input.startswith("/"): render_error(f"Unknown command: {user_input!r}. Type /help for commands.") continue history.add_user(user_input) try: response_text = _call_ai(cfg, history, registry, executor, _flags) except Exception as exc: render_error(f"AI error: {exc}") history._messages.pop() continue history.add_assistant(response_text) warnings = scan_response(response_text) if warnings: render_injection_warning(warnings) def _call_ai( cfg: PyraConfig, history: ConversationHistory, registry: PluginRegistry, executor: ToolExecutor, flags: dict | None = None, ) -> str: from pyra.vault.reader import get_key provider = get_provider(cfg.ai.provider_id) api_key = get_key(cfg.ai.provider_id) if provider.requires_key else "local" base_kwargs: dict = { "model": f"{provider.litellm_prefix}{cfg.ai.model}", "api_key": api_key, } effective_base_url = cfg.ai.base_url or provider.base_url if effective_base_url: base_kwargs["api_base"] = effective_base_url litellm.suppress_debug_info = True tools = registry.get_all_tools() tools_spec = [ { "type": "function", "function": { "name": t.name, "description": t.description, "parameters": t.parameters, }, } for t in tools ] # No tools active, or provider known not to support function calling use_tools = flags is None or flags.get("use_tools", True) if not tools_spec or not use_tools: stream = litellm.completion( **base_kwargs, messages=history.build_for_api(), stream=True, ) return render_streaming_response(stream) # Plugin tool-use loop (non-streaming for tool calls, renders final response) try: for _iteration in range(10): response = litellm.completion( **base_kwargs, messages=history.build_for_api(), tools=tools_spec, tool_choice="auto", stream=False, ) message = response.choices[0].message if not message.tool_calls: return render_text_response(message.content or "") history.add_tool_call_message(message) results = executor.execute_tool_call_batch(message.tool_calls) for r in results: history.add_tool_result(r["tool_call_id"], r["result"]) return render_text_response("Error: tool-use loop exceeded maximum iterations.") except litellm.BadRequestError: if flags is not None: flags["use_tools"] = False render_info("This model does not support function calling — tools disabled.") stream = litellm.completion( **base_kwargs, messages=history.build_for_api(), stream=True, ) return render_streaming_response(stream) def _show_help(plugin_slash: dict) -> None: lines = ["[bold]Slash commands:[/bold]"] for cmd, desc in _STATIC_COMMANDS.items(): lines.append(f" [cyan]{cmd:<20}[/cyan] {desc}") if plugin_slash: lines.append("[bold]Plugin commands:[/bold]") for cmd in sorted(plugin_slash): lines.append(f" [cyan]{cmd:<20}[/cyan]") console.print("\n".join(lines)) def _show_memory_list() -> None: memories = list_memories() if not memories: render_info("No memory files found.") return for m in memories: mtime = m.modified.strftime("%Y-%m-%d") console.print(f" [cyan]{m.name:<40}[/cyan] {m.category:<12} {mtime}")