Files
Pyra/src/pyra/chat/session.py
T
curo1305 1cbb40ac93 chore(chat): tighten tool descriptions to reduce AI selection confusion
- plan_and_execute: restrict to 3+ step tasks; prevents over-triggering on simple requests
- memory_read: hint to call memory_lookup first to find the correct path

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 23:50:21 +02:00

354 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import litellm
from prompt_toolkit import PromptSession
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.history import FileHistory
from pyra.chat.history import ConversationHistory
from pyra.chat.renderer import (
console,
render_error,
render_info,
render_injection_warning,
render_streaming_response,
render_system,
render_text_response,
)
from pyra.chat.planner import TaskPlanner
from pyra.config.manager import load_config
from pyra.config.schema import PyraConfig
from pyra.memory.reader import list_memories, lookup_memories, read_memory
from pyra.memory.writer import write_memory
from pyra.plugins.base import Tool
from pyra.plugins.executor import ToolExecutor
from pyra.plugins.registry import PluginRegistry
from pyra.security.injection import scan_response
from pyra.setup.providers import get_provider
from pyra.setup.wizard import fetch_loaded_models
from pyra.utils.paths import pyra_home
_HISTORY_FILE = pyra_home() / ".chat_history"
_STATIC_COMMANDS = {
"/quit": "Exit Pyra",
"/exit": "Exit Pyra",
"/clear": "Clear conversation history",
"/memory list": "List memory files",
"/config": "Open configuration TUI",
"/help": "Show available slash commands",
}
def _handle_memory_lookup(query: str) -> str:
results = lookup_memories(query)
if not results:
return f"No memory entries found matching '{query}'."
lines = [
f"- {r['file']}: {r['summary']} (keywords: {', '.join(r['keywords'])})"
for r in results
]
return "\n".join(lines)
def _handle_memory_read(file: str) -> str:
try:
return read_memory(file)
except (FileNotFoundError, PermissionError) as exc:
return f"Error: {exc}"
def _handle_memory_write(file: str, content: str, summary: str, keywords: list) -> str:
try:
write_memory(file, content, summary=summary, keywords=list(keywords))
return f"Memory saved: {file}"
except (ValueError, PermissionError) as exc:
return f"Error: {exc}"
def start_chat() -> None:
try:
cfg = load_config()
except FileNotFoundError as exc:
render_error(str(exc))
return
registry = PluginRegistry.instance()
registry.load_all(pyra_home() / "plugins", cfg.plugins.enabled)
executor = ToolExecutor(registry, console)
planner = TaskPlanner(cfg, registry, executor)
registry.register_builtin(Tool(
name="plan_and_execute",
description=(
"Decompose a multi-step task into sequential steps and execute each with "
"a focused sub-agent. Use only for long tasks with 3 or more sequential "
"steps that need verification between them — for simple 2-step tasks, do "
"them directly. Specify 'agent' per step to route to a specialized agent."
),
parameters={
"type": "object",
"properties": {
"task": {"type": "string", "description": "Overall task description."},
"steps": {
"type": "array",
"items": {
"type": "object",
"properties": {
"description": {"type": "string", "description": "What this step should accomplish."},
"agent": {"type": "string", "description": "Optional agent name to handle this step."},
},
"required": ["description"],
},
"minItems": 1,
"description": "Ordered steps. Each step optionally routes to a named agent.",
},
},
"required": ["task", "steps"],
},
handler=planner.make_tool_handler(),
requires_approval=False,
))
registry.register_builtin(Tool(
name="memory_lookup",
description=(
"Search the memory index by keyword or topic. "
"Always call this BEFORE memory_write to check whether a matching entry already exists."
),
parameters={
"type": "object",
"properties": {
"query": {"type": "string", "description": "Keyword or topic to search for."},
},
"required": ["query"],
},
handler=_handle_memory_lookup,
requires_approval=False,
))
registry.register_builtin(Tool(
name="memory_read",
description="Read the full content of a memory file by its relative path (e.g. 'user/profile.md'). Use memory_lookup first to find the correct path.",
parameters={
"type": "object",
"properties": {
"file": {"type": "string", "description": "Relative path to the memory file."},
},
"required": ["file"],
},
handler=_handle_memory_read,
requires_approval=False,
))
registry.register_builtin(Tool(
name="memory_write",
description=(
"Write or overwrite a memory file. Always call memory_lookup first to avoid duplicates. "
"If an existing file covers the same topic, read it first and merge the content."
),
parameters={
"type": "object",
"properties": {
"file": {"type": "string", "description": "Relative path, e.g. 'user/profile.md' or 'knowledge/python_tips.md'."},
"content": {"type": "string", "description": "Full Markdown content to write."},
"summary": {"type": "string", "description": "One-sentence summary of what this memory file stores."},
"keywords": {
"type": "array",
"items": {"type": "string"},
"description": "Keywords for index lookup (38 terms).",
},
},
"required": ["file", "content", "summary", "keywords"],
},
handler=_handle_memory_write,
requires_approval=False,
))
history = ConversationHistory(cfg, registry)
plugin_slash = registry.get_slash_commands()
all_commands = list(_STATIC_COMMANDS) + list(plugin_slash)
session: PromptSession = PromptSession(
history=FileHistory(str(_HISTORY_FILE)),
completer=WordCompleter(all_commands, sentence=True),
complete_while_typing=False,
multiline=False,
)
provider = get_provider(cfg.ai.provider_id)
render_system(
f"[bold cyan]Pyra[/bold cyan] | {provider.display_name} | {cfg.ai.model}\n"
"[dim]Type /help for commands, /quit to exit.[/dim]"
)
if provider.group == "Local":
loaded = fetch_loaded_models(provider)
if not loaded:
render_info(f"No model currently loaded in {provider.display_name}.")
elif cfg.ai.model not in loaded:
render_info(
f"Model '{cfg.ai.model}' not loaded in {provider.display_name}. "
f"Loaded: {', '.join(loaded)}"
)
_flags: dict = {"use_tools": True}
while True:
try:
user_input = session.prompt(" ").strip()
except (KeyboardInterrupt, EOFError):
console.print("\n[dim]Goodbye.[/dim]")
break
if not user_input:
continue
if user_input in ("/quit", "/exit"):
console.print("[dim]Goodbye.[/dim]")
break
if user_input == "/clear":
history.clear()
render_info("Conversation cleared.")
continue
if user_input == "/help":
_show_help(plugin_slash)
continue
if user_input == "/memory list":
_show_memory_list()
continue
if user_input == "/config":
from pyra.config.tui import launch_config_tui
launch_config_tui()
try:
cfg = load_config()
except FileNotFoundError:
pass
continue
if user_input in plugin_slash:
try:
plugin_slash[user_input]()
except Exception as exc:
render_error(f"Plugin command error: {exc}")
continue
if user_input.startswith("/"):
render_error(f"Unknown command: {user_input!r}. Type /help for commands.")
continue
history.add_user(user_input)
try:
response_text = _call_ai(cfg, history, registry, executor, _flags)
except Exception as exc:
render_error(f"AI error: {exc}")
history._messages.pop()
continue
history.add_assistant(response_text)
warnings = scan_response(response_text)
if warnings:
render_injection_warning(warnings)
def _call_ai(
cfg: PyraConfig,
history: ConversationHistory,
registry: PluginRegistry,
executor: ToolExecutor,
flags: dict | None = None,
) -> str:
from pyra.vault.reader import get_key
provider = get_provider(cfg.ai.provider_id)
api_key = get_key(cfg.ai.provider_id) if provider.requires_key else "local"
base_kwargs: dict = {
"model": f"{provider.litellm_prefix}{cfg.ai.model}",
"api_key": api_key,
}
effective_base_url = cfg.ai.base_url or provider.base_url
if effective_base_url:
base_kwargs["api_base"] = effective_base_url
litellm.suppress_debug_info = True
tools = registry.get_all_tools()
tools_spec = [
{
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.parameters,
},
}
for t in tools
]
# No tools active, or provider known not to support function calling
use_tools = flags is None or flags.get("use_tools", True)
if not tools_spec or not use_tools:
stream = litellm.completion(
**base_kwargs,
messages=history.build_for_api(),
stream=True,
)
return render_streaming_response(stream)
# Plugin tool-use loop (non-streaming for tool calls, renders final response)
try:
for _iteration in range(10):
response = litellm.completion(
**base_kwargs,
messages=history.build_for_api(),
tools=tools_spec,
tool_choice="auto",
stream=False,
)
message = response.choices[0].message
if not message.tool_calls:
return render_text_response(message.content or "")
history.add_tool_call_message(message)
results = executor.execute_tool_call_batch(message.tool_calls)
for r in results:
history.add_tool_result(r["tool_call_id"], r["result"])
return render_text_response("Error: tool-use loop exceeded maximum iterations.")
except litellm.BadRequestError:
if flags is not None:
flags["use_tools"] = False
render_info("This model does not support function calling — tools disabled.")
stream = litellm.completion(
**base_kwargs,
messages=history.build_for_api(),
stream=True,
)
return render_streaming_response(stream)
def _show_help(plugin_slash: dict) -> None:
lines = ["[bold]Slash commands:[/bold]"]
for cmd, desc in _STATIC_COMMANDS.items():
lines.append(f" [cyan]{cmd:<20}[/cyan] {desc}")
if plugin_slash:
lines.append("[bold]Plugin commands:[/bold]")
for cmd in sorted(plugin_slash):
lines.append(f" [cyan]{cmd:<20}[/cyan]")
console.print("\n".join(lines))
def _show_memory_list() -> None:
memories = list_memories()
if not memories:
render_info("No memory files found.")
return
for m in memories:
mtime = m.modified.strftime("%Y-%m-%d")
console.print(f" [cyan]{m.name:<40}[/cyan] {m.category:<12} {mtime}")