from fastapi import HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.models import Conversation, Message, MessageRole async def create_conversation( session: AsyncSession, tenant_id: str, employee_id: str, user_id: str, title: str | None = None, ) -> Conversation: conversation = Conversation( tenant_id=tenant_id, employee_id=employee_id, user_id=user_id, title=title, ) session.add(conversation) await session.commit() await session.refresh(conversation) return conversation async def get_conversation( session: AsyncSession, conversation_id: str, include_messages: bool = False ) -> Conversation: query = select(Conversation).where(Conversation.id == conversation_id) if include_messages: query = query.options(selectinload(Conversation.messages)) result = await session.execute(query) conversation = result.scalar_one_or_none() if not conversation: raise HTTPException(status_code=404, detail="Conversation not found") return conversation async def list_conversations(session: AsyncSession, tenant_id: str) -> list[Conversation]: result = await session.execute( select(Conversation).where(Conversation.tenant_id == tenant_id) ) return list(result.scalars().all()) async def delete_conversation(session: AsyncSession, conversation_id: str) -> None: conversation = await get_conversation(session, conversation_id) await session.delete(conversation) await session.commit() async def create_message( session: AsyncSession, conversation_id: str, role: MessageRole, content: str, token_count: int | None = None, sources: str | None = None, ) -> Message: message = Message( conversation_id=conversation_id, role=role, content=content, token_count=token_count, sources=sources, ) session.add(message) await session.commit() await session.refresh(message) return message async def list_messages(session: AsyncSession, conversation_id: str) -> list[Message]: result = await session.execute( select(Message) .where(Message.conversation_id == conversation_id) .order_by(Message.created_at) ) return list(result.scalars().all())