feat(backend): B3 Task 3.2 - SSE 流式响应实现

- 添加 stream_chat 生成器函数处理 SSE 事件流
- 实现 message_start / token / message_end 事件格式
- 添加 messages/stream SSE 端点
- 构建 LLM 消息列表(system prompt + 历史 + 用户消息)
- 持久化用户和 assistant 消息到数据库
- 添加 SSE 测试用 mock Provider

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
root 2026-05-07 02:33:12 +08:00
parent a0c2586487
commit 1d7946d98a
3 changed files with 193 additions and 4 deletions

View File

@ -1,4 +1,5 @@
from fastapi import APIRouter, Depends, status from fastapi import APIRouter, Depends, status
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db from app.api.deps import get_db
@ -53,7 +54,6 @@ async def send_message(
data: MessageCreate, data: MessageCreate,
session: AsyncSession = Depends(get_db), session: AsyncSession = Depends(get_db),
) -> MessageResponse: ) -> MessageResponse:
# For now, just create a user message (LLM response will come in Task 3.2)
await conversation_service.get_conversation(session, conversation_id) await conversation_service.get_conversation(session, conversation_id)
msg = await conversation_service.create_message( msg = await conversation_service.create_message(
session, conversation_id, role="user", content=data.content session, conversation_id, role="user", content=data.content
@ -61,6 +61,19 @@ async def send_message(
return MessageResponse.from_model(msg) return MessageResponse.from_model(msg)
@router.post("/{conversation_id}/messages/stream")
async def send_message_stream(
conversation_id: str,
data: MessageCreate,
session: AsyncSession = Depends(get_db),
) -> StreamingResponse:
"""SSE 流式响应端点"""
return StreamingResponse(
conversation_service.stream_chat(session, conversation_id, data.content),
media_type="text/event-stream",
)
@router.get("/{conversation_id}/messages", response_model=list[MessageResponse]) @router.get("/{conversation_id}/messages", response_model=list[MessageResponse])
async def list_messages( async def list_messages(
conversation_id: str, session: AsyncSession = Depends(get_db) conversation_id: str, session: AsyncSession = Depends(get_db)

View File

@ -1,9 +1,14 @@
import json
from fastapi import HTTPException from fastapi import HTTPException
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from app.models import Conversation, Message, MessageRole from app.models import Conversation, DigitalEmployee, Message, MessageRole
from app.providers.base import LLMMessage
from app.services.employee_service import get_employee
from app.services.tenant_service import get_provider_for_tenant
async def create_conversation( async def create_conversation(
@ -78,4 +83,67 @@ async def list_messages(session: AsyncSession, conversation_id: str) -> list[Mes
.where(Message.conversation_id == conversation_id) .where(Message.conversation_id == conversation_id)
.order_by(Message.created_at) .order_by(Message.created_at)
) )
return list(result.scalars().all()) return list(result.scalars().all())
async def build_messages_for_llm(
session: AsyncSession,
conversation_id: str,
user_content: str,
) -> tuple[list[LLMMessage], DigitalEmployee]:
"""构建发送给 LLM 的消息列表,返回 (messages, employee)"""
conversation = await get_conversation(session, conversation_id, include_messages=True)
employee = await get_employee(session, conversation.employee_id)
messages: list[LLMMessage] = []
# System prompt
if employee.system_prompt:
messages.append(LLMMessage(role="system", content=employee.system_prompt))
# 历史消息(按 max_context_messages 裁剪)
history = conversation.messages or []
max_context = employee.max_context_messages
if len(history) > max_context:
history = history[-max_context:]
for msg in history:
messages.append(LLMMessage(role=msg.role.value, content=msg.content))
# 当前用户消息
messages.append(LLMMessage(role="user", content=user_content))
return messages, employee
async def stream_chat(
session: AsyncSession,
conversation_id: str,
user_content: str,
):
"""流式对话生成器yield SSE 事件字符串"""
# 构建消息
messages, employee = await build_messages_for_llm(session, conversation_id, user_content)
# 保存用户消息
await create_message(session, conversation_id, MessageRole.user, user_content)
# 获取 Provider
provider = await get_provider_for_tenant(session, employee.tenant_id)
# 发送 message_start 事件
yield f"data: {json.dumps({'type': 'message_start'})}\n\n"
# 流式生成
full_content = ""
token_count = 0
async for token in provider.chat_stream(messages):
full_content += token
token_count += 1
yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n"
# 保存 assistant 消息
await create_message(session, conversation_id, MessageRole.assistant, full_content, token_count=token_count)
# 发送 message_end 事件
yield f"data: {json.dumps({'type': 'message_end', 'token_count': token_count})}\n\n"

View File

@ -8,7 +8,17 @@ from app.main import app
async def tenant_id(async_client: AsyncClient) -> str: async def tenant_id(async_client: AsyncClient) -> str:
"""Create a tenant and return its ID.""" """Create a tenant and return its ID."""
resp = await async_client.post("/api/v1/tenants", json={"name": "Test Tenant", "slug": "test-tenant"}) resp = await async_client.post("/api/v1/tenants", json={"name": "Test Tenant", "slug": "test-tenant"})
return resp.json()["id"] tid = resp.json()["id"]
# Configure tenant with LLM provider (mock settings)
await async_client.put(
f"/api/v1/tenants/{tid}/config",
json={
"llm_provider": "openai",
"llm_api_key": "test-api-key",
"llm_model": "gpt-4",
},
)
return tid
@pytest.fixture @pytest.fixture
@ -186,3 +196,101 @@ class TestMessageOperations:
assert resp.status_code == 200 assert resp.status_code == 200
# The response may or may not include messages depending on schema # The response may or may not include messages depending on schema
# For now, we just verify the conversation is accessible # For now, we just verify the conversation is accessible
class TestSSEStreaming:
"""Test SSE streaming message responses."""
@pytest.fixture
async def conversation_id(self, async_client: AsyncClient, tenant_id: str, employee_id: str) -> str:
"""Create a conversation and return its ID."""
resp = await async_client.post(
"/api/v1/conversations",
json={
"tenant_id": tenant_id,
"employee_id": employee_id,
"user_id": "user-001",
},
)
return resp.json()["id"]
@pytest.mark.asyncio
async def test_sse_stream_events(self, async_client: AsyncClient, conversation_id: str):
"""Test SSE stream returns properly formatted events."""
from unittest.mock import AsyncMock, patch, MagicMock
# Mock provider that yields tokens
async def mock_stream(messages):
for token in ["Hello", " ", "world", "!"]:
yield token
mock_provider = MagicMock()
mock_provider.chat_stream = mock_stream
with patch(
"app.services.conversation_service.get_provider_for_tenant",
AsyncMock(return_value=mock_provider)
):
async with async_client.stream(
"POST",
f"/api/v1/conversations/{conversation_id}/messages/stream",
json={"content": "Hello"},
) as response:
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
events = []
async for line in response.aiter_lines():
if line.startswith("data: "):
events.append(line[6:]) # Remove "data: " prefix
# Should have at least message_start, some tokens, and message_end
assert len(events) >= 2
# First event should be message_start
import json
first_event = json.loads(events[0])
assert first_event["type"] == "message_start"
# Last event should be message_end
last_event = json.loads(events[-1])
assert last_event["type"] == "message_end"
assert "token_count" in last_event
@pytest.mark.asyncio
async def test_sse_stream_persists_messages(
self, async_client: AsyncClient, conversation_id: str
):
"""Test that SSE stream persists both user and assistant messages."""
from unittest.mock import AsyncMock, patch, MagicMock
# Mock provider that yields tokens
async def mock_stream(messages):
for token in ["Test", " ", "response"]:
yield token
mock_provider = MagicMock()
mock_provider.chat_stream = mock_stream
with patch(
"app.services.conversation_service.get_provider_for_tenant",
AsyncMock(return_value=mock_provider)
):
async with async_client.stream(
"POST",
f"/api/v1/conversations/{conversation_id}/messages/stream",
json={"content": "Test question"},
) as response:
# Consume the stream
async for _ in response.aiter_lines():
pass
# Verify messages were persisted
resp = await async_client.get(f"/api/v1/conversations/{conversation_id}/messages")
messages = resp.json()
# Should have user message and assistant message
assert len(messages) >= 2
assert messages[0]["role"] == "user"
assert messages[0]["content"] == "Test question"
assert messages[1]["role"] == "assistant"