B1: 项目脚手架 + 数据模型 + 租户管理 - Task 1.1: FastAPI 项目脚手架、SQLite + async SQLAlchemy - Task 1.2: 7 个数据模型 (Tenant, TenantConfig, DigitalEmployee, Conversation, Message, KnowledgeBase, Document) - Task 1.3: 租户 CRUD API + LLM 配置(含 API Key AES 加密) B2: 数字员工配置 + LLM Provider 抽象层 - Task 2.1: 数字员工 CRUD API(关联知识库) - Task 2.2: BaseLLMProvider 抽象接口 + OpenAI/Qwen Provider - Task 2.3: Provider 动态实例化 + test-provider 端点 验证: 26 个测试全部通过 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
197 lines
6.5 KiB
Python
197 lines
6.5 KiB
Python
"""LLM Provider 抽象层测试"""
|
||
import pytest
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from app.providers.base import LLMMessage, LLMResponse
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_base_provider_interface():
|
||
"""验证 BaseLLMProvider 接口定义"""
|
||
from app.providers.base import BaseLLMProvider
|
||
|
||
# 接口应该定义 chat, chat_stream, embed 方法
|
||
assert hasattr(BaseLLMProvider, "chat")
|
||
assert hasattr(BaseLLMProvider, "chat_stream")
|
||
assert hasattr(BaseLLMProvider, "embed")
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_message_dataclass():
|
||
"""验证 LLMMessage 数据类"""
|
||
msg = LLMMessage(role="user", content="Hello")
|
||
assert msg.role == "user"
|
||
assert msg.content == "Hello"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_response_dataclass():
|
||
"""验证 LLMResponse 数据类"""
|
||
resp = LLMResponse(content="Hi there!", model="gpt-4", usage={"total_tokens": 10})
|
||
assert resp.content == "Hi there!"
|
||
assert resp.model == "gpt-4"
|
||
assert resp.usage["total_tokens"] == 10
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_openai_provider_chat():
|
||
"""测试 OpenAI Provider chat 方法(mock)"""
|
||
from app.providers.openai_provider import OpenAIProvider
|
||
|
||
provider = OpenAIProvider(api_key="sk-test", model="gpt-4")
|
||
|
||
with patch.object(provider, "client") as mock_client:
|
||
mock_response = MagicMock()
|
||
mock_response.choices = [MagicMock(message=MagicMock(content="Test response"))]
|
||
mock_response.usage = MagicMock(total_tokens=100, prompt_tokens=50, completion_tokens=50)
|
||
mock_response.model = "gpt-4"
|
||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||
|
||
result = await provider.chat([LLMMessage(role="user", content="Hi")])
|
||
assert result.content == "Test response"
|
||
assert result.model == "gpt-4"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_openai_provider_chat_stream():
|
||
"""测试 OpenAI Provider chat_stream 方法(mock)"""
|
||
from app.providers.openai_provider import OpenAIProvider
|
||
|
||
provider = OpenAIProvider(api_key="sk-test", model="gpt-4")
|
||
|
||
async def mock_stream():
|
||
class MockChunk:
|
||
choices = [MagicMock(delta=MagicMock(content="token"))]
|
||
yield MockChunk()
|
||
yield MockChunk()
|
||
|
||
with patch.object(provider, "client") as mock_client:
|
||
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream())
|
||
|
||
tokens = []
|
||
async for chunk in provider.chat_stream([LLMMessage(role="user", content="Hi")]):
|
||
tokens.append(chunk)
|
||
|
||
assert len(tokens) >= 1
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_openai_provider_embed():
|
||
"""测试 OpenAI Provider embed 方法(mock)"""
|
||
from app.providers.openai_provider import OpenAIProvider
|
||
|
||
provider = OpenAIProvider(api_key="sk-test", model="text-embedding-3-small")
|
||
|
||
with patch.object(provider, "client") as mock_client:
|
||
mock_response = MagicMock()
|
||
mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])]
|
||
mock_response.usage = MagicMock(total_tokens=10)
|
||
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||
|
||
result = await provider.embed(["Hello world"])
|
||
assert len(result) == 1
|
||
assert len(result[0]) == 3
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_qwen_provider_chat():
|
||
"""测试 Qwen Provider chat 方法(mock,兼容 OpenAI SDK)"""
|
||
from app.providers.qwen_provider import QwenProvider
|
||
|
||
provider = QwenProvider(
|
||
api_key="sk-qwen-test",
|
||
model="qwen-turbo",
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||
)
|
||
|
||
with patch.object(provider, "client") as mock_client:
|
||
mock_response = MagicMock()
|
||
mock_response.choices = [MagicMock(message=MagicMock(content="Qwen response"))]
|
||
mock_response.usage = MagicMock(total_tokens=80)
|
||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||
|
||
result = await provider.chat([LLMMessage(role="user", content="你好")])
|
||
assert result.content == "Qwen response"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_provider_registry():
|
||
"""测试 Provider 注册表"""
|
||
from app.providers import get_provider
|
||
|
||
# 根据 provider 类型返回正确的实例
|
||
openai_provider = get_provider("openai", api_key="sk-test", model="gpt-4")
|
||
assert openai_provider.__class__.__name__ == "OpenAIProvider"
|
||
|
||
qwen_provider = get_provider("qwen", api_key="sk-qwen", model="qwen-turbo")
|
||
assert qwen_provider.__class__.__name__ == "QwenProvider"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_provider_registry_invalid():
|
||
"""测试 Provider 注册表无效类型"""
|
||
from app.providers import get_provider
|
||
|
||
with pytest.raises(ValueError, match="Unknown provider"):
|
||
get_provider("invalid_provider", api_key="sk-test", model="model")
|
||
|
||
|
||
# ============= Task 2.3: Provider 集成到租户配置 =============
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_provider_for_tenant(client: AsyncClient):
|
||
"""测试根据租户配置获取 Provider"""
|
||
from app.services.tenant_service import get_provider_for_tenant
|
||
from tests.conftest import TestSession
|
||
|
||
# 创建租户并配置
|
||
tenant_resp = await client.post(
|
||
"/api/v1/tenants",
|
||
json={"name": "Provider Test", "slug": "provider-test"},
|
||
)
|
||
tenant_id = tenant_resp.json()["id"]
|
||
|
||
await client.put(
|
||
f"/api/v1/tenants/{tenant_id}/config",
|
||
json={
|
||
"llm_provider": "openai",
|
||
"llm_api_key": "sk-test-provider-key",
|
||
"llm_model": "gpt-4o",
|
||
},
|
||
)
|
||
|
||
# 获取 Provider
|
||
async with TestSession() as session:
|
||
provider = await get_provider_for_tenant(session, tenant_id)
|
||
assert provider.__class__.__name__ == "OpenAIProvider"
|
||
assert provider.model == "gpt-4o"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_test_provider_endpoint(client: AsyncClient):
|
||
"""测试 Provider 测试端点"""
|
||
# 创建租户并配置
|
||
tenant_resp = await client.post(
|
||
"/api/v1/tenants",
|
||
json={"name": "Provider Endpoint Test", "slug": "provider-endpoint-test"},
|
||
)
|
||
tenant_id = tenant_resp.json()["id"]
|
||
|
||
await client.put(
|
||
f"/api/v1/tenants/{tenant_id}/config",
|
||
json={
|
||
"llm_provider": "openai",
|
||
"llm_api_key": "sk-test-endpoint",
|
||
"llm_model": "gpt-4",
|
||
},
|
||
)
|
||
|
||
# 测试 Provider
|
||
response = await client.post(f"/api/v1/tenants/{tenant_id}/test-provider")
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["provider"] == "openai"
|
||
assert data["model"] == "gpt-4"
|
||
assert "status" in data
|