root c62156af53 feat(backend): 数字员工平台 B1+B2 批次实现
B1: 项目脚手架 + 数据模型 + 租户管理
- Task 1.1: FastAPI 项目脚手架、SQLite + async SQLAlchemy
- Task 1.2: 7 个数据模型 (Tenant, TenantConfig, DigitalEmployee, Conversation, Message, KnowledgeBase, Document)
- Task 1.3: 租户 CRUD API + LLM 配置(含 API Key AES 加密)

B2: 数字员工配置 + LLM Provider 抽象层
- Task 2.1: 数字员工 CRUD API(关联知识库)
- Task 2.2: BaseLLMProvider 抽象接口 + OpenAI/Qwen Provider
- Task 2.3: Provider 动态实例化 + test-provider 端点

验证: 26 个测试全部通过

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-06 11:29:48 +08:00

197 lines
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM Provider 抽象层测试"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.providers.base import LLMMessage, LLMResponse
@pytest.mark.asyncio
async def test_base_provider_interface():
"""验证 BaseLLMProvider 接口定义"""
from app.providers.base import BaseLLMProvider
# 接口应该定义 chat, chat_stream, embed 方法
assert hasattr(BaseLLMProvider, "chat")
assert hasattr(BaseLLMProvider, "chat_stream")
assert hasattr(BaseLLMProvider, "embed")
@pytest.mark.asyncio
async def test_llm_message_dataclass():
"""验证 LLMMessage 数据类"""
msg = LLMMessage(role="user", content="Hello")
assert msg.role == "user"
assert msg.content == "Hello"
@pytest.mark.asyncio
async def test_llm_response_dataclass():
"""验证 LLMResponse 数据类"""
resp = LLMResponse(content="Hi there!", model="gpt-4", usage={"total_tokens": 10})
assert resp.content == "Hi there!"
assert resp.model == "gpt-4"
assert resp.usage["total_tokens"] == 10
@pytest.mark.asyncio
async def test_openai_provider_chat():
"""测试 OpenAI Provider chat 方法mock"""
from app.providers.openai_provider import OpenAIProvider
provider = OpenAIProvider(api_key="sk-test", model="gpt-4")
with patch.object(provider, "client") as mock_client:
mock_response = MagicMock()
mock_response.choices = [MagicMock(message=MagicMock(content="Test response"))]
mock_response.usage = MagicMock(total_tokens=100, prompt_tokens=50, completion_tokens=50)
mock_response.model = "gpt-4"
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
result = await provider.chat([LLMMessage(role="user", content="Hi")])
assert result.content == "Test response"
assert result.model == "gpt-4"
@pytest.mark.asyncio
async def test_openai_provider_chat_stream():
"""测试 OpenAI Provider chat_stream 方法mock"""
from app.providers.openai_provider import OpenAIProvider
provider = OpenAIProvider(api_key="sk-test", model="gpt-4")
async def mock_stream():
class MockChunk:
choices = [MagicMock(delta=MagicMock(content="token"))]
yield MockChunk()
yield MockChunk()
with patch.object(provider, "client") as mock_client:
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream())
tokens = []
async for chunk in provider.chat_stream([LLMMessage(role="user", content="Hi")]):
tokens.append(chunk)
assert len(tokens) >= 1
@pytest.mark.asyncio
async def test_openai_provider_embed():
"""测试 OpenAI Provider embed 方法mock"""
from app.providers.openai_provider import OpenAIProvider
provider = OpenAIProvider(api_key="sk-test", model="text-embedding-3-small")
with patch.object(provider, "client") as mock_client:
mock_response = MagicMock()
mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])]
mock_response.usage = MagicMock(total_tokens=10)
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
result = await provider.embed(["Hello world"])
assert len(result) == 1
assert len(result[0]) == 3
@pytest.mark.asyncio
async def test_qwen_provider_chat():
"""测试 Qwen Provider chat 方法mock兼容 OpenAI SDK"""
from app.providers.qwen_provider import QwenProvider
provider = QwenProvider(
api_key="sk-qwen-test",
model="qwen-turbo",
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
with patch.object(provider, "client") as mock_client:
mock_response = MagicMock()
mock_response.choices = [MagicMock(message=MagicMock(content="Qwen response"))]
mock_response.usage = MagicMock(total_tokens=80)
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
result = await provider.chat([LLMMessage(role="user", content="你好")])
assert result.content == "Qwen response"
@pytest.mark.asyncio
async def test_provider_registry():
"""测试 Provider 注册表"""
from app.providers import get_provider
# 根据 provider 类型返回正确的实例
openai_provider = get_provider("openai", api_key="sk-test", model="gpt-4")
assert openai_provider.__class__.__name__ == "OpenAIProvider"
qwen_provider = get_provider("qwen", api_key="sk-qwen", model="qwen-turbo")
assert qwen_provider.__class__.__name__ == "QwenProvider"
@pytest.mark.asyncio
async def test_provider_registry_invalid():
"""测试 Provider 注册表无效类型"""
from app.providers import get_provider
with pytest.raises(ValueError, match="Unknown provider"):
get_provider("invalid_provider", api_key="sk-test", model="model")
# ============= Task 2.3: Provider 集成到租户配置 =============
@pytest.mark.asyncio
async def test_get_provider_for_tenant(client: AsyncClient):
"""测试根据租户配置获取 Provider"""
from app.services.tenant_service import get_provider_for_tenant
from tests.conftest import TestSession
# 创建租户并配置
tenant_resp = await client.post(
"/api/v1/tenants",
json={"name": "Provider Test", "slug": "provider-test"},
)
tenant_id = tenant_resp.json()["id"]
await client.put(
f"/api/v1/tenants/{tenant_id}/config",
json={
"llm_provider": "openai",
"llm_api_key": "sk-test-provider-key",
"llm_model": "gpt-4o",
},
)
# 获取 Provider
async with TestSession() as session:
provider = await get_provider_for_tenant(session, tenant_id)
assert provider.__class__.__name__ == "OpenAIProvider"
assert provider.model == "gpt-4o"
@pytest.mark.asyncio
async def test_test_provider_endpoint(client: AsyncClient):
"""测试 Provider 测试端点"""
# 创建租户并配置
tenant_resp = await client.post(
"/api/v1/tenants",
json={"name": "Provider Endpoint Test", "slug": "provider-endpoint-test"},
)
tenant_id = tenant_resp.json()["id"]
await client.put(
f"/api/v1/tenants/{tenant_id}/config",
json={
"llm_provider": "openai",
"llm_api_key": "sk-test-endpoint",
"llm_model": "gpt-4",
},
)
# 测试 Provider
response = await client.post(f"/api/v1/tenants/{tenant_id}/test-provider")
assert response.status_code == 200
data = response.json()
assert data["provider"] == "openai"
assert data["model"] == "gpt-4"
assert "status" in data