diff --git a/src/pyra/setup/providers.py b/src/pyra/setup/providers.py new file mode 100644 index 0000000..6e97cfd --- /dev/null +++ b/src/pyra/setup/providers.py @@ -0,0 +1,104 @@ +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class Provider: + id: str + display_name: str + requires_key: bool + default_model: str + litellm_prefix: str + base_url: str | None = None + key_env_var: str | None = None + connectivity_check: str | None = None + group: str = "Cloud" + + +PROVIDERS: list[Provider] = [ + # ── Local ──────────────────────────────────────────────────────────────── + Provider( + id="lmstudio", + display_name="LM Studio (local)", + requires_key=False, + default_model="local-model", + litellm_prefix="openai/", + base_url="http://localhost:1234/v1", + connectivity_check="http://localhost:1234/v1/models", + group="Local", + ), + Provider( + id="ollama", + display_name="Ollama (local)", + requires_key=False, + default_model="llama3", + litellm_prefix="ollama/", + base_url="http://localhost:11434", + connectivity_check="http://localhost:11434/api/tags", + group="Local", + ), + Provider( + id="llamacpp", + display_name="llama.cpp server (local)", + requires_key=False, + default_model="local-model", + litellm_prefix="openai/", + base_url="http://localhost:8080/v1", + connectivity_check="http://localhost:8080/v1/models", + group="Local", + ), + # ── Cloud ──────────────────────────────────────────────────────────────── + Provider( + id="anthropic", + display_name="Anthropic (Claude)", + requires_key=True, + default_model="claude-sonnet-4-6", + litellm_prefix="anthropic/", + key_env_var="ANTHROPIC_API_KEY", + group="Cloud", + ), + Provider( + id="openai", + display_name="OpenAI (GPT)", + requires_key=True, + default_model="gpt-4o", + litellm_prefix="openai/", + key_env_var="OPENAI_API_KEY", + group="Cloud", + ), + Provider( + id="gemini", + display_name="Google (Gemini)", + requires_key=True, + default_model="gemini/gemini-2.0-flash", + litellm_prefix="gemini/", + key_env_var="GEMINI_API_KEY", + group="Cloud", + ), + Provider( + id="deepseek", + display_name="DeepSeek", + requires_key=True, + default_model="deepseek/deepseek-chat", + litellm_prefix="deepseek/", + key_env_var="DEEPSEEK_API_KEY", + group="Cloud", + ), + Provider( + id="qwen", + display_name="Qwen (Alibaba)", + requires_key=True, + default_model="openai/qwen-plus", + litellm_prefix="openai/", + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", + key_env_var="DASHSCOPE_API_KEY", + group="Cloud", + ), +] + +PROVIDERS_BY_ID: dict[str, Provider] = {p.id: p for p in PROVIDERS} + + +def get_provider(provider_id: str) -> Provider: + if provider_id not in PROVIDERS_BY_ID: + raise KeyError(f"Unknown provider: {provider_id!r}") + return PROVIDERS_BY_ID[provider_id] diff --git a/src/pyra/setup/wizard.py b/src/pyra/setup/wizard.py new file mode 100644 index 0000000..3cc1b6f --- /dev/null +++ b/src/pyra/setup/wizard.py @@ -0,0 +1,139 @@ +import httpx +import questionary +from rich.console import Console +from rich.panel import Panel +from rich.text import Text + +from pyra.config.manager import save_config +from pyra.config.schema import ProviderConfig, PyraConfig +from pyra.setup.providers import PROVIDERS, Provider, get_provider + +console = Console() + + +def run_setup() -> None: + console.print(Panel( + Text("Welcome to Pyra Setup", justify="center", style="bold cyan"), + subtitle="Personal AI Assistant", + border_style="cyan", + )) + console.print() + + provider = _choose_provider() + model = _choose_model(provider) + + if provider.requires_key: + _collect_api_key(provider) + + _test_connection(provider, model) + + cfg = PyraConfig( + ai=ProviderConfig( + provider_id=provider.id, + model=model, + base_url=provider.base_url, + ) + ) + save_config(cfg) + + console.print() + console.print(Panel( + f"[green]Setup complete![/green]\n\n" + f"Provider: [bold]{provider.display_name}[/bold]\n" + f"Model: [bold]{model}[/bold]\n\n" + "Run [bold cyan]pyra chat[/bold cyan] to start talking.", + border_style="green", + )) + + +def _choose_provider() -> Provider: + local = [p for p in PROVIDERS if p.group == "Local"] + cloud = [p for p in PROVIDERS if p.group == "Cloud"] + + choices = ( + [questionary.Choice("── Local ──────────────────", disabled=True)] + + [questionary.Choice(p.display_name, value=p.id) for p in local] + + [questionary.Choice("── Cloud ──────────────────", disabled=True)] + + [questionary.Choice(p.display_name, value=p.id) for p in cloud] + ) + + provider_id = questionary.select( + "Choose your AI provider:", + choices=choices, + ).ask() + + if provider_id is None: + raise SystemExit(0) + + provider = get_provider(provider_id) + + if provider.connectivity_check: + _check_local_server(provider) + + return provider + + +def _check_local_server(provider: Provider) -> None: + console.print(f" Checking connection to [bold]{provider.display_name}[/bold]...", end=" ") + try: + resp = httpx.get(provider.connectivity_check, timeout=3.0) + resp.raise_for_status() + console.print("[green]✓[/green]") + except Exception: + console.print("[yellow]✗ (server not reachable)[/yellow]") + console.print( + f" [yellow]Warning:[/yellow] Could not reach {provider.base_url}.\n" + f" Make sure {provider.display_name} is running before using Pyra." + ) + + +def _choose_model(provider: Provider) -> str: + model = questionary.text( + "Model name:", + default=provider.default_model, + ).ask() + if model is None: + raise SystemExit(0) + return model.strip() + + +def _collect_api_key(provider: Provider) -> None: + from pyra.vault.writer import set_key + + console.print( + f"\n [dim]API key will be stored in the encrypted vault — never in config.yaml[/dim]" + ) + key = questionary.password(f"Enter your {provider.display_name} API key:").ask() + if key is None: + raise SystemExit(0) + key = key.strip() + if not key: + console.print("[red]No key entered — skipping.[/red]") + return + set_key(provider.id, key) + console.print(" [green]✓ Key stored in vault[/green]") + + +def _test_connection(provider: Provider, model: str) -> None: + from pyra.vault.reader import get_key + + console.print("\n Running test call...", end=" ") + try: + import litellm + + api_key = get_key(provider.id) if provider.requires_key else "no-key" + kwargs: dict = { + "model": f"{provider.litellm_prefix}{model}", + "messages": [{"role": "user", "content": "Reply with exactly: OK"}], + "max_tokens": 10, + } + if provider.base_url: + kwargs["api_base"] = provider.base_url + if api_key and api_key != "no-key": + kwargs["api_key"] = api_key + + litellm.completion(**kwargs) + console.print("[green]✓ Connection OK[/green]") + except Exception as exc: + console.print(f"[yellow]✗ Test call failed: {exc}[/yellow]") + console.print(" [dim]You can still proceed — check your config with 'pyra setup' again.[/dim]")