- 添加 stream_chat 生成器函数处理 SSE 事件流 - 实现 message_start / token / message_end 事件格式 - 添加 messages/stream SSE 端点 - 构建 LLM 消息列表(system prompt + 历史 + 用户消息) - 持久化用户和 assistant 消息到数据库 - 添加 SSE 测试用 mock Provider Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
297 lines
10 KiB
Python
297 lines
10 KiB
Python
import pytest
|
|
from httpx import AsyncClient
|
|
|
|
from app.main import app
|
|
|
|
|
|
@pytest.fixture
|
|
async def tenant_id(async_client: AsyncClient) -> str:
|
|
"""Create a tenant and return its ID."""
|
|
resp = await async_client.post("/api/v1/tenants", json={"name": "Test Tenant", "slug": "test-tenant"})
|
|
tid = resp.json()["id"]
|
|
# Configure tenant with LLM provider (mock settings)
|
|
await async_client.put(
|
|
f"/api/v1/tenants/{tid}/config",
|
|
json={
|
|
"llm_provider": "openai",
|
|
"llm_api_key": "test-api-key",
|
|
"llm_model": "gpt-4",
|
|
},
|
|
)
|
|
return tid
|
|
|
|
|
|
@pytest.fixture
|
|
async def employee_id(async_client: AsyncClient, tenant_id: str) -> str:
|
|
"""Create an employee and return its ID."""
|
|
resp = await async_client.post(
|
|
"/api/v1/employees",
|
|
json={
|
|
"tenant_id": tenant_id,
|
|
"name": "Test Employee",
|
|
"role": "assistant",
|
|
"system_prompt": "You are a helpful assistant.",
|
|
},
|
|
)
|
|
return resp.json()["id"]
|
|
|
|
|
|
class TestConversationCRUD:
|
|
"""Test conversation CRUD operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_conversation(self, async_client: AsyncClient, tenant_id: str, employee_id: str):
|
|
"""Create a new conversation."""
|
|
resp = await async_client.post(
|
|
"/api/v1/conversations",
|
|
json={
|
|
"tenant_id": tenant_id,
|
|
"employee_id": employee_id,
|
|
"user_id": "user-001",
|
|
"title": "Test Conversation",
|
|
},
|
|
)
|
|
assert resp.status_code == 201
|
|
data = resp.json()
|
|
assert data["tenant_id"] == tenant_id
|
|
assert data["employee_id"] == employee_id
|
|
assert data["user_id"] == "user-001"
|
|
assert data["title"] == "Test Conversation"
|
|
assert "id" in data
|
|
assert "created_at" in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_conversations(self, async_client: AsyncClient, tenant_id: str, employee_id: str):
|
|
"""List all conversations."""
|
|
# Create two conversations
|
|
await async_client.post(
|
|
"/api/v1/conversations",
|
|
json={
|
|
"tenant_id": tenant_id,
|
|
"employee_id": employee_id,
|
|
"user_id": "user-001",
|
|
"title": "Conv 1",
|
|
},
|
|
)
|
|
await async_client.post(
|
|
"/api/v1/conversations",
|
|
json={
|
|
"tenant_id": tenant_id,
|
|
"employee_id": employee_id,
|
|
"user_id": "user-002",
|
|
"title": "Conv 2",
|
|
},
|
|
)
|
|
|
|
resp = await async_client.get("/api/v1/conversations", params={"tenant_id": tenant_id})
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert len(data) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_conversation(self, async_client: AsyncClient, tenant_id: str, employee_id: str):
|
|
"""Get a single conversation by ID."""
|
|
create_resp = await async_client.post(
|
|
"/api/v1/conversations",
|
|
json={
|
|
"tenant_id": tenant_id,
|
|
"employee_id": employee_id,
|
|
"user_id": "user-001",
|
|
"title": "Test Conv",
|
|
},
|
|
)
|
|
conv_id = create_resp.json()["id"]
|
|
|
|
resp = await async_client.get(f"/api/v1/conversations/{conv_id}")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["id"] == conv_id
|
|
assert data["title"] == "Test Conv"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_conversation(self, async_client: AsyncClient, tenant_id: str, employee_id: str):
|
|
"""Delete a conversation."""
|
|
create_resp = await async_client.post(
|
|
"/api/v1/conversations",
|
|
json={
|
|
"tenant_id": tenant_id,
|
|
"employee_id": employee_id,
|
|
"user_id": "user-001",
|
|
},
|
|
)
|
|
conv_id = create_resp.json()["id"]
|
|
|
|
resp = await async_client.delete(f"/api/v1/conversations/{conv_id}")
|
|
assert resp.status_code == 204
|
|
|
|
# Verify it's deleted
|
|
get_resp = await async_client.get(f"/api/v1/conversations/{conv_id}")
|
|
assert get_resp.status_code == 404
|
|
|
|
|
|
class TestMessageOperations:
|
|
"""Test message operations within a conversation."""
|
|
|
|
@pytest.fixture
|
|
async def conversation_id(self, async_client: AsyncClient, tenant_id: str, employee_id: str) -> str:
|
|
"""Create a conversation and return its ID."""
|
|
resp = await async_client.post(
|
|
"/api/v1/conversations",
|
|
json={
|
|
"tenant_id": tenant_id,
|
|
"employee_id": employee_id,
|
|
"user_id": "user-001",
|
|
},
|
|
)
|
|
return resp.json()["id"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_message(self, async_client: AsyncClient, conversation_id: str):
|
|
"""Send a message to a conversation."""
|
|
resp = await async_client.post(
|
|
f"/api/v1/conversations/{conversation_id}/messages",
|
|
json={"content": "Hello, how are you?"},
|
|
)
|
|
# For now, we expect 200 with the created user message
|
|
# Later with SSE streaming, this will change
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["role"] == "user"
|
|
assert data["content"] == "Hello, how are you?"
|
|
assert "id" in data
|
|
assert data["conversation_id"] == conversation_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_messages(self, async_client: AsyncClient, conversation_id: str):
|
|
"""List all messages in a conversation."""
|
|
# Send a message first
|
|
await async_client.post(
|
|
f"/api/v1/conversations/{conversation_id}/messages",
|
|
json={"content": "First message"},
|
|
)
|
|
await async_client.post(
|
|
f"/api/v1/conversations/{conversation_id}/messages",
|
|
json={"content": "Second message"},
|
|
)
|
|
|
|
resp = await async_client.get(f"/api/v1/conversations/{conversation_id}/messages")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert len(data) >= 2
|
|
# Messages should be ordered by created_at
|
|
assert data[0]["content"] == "First message"
|
|
assert data[1]["content"] == "Second message"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_conversation_with_messages(
|
|
self, async_client: AsyncClient, conversation_id: str
|
|
):
|
|
"""Get conversation details with messages."""
|
|
await async_client.post(
|
|
f"/api/v1/conversations/{conversation_id}/messages",
|
|
json={"content": "Test message"},
|
|
)
|
|
|
|
resp = await async_client.get(f"/api/v1/conversations/{conversation_id}")
|
|
assert resp.status_code == 200
|
|
# The response may or may not include messages depending on schema
|
|
# For now, we just verify the conversation is accessible
|
|
|
|
|
|
class TestSSEStreaming:
|
|
"""Test SSE streaming message responses."""
|
|
|
|
@pytest.fixture
|
|
async def conversation_id(self, async_client: AsyncClient, tenant_id: str, employee_id: str) -> str:
|
|
"""Create a conversation and return its ID."""
|
|
resp = await async_client.post(
|
|
"/api/v1/conversations",
|
|
json={
|
|
"tenant_id": tenant_id,
|
|
"employee_id": employee_id,
|
|
"user_id": "user-001",
|
|
},
|
|
)
|
|
return resp.json()["id"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sse_stream_events(self, async_client: AsyncClient, conversation_id: str):
|
|
"""Test SSE stream returns properly formatted events."""
|
|
from unittest.mock import AsyncMock, patch, MagicMock
|
|
|
|
# Mock provider that yields tokens
|
|
async def mock_stream(messages):
|
|
for token in ["Hello", " ", "world", "!"]:
|
|
yield token
|
|
|
|
mock_provider = MagicMock()
|
|
mock_provider.chat_stream = mock_stream
|
|
|
|
with patch(
|
|
"app.services.conversation_service.get_provider_for_tenant",
|
|
AsyncMock(return_value=mock_provider)
|
|
):
|
|
async with async_client.stream(
|
|
"POST",
|
|
f"/api/v1/conversations/{conversation_id}/messages/stream",
|
|
json={"content": "Hello"},
|
|
) as response:
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
|
|
|
events = []
|
|
async for line in response.aiter_lines():
|
|
if line.startswith("data: "):
|
|
events.append(line[6:]) # Remove "data: " prefix
|
|
|
|
# Should have at least message_start, some tokens, and message_end
|
|
assert len(events) >= 2
|
|
|
|
# First event should be message_start
|
|
import json
|
|
first_event = json.loads(events[0])
|
|
assert first_event["type"] == "message_start"
|
|
|
|
# Last event should be message_end
|
|
last_event = json.loads(events[-1])
|
|
assert last_event["type"] == "message_end"
|
|
assert "token_count" in last_event
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sse_stream_persists_messages(
|
|
self, async_client: AsyncClient, conversation_id: str
|
|
):
|
|
"""Test that SSE stream persists both user and assistant messages."""
|
|
from unittest.mock import AsyncMock, patch, MagicMock
|
|
|
|
# Mock provider that yields tokens
|
|
async def mock_stream(messages):
|
|
for token in ["Test", " ", "response"]:
|
|
yield token
|
|
|
|
mock_provider = MagicMock()
|
|
mock_provider.chat_stream = mock_stream
|
|
|
|
with patch(
|
|
"app.services.conversation_service.get_provider_for_tenant",
|
|
AsyncMock(return_value=mock_provider)
|
|
):
|
|
async with async_client.stream(
|
|
"POST",
|
|
f"/api/v1/conversations/{conversation_id}/messages/stream",
|
|
json={"content": "Test question"},
|
|
) as response:
|
|
# Consume the stream
|
|
async for _ in response.aiter_lines():
|
|
pass
|
|
|
|
# Verify messages were persisted
|
|
resp = await async_client.get(f"/api/v1/conversations/{conversation_id}/messages")
|
|
messages = resp.json()
|
|
|
|
# Should have user message and assistant message
|
|
assert len(messages) >= 2
|
|
assert messages[0]["role"] == "user"
|
|
assert messages[0]["content"] == "Test question"
|
|
assert messages[1]["role"] == "assistant"
|