"""Tests for POST /chat.""" import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest from tests.conftest import ANTHROPIC_CONFIG, LMSTUDIO_CONFIG, OLLAMA_CONFIG # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- _LOAD_CONFIG = "app.routers.chat.load_ai_config" _PROVIDER_CHAT = "app.providers.openai_compat.OpenAICompatProvider.chat" _ANTHROPIC_CHAT = "app.providers.anthropic_provider.AnthropicProvider.chat" MESSAGES = [{"role": "user", "content": "Hello"}] SYSTEM_MESSAGES = [ {"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hello"}, ] def _mock_chat_response(content="ok", input_tokens=10, output_tokens=5): return AsyncMock(return_value=(content, input_tokens, output_tokens)) # --------------------------------------------------------------------------- # Success: each provider # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_chat_lmstudio_success(ai_client): with patch(_LOAD_CONFIG, return_value=LMSTUDIO_CONFIG), patch( _PROVIDER_CHAT, new=_mock_chat_response("lmstudio reply") ): resp = await ai_client.post("/chat", json={"messages": MESSAGES}) assert resp.status_code == 200 data = resp.json() assert data["content"] == "lmstudio reply" assert data["provider"] == "lmstudio" assert data["model"] == "test-model" assert data["input_tokens"] == 10 assert data["output_tokens"] == 5 @pytest.mark.asyncio async def test_chat_ollama_success(ai_client): with patch(_LOAD_CONFIG, return_value=OLLAMA_CONFIG), patch( _PROVIDER_CHAT, new=_mock_chat_response("ollama reply") ): resp = await ai_client.post("/chat", json={"messages": MESSAGES}) assert resp.status_code == 200 data = resp.json() assert data["content"] == "ollama reply" assert data["provider"] == "ollama" @pytest.mark.asyncio async def test_chat_anthropic_success(ai_client): with patch(_LOAD_CONFIG, return_value=ANTHROPIC_CONFIG), patch( _ANTHROPIC_CHAT, new=_mock_chat_response("anthropic reply") ): resp = await ai_client.post("/chat", json={"messages": MESSAGES}) assert resp.status_code == 200 data = resp.json() assert data["content"] == "anthropic reply" assert data["provider"] == "anthropic" # --------------------------------------------------------------------------- # response_format # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_response_format_json_strips_fences(ai_client): fenced = "```json\n{\"key\": \"value\"}\n```" with patch(_LOAD_CONFIG, return_value=LMSTUDIO_CONFIG), patch( _PROVIDER_CHAT, new=_mock_chat_response(fenced) ): resp = await ai_client.post( "/chat", json={"messages": MESSAGES, "response_format": "json"}, ) assert resp.status_code == 200 assert resp.json()["content"] == '{"key": "value"}' @pytest.mark.asyncio async def test_response_format_text_preserves_fences(ai_client): fenced = "```python\nprint('hi')\n```" with patch(_LOAD_CONFIG, return_value=LMSTUDIO_CONFIG), patch( _PROVIDER_CHAT, new=_mock_chat_response(fenced) ): resp = await ai_client.post( "/chat", json={"messages": MESSAGES, "response_format": "text"}, ) assert resp.status_code == 200 assert "```" in resp.json()["content"] # --------------------------------------------------------------------------- # Validation errors # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_chat_missing_messages_returns_422(ai_client): resp = await ai_client.post("/chat", json={}) assert resp.status_code == 422 @pytest.mark.asyncio async def test_chat_empty_messages_returns_422(ai_client): resp = await ai_client.post("/chat", json={"messages": []}) assert resp.status_code == 422 # --------------------------------------------------------------------------- # Provider errors # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_chat_connection_error_returns_502(ai_client): from app.providers.openai_compat import ProviderConnectionError with patch(_LOAD_CONFIG, return_value=LMSTUDIO_CONFIG), patch( _PROVIDER_CHAT, side_effect=ProviderConnectionError("refused") ): resp = await ai_client.post("/chat", json={"messages": MESSAGES}) assert resp.status_code == 502 @pytest.mark.asyncio async def test_chat_timeout_returns_504(ai_client): async def _slow(*_args, **_kwargs): await asyncio.sleep(100) with patch(_LOAD_CONFIG, return_value={**LMSTUDIO_CONFIG, "timeout_seconds": 0.01}), patch( _PROVIDER_CHAT, new=_slow ): resp = await ai_client.post("/chat", json={"messages": MESSAGES}) assert resp.status_code == 504 @pytest.mark.asyncio async def test_chat_unknown_provider_returns_503(ai_client): bad_config = {**LMSTUDIO_CONFIG, "provider": "unknown-llm"} with patch(_LOAD_CONFIG, return_value=bad_config): resp = await ai_client.post("/chat", json={"messages": MESSAGES}) assert resp.status_code == 503 # --------------------------------------------------------------------------- # Anthropic system message extraction # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_anthropic_system_message_extracted(ai_client): """System-role messages must not appear in the user_messages list sent to Anthropic.""" captured_kwargs: dict = {} async def _fake_create(**kwargs): captured_kwargs.update(kwargs) mock_resp = MagicMock() mock_resp.content = [MagicMock(text="ok")] mock_resp.usage = MagicMock(input_tokens=5, output_tokens=2) return mock_resp with patch(_LOAD_CONFIG, return_value=ANTHROPIC_CONFIG), patch( "anthropic.AsyncAnthropic.messages", new_callable=lambda: type( "Messages", (), {"create": staticmethod(AsyncMock(side_effect=_fake_create))}, ), ): resp = await ai_client.post("/chat", json={"messages": SYSTEM_MESSAGES}) # Whether the call succeeded or not, no system role should reach the messages list if "messages" in captured_kwargs: roles = [m["role"] for m in captured_kwargs["messages"]] assert "system" not in roles # --------------------------------------------------------------------------- # Parameter forwarding # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_max_tokens_and_temperature_forwarded(ai_client): captured: dict = {} async def _capture(messages, max_tokens, temperature): captured["max_tokens"] = max_tokens captured["temperature"] = temperature return ("ok", 1, 1) with patch(_LOAD_CONFIG, return_value=LMSTUDIO_CONFIG), patch(_PROVIDER_CHAT, new=_capture): resp = await ai_client.post( "/chat", json={"messages": MESSAGES, "max_tokens": 512, "temperature": 0.7}, ) assert resp.status_code == 200 assert captured["max_tokens"] == 512 assert captured["temperature"] == pytest.approx(0.7)