- 添加 Conversation CRUD 端点(创建/列表/获取/删除) - 添加 Message 操作端点(发送/列表) - 注册 conversations 路由到 API v1 - 修复测试 fixture 的 API 路径前缀 - 添加 async_client fixture alias Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
81 lines
2.3 KiB
Python
81 lines
2.3 KiB
Python
from fastapi import HTTPException
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from app.models import Conversation, Message, MessageRole
|
|
|
|
|
|
async def create_conversation(
|
|
session: AsyncSession,
|
|
tenant_id: str,
|
|
employee_id: str,
|
|
user_id: str,
|
|
title: str | None = None,
|
|
) -> Conversation:
|
|
conversation = Conversation(
|
|
tenant_id=tenant_id,
|
|
employee_id=employee_id,
|
|
user_id=user_id,
|
|
title=title,
|
|
)
|
|
session.add(conversation)
|
|
await session.commit()
|
|
await session.refresh(conversation)
|
|
return conversation
|
|
|
|
|
|
async def get_conversation(
|
|
session: AsyncSession, conversation_id: str, include_messages: bool = False
|
|
) -> Conversation:
|
|
query = select(Conversation).where(Conversation.id == conversation_id)
|
|
if include_messages:
|
|
query = query.options(selectinload(Conversation.messages))
|
|
result = await session.execute(query)
|
|
conversation = result.scalar_one_or_none()
|
|
if not conversation:
|
|
raise HTTPException(status_code=404, detail="Conversation not found")
|
|
return conversation
|
|
|
|
|
|
async def list_conversations(session: AsyncSession, tenant_id: str) -> list[Conversation]:
|
|
result = await session.execute(
|
|
select(Conversation).where(Conversation.tenant_id == tenant_id)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def delete_conversation(session: AsyncSession, conversation_id: str) -> None:
|
|
conversation = await get_conversation(session, conversation_id)
|
|
await session.delete(conversation)
|
|
await session.commit()
|
|
|
|
|
|
async def create_message(
|
|
session: AsyncSession,
|
|
conversation_id: str,
|
|
role: MessageRole,
|
|
content: str,
|
|
token_count: int | None = None,
|
|
sources: str | None = None,
|
|
) -> Message:
|
|
message = Message(
|
|
conversation_id=conversation_id,
|
|
role=role,
|
|
content=content,
|
|
token_count=token_count,
|
|
sources=sources,
|
|
)
|
|
session.add(message)
|
|
await session.commit()
|
|
await session.refresh(message)
|
|
return message
|
|
|
|
|
|
async def list_messages(session: AsyncSession, conversation_id: str) -> list[Message]:
|
|
result = await session.execute(
|
|
select(Message)
|
|
.where(Message.conversation_id == conversation_id)
|
|
.order_by(Message.created_at)
|
|
)
|
|
return list(result.scalars().all()) |