digital-employee-platform/backend/app/services/conversation_service.py
root 1d7946d98a feat(backend): B3 Task 3.2 - SSE 流式响应实现
- 添加 stream_chat 生成器函数处理 SSE 事件流
- 实现 message_start / token / message_end 事件格式
- 添加 messages/stream SSE 端点
- 构建 LLM 消息列表(system prompt + 历史 + 用户消息)
- 持久化用户和 assistant 消息到数据库
- 添加 SSE 测试用 mock Provider

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-07 02:33:12 +08:00

149 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models import Conversation, DigitalEmployee, Message, MessageRole
from app.providers.base import LLMMessage
from app.services.employee_service import get_employee
from app.services.tenant_service import get_provider_for_tenant
async def create_conversation(
session: AsyncSession,
tenant_id: str,
employee_id: str,
user_id: str,
title: str | None = None,
) -> Conversation:
conversation = Conversation(
tenant_id=tenant_id,
employee_id=employee_id,
user_id=user_id,
title=title,
)
session.add(conversation)
await session.commit()
await session.refresh(conversation)
return conversation
async def get_conversation(
session: AsyncSession, conversation_id: str, include_messages: bool = False
) -> Conversation:
query = select(Conversation).where(Conversation.id == conversation_id)
if include_messages:
query = query.options(selectinload(Conversation.messages))
result = await session.execute(query)
conversation = result.scalar_one_or_none()
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
return conversation
async def list_conversations(session: AsyncSession, tenant_id: str) -> list[Conversation]:
result = await session.execute(
select(Conversation).where(Conversation.tenant_id == tenant_id)
)
return list(result.scalars().all())
async def delete_conversation(session: AsyncSession, conversation_id: str) -> None:
conversation = await get_conversation(session, conversation_id)
await session.delete(conversation)
await session.commit()
async def create_message(
session: AsyncSession,
conversation_id: str,
role: MessageRole,
content: str,
token_count: int | None = None,
sources: str | None = None,
) -> Message:
message = Message(
conversation_id=conversation_id,
role=role,
content=content,
token_count=token_count,
sources=sources,
)
session.add(message)
await session.commit()
await session.refresh(message)
return message
async def list_messages(session: AsyncSession, conversation_id: str) -> list[Message]:
result = await session.execute(
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at)
)
return list(result.scalars().all())
async def build_messages_for_llm(
session: AsyncSession,
conversation_id: str,
user_content: str,
) -> tuple[list[LLMMessage], DigitalEmployee]:
"""构建发送给 LLM 的消息列表,返回 (messages, employee)"""
conversation = await get_conversation(session, conversation_id, include_messages=True)
employee = await get_employee(session, conversation.employee_id)
messages: list[LLMMessage] = []
# System prompt
if employee.system_prompt:
messages.append(LLMMessage(role="system", content=employee.system_prompt))
# 历史消息(按 max_context_messages 裁剪)
history = conversation.messages or []
max_context = employee.max_context_messages
if len(history) > max_context:
history = history[-max_context:]
for msg in history:
messages.append(LLMMessage(role=msg.role.value, content=msg.content))
# 当前用户消息
messages.append(LLMMessage(role="user", content=user_content))
return messages, employee
async def stream_chat(
session: AsyncSession,
conversation_id: str,
user_content: str,
):
"""流式对话生成器yield SSE 事件字符串"""
# 构建消息
messages, employee = await build_messages_for_llm(session, conversation_id, user_content)
# 保存用户消息
await create_message(session, conversation_id, MessageRole.user, user_content)
# 获取 Provider
provider = await get_provider_for_tenant(session, employee.tenant_id)
# 发送 message_start 事件
yield f"data: {json.dumps({'type': 'message_start'})}\n\n"
# 流式生成
full_content = ""
token_count = 0
async for token in provider.chat_stream(messages):
full_content += token
token_count += 1
yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n"
# 保存 assistant 消息
await create_message(session, conversation_id, MessageRole.assistant, full_content, token_count=token_count)
# 发送 message_end 事件
yield f"data: {json.dumps({'type': 'message_end', 'token_count': token_count})}\n\n"