digital-employee-platform/backend/app/services/conversation_service.py
root a0c2586487 feat(backend): B3 Task 3.1 - 对话与消息基础 API
- 添加 Conversation CRUD 端点(创建/列表/获取/删除)
- 添加 Message 操作端点(发送/列表)
- 注册 conversations 路由到 API v1
- 修复测试 fixture 的 API 路径前缀
- 添加 async_client fixture alias

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-07 02:24:06 +08:00

81 lines
2.3 KiB
Python

from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models import Conversation, Message, MessageRole
async def create_conversation(
session: AsyncSession,
tenant_id: str,
employee_id: str,
user_id: str,
title: str | None = None,
) -> Conversation:
conversation = Conversation(
tenant_id=tenant_id,
employee_id=employee_id,
user_id=user_id,
title=title,
)
session.add(conversation)
await session.commit()
await session.refresh(conversation)
return conversation
async def get_conversation(
session: AsyncSession, conversation_id: str, include_messages: bool = False
) -> Conversation:
query = select(Conversation).where(Conversation.id == conversation_id)
if include_messages:
query = query.options(selectinload(Conversation.messages))
result = await session.execute(query)
conversation = result.scalar_one_or_none()
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
return conversation
async def list_conversations(session: AsyncSession, tenant_id: str) -> list[Conversation]:
result = await session.execute(
select(Conversation).where(Conversation.tenant_id == tenant_id)
)
return list(result.scalars().all())
async def delete_conversation(session: AsyncSession, conversation_id: str) -> None:
conversation = await get_conversation(session, conversation_id)
await session.delete(conversation)
await session.commit()
async def create_message(
session: AsyncSession,
conversation_id: str,
role: MessageRole,
content: str,
token_count: int | None = None,
sources: str | None = None,
) -> Message:
message = Message(
conversation_id=conversation_id,
role=role,
content=content,
token_count=token_count,
sources=sources,
)
session.add(message)
await session.commit()
await session.refresh(message)
return message
async def list_messages(session: AsyncSession, conversation_id: str) -> list[Message]:
result = await session.execute(
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at)
)
return list(result.scalars().all())