import json from fastapi import HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.models import Conversation, DigitalEmployee, Message, MessageRole from app.providers.base import LLMMessage from app.services.employee_service import get_employee from app.services.tenant_service import get_provider_for_tenant async def create_conversation( session: AsyncSession, tenant_id: str, employee_id: str, user_id: str, title: str | None = None, ) -> Conversation: conversation = Conversation( tenant_id=tenant_id, employee_id=employee_id, user_id=user_id, title=title, ) session.add(conversation) await session.commit() await session.refresh(conversation) return conversation async def get_conversation( session: AsyncSession, conversation_id: str, include_messages: bool = False ) -> Conversation: query = select(Conversation).where(Conversation.id == conversation_id) if include_messages: query = query.options(selectinload(Conversation.messages)) result = await session.execute(query) conversation = result.scalar_one_or_none() if not conversation: raise HTTPException(status_code=404, detail="Conversation not found") return conversation async def list_conversations(session: AsyncSession, tenant_id: str) -> list[Conversation]: result = await session.execute( select(Conversation).where(Conversation.tenant_id == tenant_id) ) return list(result.scalars().all()) async def delete_conversation(session: AsyncSession, conversation_id: str) -> None: conversation = await get_conversation(session, conversation_id) await session.delete(conversation) await session.commit() async def create_message( session: AsyncSession, conversation_id: str, role: MessageRole, content: str, token_count: int | None = None, sources: str | None = None, ) -> Message: message = Message( conversation_id=conversation_id, role=role, content=content, token_count=token_count, sources=sources, ) session.add(message) await session.commit() await session.refresh(message) return message async def list_messages(session: AsyncSession, conversation_id: str) -> list[Message]: result = await session.execute( select(Message) .where(Message.conversation_id == conversation_id) .order_by(Message.created_at) ) return list(result.scalars().all()) async def build_messages_for_llm( session: AsyncSession, conversation_id: str, user_content: str, ) -> tuple[list[LLMMessage], DigitalEmployee]: """构建发送给 LLM 的消息列表,返回 (messages, employee)""" conversation = await get_conversation(session, conversation_id, include_messages=True) employee = await get_employee(session, conversation.employee_id) messages: list[LLMMessage] = [] # System prompt if employee.system_prompt: messages.append(LLMMessage(role="system", content=employee.system_prompt)) # 历史消息(按 max_context_messages 裁剪) history = conversation.messages or [] max_context = employee.max_context_messages if len(history) > max_context: history = history[-max_context:] for msg in history: messages.append(LLMMessage(role=msg.role.value, content=msg.content)) # 当前用户消息 messages.append(LLMMessage(role="user", content=user_content)) return messages, employee async def stream_chat( session: AsyncSession, conversation_id: str, user_content: str, ): """流式对话生成器,yield SSE 事件字符串""" # 构建消息 messages, employee = await build_messages_for_llm(session, conversation_id, user_content) # 保存用户消息 await create_message(session, conversation_id, MessageRole.user, user_content) # 获取 Provider provider = await get_provider_for_tenant(session, employee.tenant_id) # 发送 message_start 事件 yield f"data: {json.dumps({'type': 'message_start'})}\n\n" # 流式生成 full_content = "" token_count = 0 async for token in provider.chat_stream(messages): full_content += token token_count += 1 yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n" # 保存 assistant 消息 await create_message(session, conversation_id, MessageRole.assistant, full_content, token_count=token_count) # 发送 message_end 事件 yield f"data: {json.dumps({'type': 'message_end', 'token_count': token_count})}\n\n"