- 添加 stream_chat 生成器函数处理 SSE 事件流 - 实现 message_start / token / message_end 事件格式 - 添加 messages/stream SSE 端点 - 构建 LLM 消息列表(system prompt + 历史 + 用户消息) - 持久化用户和 assistant 消息到数据库 - 添加 SSE 测试用 mock Provider Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
149 lines
4.6 KiB
Python
149 lines
4.6 KiB
Python
import json
|
||
|
||
from fastapi import HTTPException
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy.orm import selectinload
|
||
|
||
from app.models import Conversation, DigitalEmployee, Message, MessageRole
|
||
from app.providers.base import LLMMessage
|
||
from app.services.employee_service import get_employee
|
||
from app.services.tenant_service import get_provider_for_tenant
|
||
|
||
|
||
async def create_conversation(
|
||
session: AsyncSession,
|
||
tenant_id: str,
|
||
employee_id: str,
|
||
user_id: str,
|
||
title: str | None = None,
|
||
) -> Conversation:
|
||
conversation = Conversation(
|
||
tenant_id=tenant_id,
|
||
employee_id=employee_id,
|
||
user_id=user_id,
|
||
title=title,
|
||
)
|
||
session.add(conversation)
|
||
await session.commit()
|
||
await session.refresh(conversation)
|
||
return conversation
|
||
|
||
|
||
async def get_conversation(
|
||
session: AsyncSession, conversation_id: str, include_messages: bool = False
|
||
) -> Conversation:
|
||
query = select(Conversation).where(Conversation.id == conversation_id)
|
||
if include_messages:
|
||
query = query.options(selectinload(Conversation.messages))
|
||
result = await session.execute(query)
|
||
conversation = result.scalar_one_or_none()
|
||
if not conversation:
|
||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||
return conversation
|
||
|
||
|
||
async def list_conversations(session: AsyncSession, tenant_id: str) -> list[Conversation]:
|
||
result = await session.execute(
|
||
select(Conversation).where(Conversation.tenant_id == tenant_id)
|
||
)
|
||
return list(result.scalars().all())
|
||
|
||
|
||
async def delete_conversation(session: AsyncSession, conversation_id: str) -> None:
|
||
conversation = await get_conversation(session, conversation_id)
|
||
await session.delete(conversation)
|
||
await session.commit()
|
||
|
||
|
||
async def create_message(
|
||
session: AsyncSession,
|
||
conversation_id: str,
|
||
role: MessageRole,
|
||
content: str,
|
||
token_count: int | None = None,
|
||
sources: str | None = None,
|
||
) -> Message:
|
||
message = Message(
|
||
conversation_id=conversation_id,
|
||
role=role,
|
||
content=content,
|
||
token_count=token_count,
|
||
sources=sources,
|
||
)
|
||
session.add(message)
|
||
await session.commit()
|
||
await session.refresh(message)
|
||
return message
|
||
|
||
|
||
async def list_messages(session: AsyncSession, conversation_id: str) -> list[Message]:
|
||
result = await session.execute(
|
||
select(Message)
|
||
.where(Message.conversation_id == conversation_id)
|
||
.order_by(Message.created_at)
|
||
)
|
||
return list(result.scalars().all())
|
||
|
||
|
||
async def build_messages_for_llm(
|
||
session: AsyncSession,
|
||
conversation_id: str,
|
||
user_content: str,
|
||
) -> tuple[list[LLMMessage], DigitalEmployee]:
|
||
"""构建发送给 LLM 的消息列表,返回 (messages, employee)"""
|
||
conversation = await get_conversation(session, conversation_id, include_messages=True)
|
||
employee = await get_employee(session, conversation.employee_id)
|
||
|
||
messages: list[LLMMessage] = []
|
||
|
||
# System prompt
|
||
if employee.system_prompt:
|
||
messages.append(LLMMessage(role="system", content=employee.system_prompt))
|
||
|
||
# 历史消息(按 max_context_messages 裁剪)
|
||
history = conversation.messages or []
|
||
max_context = employee.max_context_messages
|
||
if len(history) > max_context:
|
||
history = history[-max_context:]
|
||
|
||
for msg in history:
|
||
messages.append(LLMMessage(role=msg.role.value, content=msg.content))
|
||
|
||
# 当前用户消息
|
||
messages.append(LLMMessage(role="user", content=user_content))
|
||
|
||
return messages, employee
|
||
|
||
|
||
async def stream_chat(
|
||
session: AsyncSession,
|
||
conversation_id: str,
|
||
user_content: str,
|
||
):
|
||
"""流式对话生成器,yield SSE 事件字符串"""
|
||
# 构建消息
|
||
messages, employee = await build_messages_for_llm(session, conversation_id, user_content)
|
||
|
||
# 保存用户消息
|
||
await create_message(session, conversation_id, MessageRole.user, user_content)
|
||
|
||
# 获取 Provider
|
||
provider = await get_provider_for_tenant(session, employee.tenant_id)
|
||
|
||
# 发送 message_start 事件
|
||
yield f"data: {json.dumps({'type': 'message_start'})}\n\n"
|
||
|
||
# 流式生成
|
||
full_content = ""
|
||
token_count = 0
|
||
async for token in provider.chat_stream(messages):
|
||
full_content += token
|
||
token_count += 1
|
||
yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n"
|
||
|
||
# 保存 assistant 消息
|
||
await create_message(session, conversation_id, MessageRole.assistant, full_content, token_count=token_count)
|
||
|
||
# 发送 message_end 事件
|
||
yield f"data: {json.dumps({'type': 'message_end', 'token_count': token_count})}\n\n" |