1cbb40ac93
- plan_and_execute: restrict to 3+ step tasks; prevents over-triggering on simple requests - memory_read: hint to call memory_lookup first to find the correct path Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
354 lines
12 KiB
Python
354 lines
12 KiB
Python
from __future__ import annotations
|
||
|
||
import litellm
|
||
from prompt_toolkit import PromptSession
|
||
from prompt_toolkit.completion import WordCompleter
|
||
from prompt_toolkit.history import FileHistory
|
||
|
||
from pyra.chat.history import ConversationHistory
|
||
from pyra.chat.renderer import (
|
||
console,
|
||
render_error,
|
||
render_info,
|
||
render_injection_warning,
|
||
render_streaming_response,
|
||
render_system,
|
||
render_text_response,
|
||
)
|
||
from pyra.chat.planner import TaskPlanner
|
||
from pyra.config.manager import load_config
|
||
from pyra.config.schema import PyraConfig
|
||
from pyra.memory.reader import list_memories, lookup_memories, read_memory
|
||
from pyra.memory.writer import write_memory
|
||
from pyra.plugins.base import Tool
|
||
from pyra.plugins.executor import ToolExecutor
|
||
from pyra.plugins.registry import PluginRegistry
|
||
from pyra.security.injection import scan_response
|
||
from pyra.setup.providers import get_provider
|
||
from pyra.setup.wizard import fetch_loaded_models
|
||
from pyra.utils.paths import pyra_home
|
||
|
||
_HISTORY_FILE = pyra_home() / ".chat_history"
|
||
|
||
_STATIC_COMMANDS = {
|
||
"/quit": "Exit Pyra",
|
||
"/exit": "Exit Pyra",
|
||
"/clear": "Clear conversation history",
|
||
"/memory list": "List memory files",
|
||
"/config": "Open configuration TUI",
|
||
"/help": "Show available slash commands",
|
||
}
|
||
|
||
|
||
def _handle_memory_lookup(query: str) -> str:
|
||
results = lookup_memories(query)
|
||
if not results:
|
||
return f"No memory entries found matching '{query}'."
|
||
lines = [
|
||
f"- {r['file']}: {r['summary']} (keywords: {', '.join(r['keywords'])})"
|
||
for r in results
|
||
]
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _handle_memory_read(file: str) -> str:
|
||
try:
|
||
return read_memory(file)
|
||
except (FileNotFoundError, PermissionError) as exc:
|
||
return f"Error: {exc}"
|
||
|
||
|
||
def _handle_memory_write(file: str, content: str, summary: str, keywords: list) -> str:
|
||
try:
|
||
write_memory(file, content, summary=summary, keywords=list(keywords))
|
||
return f"Memory saved: {file}"
|
||
except (ValueError, PermissionError) as exc:
|
||
return f"Error: {exc}"
|
||
|
||
|
||
def start_chat() -> None:
|
||
try:
|
||
cfg = load_config()
|
||
except FileNotFoundError as exc:
|
||
render_error(str(exc))
|
||
return
|
||
|
||
registry = PluginRegistry.instance()
|
||
registry.load_all(pyra_home() / "plugins", cfg.plugins.enabled)
|
||
executor = ToolExecutor(registry, console)
|
||
planner = TaskPlanner(cfg, registry, executor)
|
||
registry.register_builtin(Tool(
|
||
name="plan_and_execute",
|
||
description=(
|
||
"Decompose a multi-step task into sequential steps and execute each with "
|
||
"a focused sub-agent. Use only for long tasks with 3 or more sequential "
|
||
"steps that need verification between them — for simple 2-step tasks, do "
|
||
"them directly. Specify 'agent' per step to route to a specialized agent."
|
||
),
|
||
parameters={
|
||
"type": "object",
|
||
"properties": {
|
||
"task": {"type": "string", "description": "Overall task description."},
|
||
"steps": {
|
||
"type": "array",
|
||
"items": {
|
||
"type": "object",
|
||
"properties": {
|
||
"description": {"type": "string", "description": "What this step should accomplish."},
|
||
"agent": {"type": "string", "description": "Optional agent name to handle this step."},
|
||
},
|
||
"required": ["description"],
|
||
},
|
||
"minItems": 1,
|
||
"description": "Ordered steps. Each step optionally routes to a named agent.",
|
||
},
|
||
},
|
||
"required": ["task", "steps"],
|
||
},
|
||
handler=planner.make_tool_handler(),
|
||
requires_approval=False,
|
||
))
|
||
registry.register_builtin(Tool(
|
||
name="memory_lookup",
|
||
description=(
|
||
"Search the memory index by keyword or topic. "
|
||
"Always call this BEFORE memory_write to check whether a matching entry already exists."
|
||
),
|
||
parameters={
|
||
"type": "object",
|
||
"properties": {
|
||
"query": {"type": "string", "description": "Keyword or topic to search for."},
|
||
},
|
||
"required": ["query"],
|
||
},
|
||
handler=_handle_memory_lookup,
|
||
requires_approval=False,
|
||
))
|
||
registry.register_builtin(Tool(
|
||
name="memory_read",
|
||
description="Read the full content of a memory file by its relative path (e.g. 'user/profile.md'). Use memory_lookup first to find the correct path.",
|
||
parameters={
|
||
"type": "object",
|
||
"properties": {
|
||
"file": {"type": "string", "description": "Relative path to the memory file."},
|
||
},
|
||
"required": ["file"],
|
||
},
|
||
handler=_handle_memory_read,
|
||
requires_approval=False,
|
||
))
|
||
registry.register_builtin(Tool(
|
||
name="memory_write",
|
||
description=(
|
||
"Write or overwrite a memory file. Always call memory_lookup first to avoid duplicates. "
|
||
"If an existing file covers the same topic, read it first and merge the content."
|
||
),
|
||
parameters={
|
||
"type": "object",
|
||
"properties": {
|
||
"file": {"type": "string", "description": "Relative path, e.g. 'user/profile.md' or 'knowledge/python_tips.md'."},
|
||
"content": {"type": "string", "description": "Full Markdown content to write."},
|
||
"summary": {"type": "string", "description": "One-sentence summary of what this memory file stores."},
|
||
"keywords": {
|
||
"type": "array",
|
||
"items": {"type": "string"},
|
||
"description": "Keywords for index lookup (3–8 terms).",
|
||
},
|
||
},
|
||
"required": ["file", "content", "summary", "keywords"],
|
||
},
|
||
handler=_handle_memory_write,
|
||
requires_approval=False,
|
||
))
|
||
|
||
history = ConversationHistory(cfg, registry)
|
||
|
||
plugin_slash = registry.get_slash_commands()
|
||
all_commands = list(_STATIC_COMMANDS) + list(plugin_slash)
|
||
session: PromptSession = PromptSession(
|
||
history=FileHistory(str(_HISTORY_FILE)),
|
||
completer=WordCompleter(all_commands, sentence=True),
|
||
complete_while_typing=False,
|
||
multiline=False,
|
||
)
|
||
|
||
provider = get_provider(cfg.ai.provider_id)
|
||
render_system(
|
||
f"[bold cyan]Pyra[/bold cyan] | {provider.display_name} | {cfg.ai.model}\n"
|
||
"[dim]Type /help for commands, /quit to exit.[/dim]"
|
||
)
|
||
|
||
if provider.group == "Local":
|
||
loaded = fetch_loaded_models(provider)
|
||
if not loaded:
|
||
render_info(f"No model currently loaded in {provider.display_name}.")
|
||
elif cfg.ai.model not in loaded:
|
||
render_info(
|
||
f"Model '{cfg.ai.model}' not loaded in {provider.display_name}. "
|
||
f"Loaded: {', '.join(loaded)}"
|
||
)
|
||
|
||
_flags: dict = {"use_tools": True}
|
||
|
||
while True:
|
||
try:
|
||
user_input = session.prompt("› ").strip()
|
||
except (KeyboardInterrupt, EOFError):
|
||
console.print("\n[dim]Goodbye.[/dim]")
|
||
break
|
||
|
||
if not user_input:
|
||
continue
|
||
|
||
if user_input in ("/quit", "/exit"):
|
||
console.print("[dim]Goodbye.[/dim]")
|
||
break
|
||
|
||
if user_input == "/clear":
|
||
history.clear()
|
||
render_info("Conversation cleared.")
|
||
continue
|
||
|
||
if user_input == "/help":
|
||
_show_help(plugin_slash)
|
||
continue
|
||
|
||
if user_input == "/memory list":
|
||
_show_memory_list()
|
||
continue
|
||
|
||
if user_input == "/config":
|
||
from pyra.config.tui import launch_config_tui
|
||
launch_config_tui()
|
||
try:
|
||
cfg = load_config()
|
||
except FileNotFoundError:
|
||
pass
|
||
continue
|
||
|
||
if user_input in plugin_slash:
|
||
try:
|
||
plugin_slash[user_input]()
|
||
except Exception as exc:
|
||
render_error(f"Plugin command error: {exc}")
|
||
continue
|
||
|
||
if user_input.startswith("/"):
|
||
render_error(f"Unknown command: {user_input!r}. Type /help for commands.")
|
||
continue
|
||
|
||
history.add_user(user_input)
|
||
|
||
try:
|
||
response_text = _call_ai(cfg, history, registry, executor, _flags)
|
||
except Exception as exc:
|
||
render_error(f"AI error: {exc}")
|
||
history._messages.pop()
|
||
continue
|
||
|
||
history.add_assistant(response_text)
|
||
|
||
warnings = scan_response(response_text)
|
||
if warnings:
|
||
render_injection_warning(warnings)
|
||
|
||
|
||
def _call_ai(
|
||
cfg: PyraConfig,
|
||
history: ConversationHistory,
|
||
registry: PluginRegistry,
|
||
executor: ToolExecutor,
|
||
flags: dict | None = None,
|
||
) -> str:
|
||
from pyra.vault.reader import get_key
|
||
|
||
provider = get_provider(cfg.ai.provider_id)
|
||
api_key = get_key(cfg.ai.provider_id) if provider.requires_key else "local"
|
||
|
||
base_kwargs: dict = {
|
||
"model": f"{provider.litellm_prefix}{cfg.ai.model}",
|
||
"api_key": api_key,
|
||
}
|
||
effective_base_url = cfg.ai.base_url or provider.base_url
|
||
if effective_base_url:
|
||
base_kwargs["api_base"] = effective_base_url
|
||
|
||
litellm.suppress_debug_info = True
|
||
|
||
tools = registry.get_all_tools()
|
||
tools_spec = [
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": t.name,
|
||
"description": t.description,
|
||
"parameters": t.parameters,
|
||
},
|
||
}
|
||
for t in tools
|
||
]
|
||
|
||
# No tools active, or provider known not to support function calling
|
||
use_tools = flags is None or flags.get("use_tools", True)
|
||
if not tools_spec or not use_tools:
|
||
stream = litellm.completion(
|
||
**base_kwargs,
|
||
messages=history.build_for_api(),
|
||
stream=True,
|
||
)
|
||
return render_streaming_response(stream)
|
||
|
||
# Plugin tool-use loop (non-streaming for tool calls, renders final response)
|
||
try:
|
||
for _iteration in range(10):
|
||
response = litellm.completion(
|
||
**base_kwargs,
|
||
messages=history.build_for_api(),
|
||
tools=tools_spec,
|
||
tool_choice="auto",
|
||
stream=False,
|
||
)
|
||
message = response.choices[0].message
|
||
|
||
if not message.tool_calls:
|
||
return render_text_response(message.content or "")
|
||
|
||
history.add_tool_call_message(message)
|
||
results = executor.execute_tool_call_batch(message.tool_calls)
|
||
for r in results:
|
||
history.add_tool_result(r["tool_call_id"], r["result"])
|
||
|
||
return render_text_response("Error: tool-use loop exceeded maximum iterations.")
|
||
|
||
except litellm.BadRequestError:
|
||
if flags is not None:
|
||
flags["use_tools"] = False
|
||
render_info("This model does not support function calling — tools disabled.")
|
||
stream = litellm.completion(
|
||
**base_kwargs,
|
||
messages=history.build_for_api(),
|
||
stream=True,
|
||
)
|
||
return render_streaming_response(stream)
|
||
|
||
|
||
def _show_help(plugin_slash: dict) -> None:
|
||
lines = ["[bold]Slash commands:[/bold]"]
|
||
for cmd, desc in _STATIC_COMMANDS.items():
|
||
lines.append(f" [cyan]{cmd:<20}[/cyan] {desc}")
|
||
if plugin_slash:
|
||
lines.append("[bold]Plugin commands:[/bold]")
|
||
for cmd in sorted(plugin_slash):
|
||
lines.append(f" [cyan]{cmd:<20}[/cyan]")
|
||
console.print("\n".join(lines))
|
||
|
||
|
||
def _show_memory_list() -> None:
|
||
memories = list_memories()
|
||
if not memories:
|
||
render_info("No memory files found.")
|
||
return
|
||
for m in memories:
|
||
mtime = m.modified.strftime("%Y-%m-%d")
|
||
console.print(f" [cyan]{m.name:<40}[/cyan] {m.category:<12} {mtime}")
|