"""LLM Provider 抽象层测试""" import pytest from unittest.mock import AsyncMock, MagicMock, patch from app.providers.base import LLMMessage, LLMResponse @pytest.mark.asyncio async def test_base_provider_interface(): """验证 BaseLLMProvider 接口定义""" from app.providers.base import BaseLLMProvider # 接口应该定义 chat, chat_stream, embed 方法 assert hasattr(BaseLLMProvider, "chat") assert hasattr(BaseLLMProvider, "chat_stream") assert hasattr(BaseLLMProvider, "embed") @pytest.mark.asyncio async def test_llm_message_dataclass(): """验证 LLMMessage 数据类""" msg = LLMMessage(role="user", content="Hello") assert msg.role == "user" assert msg.content == "Hello" @pytest.mark.asyncio async def test_llm_response_dataclass(): """验证 LLMResponse 数据类""" resp = LLMResponse(content="Hi there!", model="gpt-4", usage={"total_tokens": 10}) assert resp.content == "Hi there!" assert resp.model == "gpt-4" assert resp.usage["total_tokens"] == 10 @pytest.mark.asyncio async def test_openai_provider_chat(): """测试 OpenAI Provider chat 方法(mock)""" from app.providers.openai_provider import OpenAIProvider provider = OpenAIProvider(api_key="sk-test", model="gpt-4") with patch.object(provider, "client") as mock_client: mock_response = MagicMock() mock_response.choices = [MagicMock(message=MagicMock(content="Test response"))] mock_response.usage = MagicMock(total_tokens=100, prompt_tokens=50, completion_tokens=50) mock_response.model = "gpt-4" mock_client.chat.completions.create = AsyncMock(return_value=mock_response) result = await provider.chat([LLMMessage(role="user", content="Hi")]) assert result.content == "Test response" assert result.model == "gpt-4" @pytest.mark.asyncio async def test_openai_provider_chat_stream(): """测试 OpenAI Provider chat_stream 方法(mock)""" from app.providers.openai_provider import OpenAIProvider provider = OpenAIProvider(api_key="sk-test", model="gpt-4") async def mock_stream(): class MockChunk: choices = [MagicMock(delta=MagicMock(content="token"))] yield MockChunk() yield MockChunk() with patch.object(provider, "client") as mock_client: mock_client.chat.completions.create = AsyncMock(return_value=mock_stream()) tokens = [] async for chunk in provider.chat_stream([LLMMessage(role="user", content="Hi")]): tokens.append(chunk) assert len(tokens) >= 1 @pytest.mark.asyncio async def test_openai_provider_embed(): """测试 OpenAI Provider embed 方法(mock)""" from app.providers.openai_provider import OpenAIProvider provider = OpenAIProvider(api_key="sk-test", model="text-embedding-3-small") with patch.object(provider, "client") as mock_client: mock_response = MagicMock() mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] mock_response.usage = MagicMock(total_tokens=10) mock_client.embeddings.create = AsyncMock(return_value=mock_response) result = await provider.embed(["Hello world"]) assert len(result) == 1 assert len(result[0]) == 3 @pytest.mark.asyncio async def test_qwen_provider_chat(): """测试 Qwen Provider chat 方法(mock,兼容 OpenAI SDK)""" from app.providers.qwen_provider import QwenProvider provider = QwenProvider( api_key="sk-qwen-test", model="qwen-turbo", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", ) with patch.object(provider, "client") as mock_client: mock_response = MagicMock() mock_response.choices = [MagicMock(message=MagicMock(content="Qwen response"))] mock_response.usage = MagicMock(total_tokens=80) mock_client.chat.completions.create = AsyncMock(return_value=mock_response) result = await provider.chat([LLMMessage(role="user", content="你好")]) assert result.content == "Qwen response" @pytest.mark.asyncio async def test_provider_registry(): """测试 Provider 注册表""" from app.providers import get_provider # 根据 provider 类型返回正确的实例 openai_provider = get_provider("openai", api_key="sk-test", model="gpt-4") assert openai_provider.__class__.__name__ == "OpenAIProvider" qwen_provider = get_provider("qwen", api_key="sk-qwen", model="qwen-turbo") assert qwen_provider.__class__.__name__ == "QwenProvider" @pytest.mark.asyncio async def test_provider_registry_invalid(): """测试 Provider 注册表无效类型""" from app.providers import get_provider with pytest.raises(ValueError, match="Unknown provider"): get_provider("invalid_provider", api_key="sk-test", model="model") # ============= Task 2.3: Provider 集成到租户配置 ============= @pytest.mark.asyncio async def test_get_provider_for_tenant(client: AsyncClient): """测试根据租户配置获取 Provider""" from app.services.tenant_service import get_provider_for_tenant from tests.conftest import TestSession # 创建租户并配置 tenant_resp = await client.post( "/api/v1/tenants", json={"name": "Provider Test", "slug": "provider-test"}, ) tenant_id = tenant_resp.json()["id"] await client.put( f"/api/v1/tenants/{tenant_id}/config", json={ "llm_provider": "openai", "llm_api_key": "sk-test-provider-key", "llm_model": "gpt-4o", }, ) # 获取 Provider async with TestSession() as session: provider = await get_provider_for_tenant(session, tenant_id) assert provider.__class__.__name__ == "OpenAIProvider" assert provider.model == "gpt-4o" @pytest.mark.asyncio async def test_test_provider_endpoint(client: AsyncClient): """测试 Provider 测试端点""" # 创建租户并配置 tenant_resp = await client.post( "/api/v1/tenants", json={"name": "Provider Endpoint Test", "slug": "provider-endpoint-test"}, ) tenant_id = tenant_resp.json()["id"] await client.put( f"/api/v1/tenants/{tenant_id}/config", json={ "llm_provider": "openai", "llm_api_key": "sk-test-endpoint", "llm_model": "gpt-4", }, ) # 测试 Provider response = await client.post(f"/api/v1/tenants/{tenant_id}/test-provider") assert response.status_code == 200 data = response.json() assert data["provider"] == "openai" assert data["model"] == "gpt-4" assert "status" in data