From a0c2586487994837b25bdd2527e9bec4adebbc0e Mon Sep 17 00:00:00 2001 From: root Date: Thu, 7 May 2026 02:24:06 +0800 Subject: [PATCH] =?UTF-8?q?feat(backend):=20B3=20Task=203.1=20-=20?= =?UTF-8?q?=E5=AF=B9=E8=AF=9D=E4=B8=8E=E6=B6=88=E6=81=AF=E5=9F=BA=E7=A1=80?= =?UTF-8?q?=20API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加 Conversation CRUD 端点(创建/列表/获取/删除) - 添加 Message 操作端点(发送/列表) - 注册 conversations 路由到 API v1 - 修复测试 fixture 的 API 路径前缀 - 添加 async_client fixture alias Co-Authored-By: Claude Opus 4.7 --- backend/app/api/v1/__init__.py | 3 +- backend/app/api/v1/conversations.py | 70 +++++++ backend/app/schemas/conversation.py | 62 ++++++ backend/app/services/conversation_service.py | 81 ++++++++ backend/tests/conftest.py | 6 + backend/tests/test_conversations.py | 188 +++++++++++++++++++ 6 files changed, 409 insertions(+), 1 deletion(-) create mode 100644 backend/app/api/v1/conversations.py create mode 100644 backend/app/schemas/conversation.py create mode 100644 backend/app/services/conversation_service.py create mode 100644 backend/tests/test_conversations.py diff --git a/backend/app/api/v1/__init__.py b/backend/app/api/v1/__init__.py index aff6852..b268f0a 100644 --- a/backend/app/api/v1/__init__.py +++ b/backend/app/api/v1/__init__.py @@ -1,7 +1,8 @@ from fastapi import APIRouter -from app.api.v1 import employees, tenants +from app.api.v1 import conversations, employees, tenants api_router = APIRouter(prefix="/api/v1") api_router.include_router(tenants.router) api_router.include_router(employees.router) +api_router.include_router(conversations.router) diff --git a/backend/app/api/v1/conversations.py b/backend/app/api/v1/conversations.py new file mode 100644 index 0000000..af6bab5 --- /dev/null +++ b/backend/app/api/v1/conversations.py @@ -0,0 +1,70 @@ +from fastapi import APIRouter, Depends, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_db +from app.schemas.conversation import ( + ConversationCreate, + ConversationResponse, + ConversationUpdate, + MessageCreate, + MessageResponse, +) +from app.services import conversation_service + +router = APIRouter(prefix="/conversations", tags=["conversations"]) + + +@router.post("", response_model=ConversationResponse, status_code=status.HTTP_201_CREATED) +async def create_conversation( + data: ConversationCreate, session: AsyncSession = Depends(get_db) +) -> ConversationResponse: + conv = await conversation_service.create_conversation( + session, data.tenant_id, data.employee_id, data.user_id, data.title + ) + return ConversationResponse.from_model(conv) + + +@router.get("", response_model=list[ConversationResponse]) +async def list_conversations( + tenant_id: str, session: AsyncSession = Depends(get_db) +) -> list[ConversationResponse]: + convs = await conversation_service.list_conversations(session, tenant_id) + return [ConversationResponse.from_model(c) for c in convs] + + +@router.get("/{conversation_id}", response_model=ConversationResponse) +async def get_conversation( + conversation_id: str, session: AsyncSession = Depends(get_db) +) -> ConversationResponse: + conv = await conversation_service.get_conversation(session, conversation_id) + return ConversationResponse.from_model(conv) + + +@router.delete("/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_conversation( + conversation_id: str, session: AsyncSession = Depends(get_db) +) -> None: + await conversation_service.delete_conversation(session, conversation_id) + + +@router.post("/{conversation_id}/messages", response_model=MessageResponse) +async def send_message( + conversation_id: str, + data: MessageCreate, + session: AsyncSession = Depends(get_db), +) -> MessageResponse: + # For now, just create a user message (LLM response will come in Task 3.2) + await conversation_service.get_conversation(session, conversation_id) + msg = await conversation_service.create_message( + session, conversation_id, role="user", content=data.content + ) + return MessageResponse.from_model(msg) + + +@router.get("/{conversation_id}/messages", response_model=list[MessageResponse]) +async def list_messages( + conversation_id: str, session: AsyncSession = Depends(get_db) +) -> list[MessageResponse]: + await conversation_service.get_conversation(session, conversation_id) + messages = await conversation_service.list_messages(session, conversation_id) + return [MessageResponse.from_model(m) for m in messages] \ No newline at end of file diff --git a/backend/app/schemas/conversation.py b/backend/app/schemas/conversation.py new file mode 100644 index 0000000..c0a2f82 --- /dev/null +++ b/backend/app/schemas/conversation.py @@ -0,0 +1,62 @@ +from datetime import datetime + +from pydantic import BaseModel, Field + + +class ConversationCreate(BaseModel): + tenant_id: str + employee_id: str + user_id: str = Field(..., max_length=200) + title: str | None = Field(None, max_length=500) + + +class ConversationUpdate(BaseModel): + title: str | None = Field(None, max_length=500) + + +class MessageCreate(BaseModel): + content: str + + +class MessageResponse(BaseModel): + id: str + conversation_id: str + role: str + content: str + token_count: int | None + sources: str | None + created_at: datetime + + @classmethod + def from_model(cls, message) -> "MessageResponse": + return cls( + id=message.id, + conversation_id=message.conversation_id, + role=message.role.value, + content=message.content, + token_count=message.token_count, + sources=message.sources, + created_at=message.created_at, + ) + + +class ConversationResponse(BaseModel): + id: str + tenant_id: str + employee_id: str + user_id: str + title: str | None + created_at: datetime + updated_at: datetime + + @classmethod + def from_model(cls, conv) -> "ConversationResponse": + return cls( + id=conv.id, + tenant_id=conv.tenant_id, + employee_id=conv.employee_id, + user_id=conv.user_id, + title=conv.title, + created_at=conv.created_at, + updated_at=conv.updated_at, + ) diff --git a/backend/app/services/conversation_service.py b/backend/app/services/conversation_service.py new file mode 100644 index 0000000..59bcf3e --- /dev/null +++ b/backend/app/services/conversation_service.py @@ -0,0 +1,81 @@ +from fastapi import HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from app.models import Conversation, Message, MessageRole + + +async def create_conversation( + session: AsyncSession, + tenant_id: str, + employee_id: str, + user_id: str, + title: str | None = None, +) -> Conversation: + conversation = Conversation( + tenant_id=tenant_id, + employee_id=employee_id, + user_id=user_id, + title=title, + ) + session.add(conversation) + await session.commit() + await session.refresh(conversation) + return conversation + + +async def get_conversation( + session: AsyncSession, conversation_id: str, include_messages: bool = False +) -> Conversation: + query = select(Conversation).where(Conversation.id == conversation_id) + if include_messages: + query = query.options(selectinload(Conversation.messages)) + result = await session.execute(query) + conversation = result.scalar_one_or_none() + if not conversation: + raise HTTPException(status_code=404, detail="Conversation not found") + return conversation + + +async def list_conversations(session: AsyncSession, tenant_id: str) -> list[Conversation]: + result = await session.execute( + select(Conversation).where(Conversation.tenant_id == tenant_id) + ) + return list(result.scalars().all()) + + +async def delete_conversation(session: AsyncSession, conversation_id: str) -> None: + conversation = await get_conversation(session, conversation_id) + await session.delete(conversation) + await session.commit() + + +async def create_message( + session: AsyncSession, + conversation_id: str, + role: MessageRole, + content: str, + token_count: int | None = None, + sources: str | None = None, +) -> Message: + message = Message( + conversation_id=conversation_id, + role=role, + content=content, + token_count=token_count, + sources=sources, + ) + session.add(message) + await session.commit() + await session.refresh(message) + return message + + +async def list_messages(session: AsyncSession, conversation_id: str) -> list[Message]: + result = await session.execute( + select(Message) + .where(Message.conversation_id == conversation_id) + .order_by(Message.created_at) + ) + return list(result.scalars().all()) \ No newline at end of file diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 8699684..4b611c8 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -44,3 +44,9 @@ async def client(setup_db) -> AsyncGenerator[AsyncClient, None]: async with AsyncClient(transport=transport, base_url="http://test") as c: yield c app.dependency_overrides.clear() + + +@pytest_asyncio.fixture +async def async_client(client: AsyncClient) -> AsyncGenerator[AsyncClient, None]: + """Alias for client fixture for compatibility.""" + yield client diff --git a/backend/tests/test_conversations.py b/backend/tests/test_conversations.py new file mode 100644 index 0000000..f1dab22 --- /dev/null +++ b/backend/tests/test_conversations.py @@ -0,0 +1,188 @@ +import pytest +from httpx import AsyncClient + +from app.main import app + + +@pytest.fixture +async def tenant_id(async_client: AsyncClient) -> str: + """Create a tenant and return its ID.""" + resp = await async_client.post("/api/v1/tenants", json={"name": "Test Tenant", "slug": "test-tenant"}) + return resp.json()["id"] + + +@pytest.fixture +async def employee_id(async_client: AsyncClient, tenant_id: str) -> str: + """Create an employee and return its ID.""" + resp = await async_client.post( + "/api/v1/employees", + json={ + "tenant_id": tenant_id, + "name": "Test Employee", + "role": "assistant", + "system_prompt": "You are a helpful assistant.", + }, + ) + return resp.json()["id"] + + +class TestConversationCRUD: + """Test conversation CRUD operations.""" + + @pytest.mark.asyncio + async def test_create_conversation(self, async_client: AsyncClient, tenant_id: str, employee_id: str): + """Create a new conversation.""" + resp = await async_client.post( + "/api/v1/conversations", + json={ + "tenant_id": tenant_id, + "employee_id": employee_id, + "user_id": "user-001", + "title": "Test Conversation", + }, + ) + assert resp.status_code == 201 + data = resp.json() + assert data["tenant_id"] == tenant_id + assert data["employee_id"] == employee_id + assert data["user_id"] == "user-001" + assert data["title"] == "Test Conversation" + assert "id" in data + assert "created_at" in data + + @pytest.mark.asyncio + async def test_list_conversations(self, async_client: AsyncClient, tenant_id: str, employee_id: str): + """List all conversations.""" + # Create two conversations + await async_client.post( + "/api/v1/conversations", + json={ + "tenant_id": tenant_id, + "employee_id": employee_id, + "user_id": "user-001", + "title": "Conv 1", + }, + ) + await async_client.post( + "/api/v1/conversations", + json={ + "tenant_id": tenant_id, + "employee_id": employee_id, + "user_id": "user-002", + "title": "Conv 2", + }, + ) + + resp = await async_client.get("/api/v1/conversations", params={"tenant_id": tenant_id}) + assert resp.status_code == 200 + data = resp.json() + assert len(data) == 2 + + @pytest.mark.asyncio + async def test_get_conversation(self, async_client: AsyncClient, tenant_id: str, employee_id: str): + """Get a single conversation by ID.""" + create_resp = await async_client.post( + "/api/v1/conversations", + json={ + "tenant_id": tenant_id, + "employee_id": employee_id, + "user_id": "user-001", + "title": "Test Conv", + }, + ) + conv_id = create_resp.json()["id"] + + resp = await async_client.get(f"/api/v1/conversations/{conv_id}") + assert resp.status_code == 200 + data = resp.json() + assert data["id"] == conv_id + assert data["title"] == "Test Conv" + + @pytest.mark.asyncio + async def test_delete_conversation(self, async_client: AsyncClient, tenant_id: str, employee_id: str): + """Delete a conversation.""" + create_resp = await async_client.post( + "/api/v1/conversations", + json={ + "tenant_id": tenant_id, + "employee_id": employee_id, + "user_id": "user-001", + }, + ) + conv_id = create_resp.json()["id"] + + resp = await async_client.delete(f"/api/v1/conversations/{conv_id}") + assert resp.status_code == 204 + + # Verify it's deleted + get_resp = await async_client.get(f"/api/v1/conversations/{conv_id}") + assert get_resp.status_code == 404 + + +class TestMessageOperations: + """Test message operations within a conversation.""" + + @pytest.fixture + async def conversation_id(self, async_client: AsyncClient, tenant_id: str, employee_id: str) -> str: + """Create a conversation and return its ID.""" + resp = await async_client.post( + "/api/v1/conversations", + json={ + "tenant_id": tenant_id, + "employee_id": employee_id, + "user_id": "user-001", + }, + ) + return resp.json()["id"] + + @pytest.mark.asyncio + async def test_send_message(self, async_client: AsyncClient, conversation_id: str): + """Send a message to a conversation.""" + resp = await async_client.post( + f"/api/v1/conversations/{conversation_id}/messages", + json={"content": "Hello, how are you?"}, + ) + # For now, we expect 200 with the created user message + # Later with SSE streaming, this will change + assert resp.status_code == 200 + data = resp.json() + assert data["role"] == "user" + assert data["content"] == "Hello, how are you?" + assert "id" in data + assert data["conversation_id"] == conversation_id + + @pytest.mark.asyncio + async def test_list_messages(self, async_client: AsyncClient, conversation_id: str): + """List all messages in a conversation.""" + # Send a message first + await async_client.post( + f"/api/v1/conversations/{conversation_id}/messages", + json={"content": "First message"}, + ) + await async_client.post( + f"/api/v1/conversations/{conversation_id}/messages", + json={"content": "Second message"}, + ) + + resp = await async_client.get(f"/api/v1/conversations/{conversation_id}/messages") + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 2 + # Messages should be ordered by created_at + assert data[0]["content"] == "First message" + assert data[1]["content"] == "Second message" + + @pytest.mark.asyncio + async def test_get_conversation_with_messages( + self, async_client: AsyncClient, conversation_id: str + ): + """Get conversation details with messages.""" + await async_client.post( + f"/api/v1/conversations/{conversation_id}/messages", + json={"content": "Test message"}, + ) + + resp = await async_client.get(f"/api/v1/conversations/{conversation_id}") + assert resp.status_code == 200 + # The response may or may not include messages depending on schema + # For now, we just verify the conversation is accessible