feat(backend): B3 Task 3.2 - SSE 流式响应实现
- 添加 stream_chat 生成器函数处理 SSE 事件流 - 实现 message_start / token / message_end 事件格式 - 添加 messages/stream SSE 端点 - 构建 LLM 消息列表(system prompt + 历史 + 用户消息) - 持久化用户和 assistant 消息到数据库 - 添加 SSE 测试用 mock Provider Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
a0c2586487
commit
1d7946d98a
@ -1,4 +1,5 @@
|
|||||||
from fastapi import APIRouter, Depends, status
|
from fastapi import APIRouter, Depends, status
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_db
|
from app.api.deps import get_db
|
||||||
@ -53,7 +54,6 @@ async def send_message(
|
|||||||
data: MessageCreate,
|
data: MessageCreate,
|
||||||
session: AsyncSession = Depends(get_db),
|
session: AsyncSession = Depends(get_db),
|
||||||
) -> MessageResponse:
|
) -> MessageResponse:
|
||||||
# For now, just create a user message (LLM response will come in Task 3.2)
|
|
||||||
await conversation_service.get_conversation(session, conversation_id)
|
await conversation_service.get_conversation(session, conversation_id)
|
||||||
msg = await conversation_service.create_message(
|
msg = await conversation_service.create_message(
|
||||||
session, conversation_id, role="user", content=data.content
|
session, conversation_id, role="user", content=data.content
|
||||||
@ -61,6 +61,19 @@ async def send_message(
|
|||||||
return MessageResponse.from_model(msg)
|
return MessageResponse.from_model(msg)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{conversation_id}/messages/stream")
|
||||||
|
async def send_message_stream(
|
||||||
|
conversation_id: str,
|
||||||
|
data: MessageCreate,
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
) -> StreamingResponse:
|
||||||
|
"""SSE 流式响应端点"""
|
||||||
|
return StreamingResponse(
|
||||||
|
conversation_service.stream_chat(session, conversation_id, data.content),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{conversation_id}/messages", response_model=list[MessageResponse])
|
@router.get("/{conversation_id}/messages", response_model=list[MessageResponse])
|
||||||
async def list_messages(
|
async def list_messages(
|
||||||
conversation_id: str, session: AsyncSession = Depends(get_db)
|
conversation_id: str, session: AsyncSession = Depends(get_db)
|
||||||
|
|||||||
@ -1,9 +1,14 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from app.models import Conversation, Message, MessageRole
|
from app.models import Conversation, DigitalEmployee, Message, MessageRole
|
||||||
|
from app.providers.base import LLMMessage
|
||||||
|
from app.services.employee_service import get_employee
|
||||||
|
from app.services.tenant_service import get_provider_for_tenant
|
||||||
|
|
||||||
|
|
||||||
async def create_conversation(
|
async def create_conversation(
|
||||||
@ -79,3 +84,66 @@ async def list_messages(session: AsyncSession, conversation_id: str) -> list[Mes
|
|||||||
.order_by(Message.created_at)
|
.order_by(Message.created_at)
|
||||||
)
|
)
|
||||||
return list(result.scalars().all())
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
|
||||||
|
async def build_messages_for_llm(
|
||||||
|
session: AsyncSession,
|
||||||
|
conversation_id: str,
|
||||||
|
user_content: str,
|
||||||
|
) -> tuple[list[LLMMessage], DigitalEmployee]:
|
||||||
|
"""构建发送给 LLM 的消息列表,返回 (messages, employee)"""
|
||||||
|
conversation = await get_conversation(session, conversation_id, include_messages=True)
|
||||||
|
employee = await get_employee(session, conversation.employee_id)
|
||||||
|
|
||||||
|
messages: list[LLMMessage] = []
|
||||||
|
|
||||||
|
# System prompt
|
||||||
|
if employee.system_prompt:
|
||||||
|
messages.append(LLMMessage(role="system", content=employee.system_prompt))
|
||||||
|
|
||||||
|
# 历史消息(按 max_context_messages 裁剪)
|
||||||
|
history = conversation.messages or []
|
||||||
|
max_context = employee.max_context_messages
|
||||||
|
if len(history) > max_context:
|
||||||
|
history = history[-max_context:]
|
||||||
|
|
||||||
|
for msg in history:
|
||||||
|
messages.append(LLMMessage(role=msg.role.value, content=msg.content))
|
||||||
|
|
||||||
|
# 当前用户消息
|
||||||
|
messages.append(LLMMessage(role="user", content=user_content))
|
||||||
|
|
||||||
|
return messages, employee
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_chat(
|
||||||
|
session: AsyncSession,
|
||||||
|
conversation_id: str,
|
||||||
|
user_content: str,
|
||||||
|
):
|
||||||
|
"""流式对话生成器,yield SSE 事件字符串"""
|
||||||
|
# 构建消息
|
||||||
|
messages, employee = await build_messages_for_llm(session, conversation_id, user_content)
|
||||||
|
|
||||||
|
# 保存用户消息
|
||||||
|
await create_message(session, conversation_id, MessageRole.user, user_content)
|
||||||
|
|
||||||
|
# 获取 Provider
|
||||||
|
provider = await get_provider_for_tenant(session, employee.tenant_id)
|
||||||
|
|
||||||
|
# 发送 message_start 事件
|
||||||
|
yield f"data: {json.dumps({'type': 'message_start'})}\n\n"
|
||||||
|
|
||||||
|
# 流式生成
|
||||||
|
full_content = ""
|
||||||
|
token_count = 0
|
||||||
|
async for token in provider.chat_stream(messages):
|
||||||
|
full_content += token
|
||||||
|
token_count += 1
|
||||||
|
yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n"
|
||||||
|
|
||||||
|
# 保存 assistant 消息
|
||||||
|
await create_message(session, conversation_id, MessageRole.assistant, full_content, token_count=token_count)
|
||||||
|
|
||||||
|
# 发送 message_end 事件
|
||||||
|
yield f"data: {json.dumps({'type': 'message_end', 'token_count': token_count})}\n\n"
|
||||||
@ -8,7 +8,17 @@ from app.main import app
|
|||||||
async def tenant_id(async_client: AsyncClient) -> str:
|
async def tenant_id(async_client: AsyncClient) -> str:
|
||||||
"""Create a tenant and return its ID."""
|
"""Create a tenant and return its ID."""
|
||||||
resp = await async_client.post("/api/v1/tenants", json={"name": "Test Tenant", "slug": "test-tenant"})
|
resp = await async_client.post("/api/v1/tenants", json={"name": "Test Tenant", "slug": "test-tenant"})
|
||||||
return resp.json()["id"]
|
tid = resp.json()["id"]
|
||||||
|
# Configure tenant with LLM provider (mock settings)
|
||||||
|
await async_client.put(
|
||||||
|
f"/api/v1/tenants/{tid}/config",
|
||||||
|
json={
|
||||||
|
"llm_provider": "openai",
|
||||||
|
"llm_api_key": "test-api-key",
|
||||||
|
"llm_model": "gpt-4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return tid
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -186,3 +196,101 @@ class TestMessageOperations:
|
|||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
# The response may or may not include messages depending on schema
|
# The response may or may not include messages depending on schema
|
||||||
# For now, we just verify the conversation is accessible
|
# For now, we just verify the conversation is accessible
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSEStreaming:
|
||||||
|
"""Test SSE streaming message responses."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def conversation_id(self, async_client: AsyncClient, tenant_id: str, employee_id: str) -> str:
|
||||||
|
"""Create a conversation and return its ID."""
|
||||||
|
resp = await async_client.post(
|
||||||
|
"/api/v1/conversations",
|
||||||
|
json={
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"employee_id": employee_id,
|
||||||
|
"user_id": "user-001",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return resp.json()["id"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sse_stream_events(self, async_client: AsyncClient, conversation_id: str):
|
||||||
|
"""Test SSE stream returns properly formatted events."""
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
|
||||||
|
# Mock provider that yields tokens
|
||||||
|
async def mock_stream(messages):
|
||||||
|
for token in ["Hello", " ", "world", "!"]:
|
||||||
|
yield token
|
||||||
|
|
||||||
|
mock_provider = MagicMock()
|
||||||
|
mock_provider.chat_stream = mock_stream
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.services.conversation_service.get_provider_for_tenant",
|
||||||
|
AsyncMock(return_value=mock_provider)
|
||||||
|
):
|
||||||
|
async with async_client.stream(
|
||||||
|
"POST",
|
||||||
|
f"/api/v1/conversations/{conversation_id}/messages/stream",
|
||||||
|
json={"content": "Hello"},
|
||||||
|
) as response:
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
||||||
|
|
||||||
|
events = []
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
events.append(line[6:]) # Remove "data: " prefix
|
||||||
|
|
||||||
|
# Should have at least message_start, some tokens, and message_end
|
||||||
|
assert len(events) >= 2
|
||||||
|
|
||||||
|
# First event should be message_start
|
||||||
|
import json
|
||||||
|
first_event = json.loads(events[0])
|
||||||
|
assert first_event["type"] == "message_start"
|
||||||
|
|
||||||
|
# Last event should be message_end
|
||||||
|
last_event = json.loads(events[-1])
|
||||||
|
assert last_event["type"] == "message_end"
|
||||||
|
assert "token_count" in last_event
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sse_stream_persists_messages(
|
||||||
|
self, async_client: AsyncClient, conversation_id: str
|
||||||
|
):
|
||||||
|
"""Test that SSE stream persists both user and assistant messages."""
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
|
||||||
|
# Mock provider that yields tokens
|
||||||
|
async def mock_stream(messages):
|
||||||
|
for token in ["Test", " ", "response"]:
|
||||||
|
yield token
|
||||||
|
|
||||||
|
mock_provider = MagicMock()
|
||||||
|
mock_provider.chat_stream = mock_stream
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.services.conversation_service.get_provider_for_tenant",
|
||||||
|
AsyncMock(return_value=mock_provider)
|
||||||
|
):
|
||||||
|
async with async_client.stream(
|
||||||
|
"POST",
|
||||||
|
f"/api/v1/conversations/{conversation_id}/messages/stream",
|
||||||
|
json={"content": "Test question"},
|
||||||
|
) as response:
|
||||||
|
# Consume the stream
|
||||||
|
async for _ in response.aiter_lines():
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Verify messages were persisted
|
||||||
|
resp = await async_client.get(f"/api/v1/conversations/{conversation_id}/messages")
|
||||||
|
messages = resp.json()
|
||||||
|
|
||||||
|
# Should have user message and assistant message
|
||||||
|
assert len(messages) >= 2
|
||||||
|
assert messages[0]["role"] == "user"
|
||||||
|
assert messages[0]["content"] == "Test question"
|
||||||
|
assert messages[1]["role"] == "assistant"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user