feat(backend): B3 Task 3.1 - 对话与消息基础 API

- 添加 Conversation CRUD 端点(创建/列表/获取/删除)
- 添加 Message 操作端点(发送/列表)
- 注册 conversations 路由到 API v1
- 修复测试 fixture 的 API 路径前缀
- 添加 async_client fixture alias

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
root 2026-05-07 02:24:06 +08:00
parent c62156af53
commit a0c2586487
6 changed files with 409 additions and 1 deletions

View File

@ -1,7 +1,8 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.api.v1 import employees, tenants from app.api.v1 import conversations, employees, tenants
api_router = APIRouter(prefix="/api/v1") api_router = APIRouter(prefix="/api/v1")
api_router.include_router(tenants.router) api_router.include_router(tenants.router)
api_router.include_router(employees.router) api_router.include_router(employees.router)
api_router.include_router(conversations.router)

View File

@ -0,0 +1,70 @@
from fastapi import APIRouter, Depends, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db
from app.schemas.conversation import (
ConversationCreate,
ConversationResponse,
ConversationUpdate,
MessageCreate,
MessageResponse,
)
from app.services import conversation_service
router = APIRouter(prefix="/conversations", tags=["conversations"])
@router.post("", response_model=ConversationResponse, status_code=status.HTTP_201_CREATED)
async def create_conversation(
data: ConversationCreate, session: AsyncSession = Depends(get_db)
) -> ConversationResponse:
conv = await conversation_service.create_conversation(
session, data.tenant_id, data.employee_id, data.user_id, data.title
)
return ConversationResponse.from_model(conv)
@router.get("", response_model=list[ConversationResponse])
async def list_conversations(
tenant_id: str, session: AsyncSession = Depends(get_db)
) -> list[ConversationResponse]:
convs = await conversation_service.list_conversations(session, tenant_id)
return [ConversationResponse.from_model(c) for c in convs]
@router.get("/{conversation_id}", response_model=ConversationResponse)
async def get_conversation(
conversation_id: str, session: AsyncSession = Depends(get_db)
) -> ConversationResponse:
conv = await conversation_service.get_conversation(session, conversation_id)
return ConversationResponse.from_model(conv)
@router.delete("/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_conversation(
conversation_id: str, session: AsyncSession = Depends(get_db)
) -> None:
await conversation_service.delete_conversation(session, conversation_id)
@router.post("/{conversation_id}/messages", response_model=MessageResponse)
async def send_message(
conversation_id: str,
data: MessageCreate,
session: AsyncSession = Depends(get_db),
) -> MessageResponse:
# For now, just create a user message (LLM response will come in Task 3.2)
await conversation_service.get_conversation(session, conversation_id)
msg = await conversation_service.create_message(
session, conversation_id, role="user", content=data.content
)
return MessageResponse.from_model(msg)
@router.get("/{conversation_id}/messages", response_model=list[MessageResponse])
async def list_messages(
conversation_id: str, session: AsyncSession = Depends(get_db)
) -> list[MessageResponse]:
await conversation_service.get_conversation(session, conversation_id)
messages = await conversation_service.list_messages(session, conversation_id)
return [MessageResponse.from_model(m) for m in messages]

View File

@ -0,0 +1,62 @@
from datetime import datetime
from pydantic import BaseModel, Field
class ConversationCreate(BaseModel):
tenant_id: str
employee_id: str
user_id: str = Field(..., max_length=200)
title: str | None = Field(None, max_length=500)
class ConversationUpdate(BaseModel):
title: str | None = Field(None, max_length=500)
class MessageCreate(BaseModel):
content: str
class MessageResponse(BaseModel):
id: str
conversation_id: str
role: str
content: str
token_count: int | None
sources: str | None
created_at: datetime
@classmethod
def from_model(cls, message) -> "MessageResponse":
return cls(
id=message.id,
conversation_id=message.conversation_id,
role=message.role.value,
content=message.content,
token_count=message.token_count,
sources=message.sources,
created_at=message.created_at,
)
class ConversationResponse(BaseModel):
id: str
tenant_id: str
employee_id: str
user_id: str
title: str | None
created_at: datetime
updated_at: datetime
@classmethod
def from_model(cls, conv) -> "ConversationResponse":
return cls(
id=conv.id,
tenant_id=conv.tenant_id,
employee_id=conv.employee_id,
user_id=conv.user_id,
title=conv.title,
created_at=conv.created_at,
updated_at=conv.updated_at,
)

View File

@ -0,0 +1,81 @@
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models import Conversation, Message, MessageRole
async def create_conversation(
session: AsyncSession,
tenant_id: str,
employee_id: str,
user_id: str,
title: str | None = None,
) -> Conversation:
conversation = Conversation(
tenant_id=tenant_id,
employee_id=employee_id,
user_id=user_id,
title=title,
)
session.add(conversation)
await session.commit()
await session.refresh(conversation)
return conversation
async def get_conversation(
session: AsyncSession, conversation_id: str, include_messages: bool = False
) -> Conversation:
query = select(Conversation).where(Conversation.id == conversation_id)
if include_messages:
query = query.options(selectinload(Conversation.messages))
result = await session.execute(query)
conversation = result.scalar_one_or_none()
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
return conversation
async def list_conversations(session: AsyncSession, tenant_id: str) -> list[Conversation]:
result = await session.execute(
select(Conversation).where(Conversation.tenant_id == tenant_id)
)
return list(result.scalars().all())
async def delete_conversation(session: AsyncSession, conversation_id: str) -> None:
conversation = await get_conversation(session, conversation_id)
await session.delete(conversation)
await session.commit()
async def create_message(
session: AsyncSession,
conversation_id: str,
role: MessageRole,
content: str,
token_count: int | None = None,
sources: str | None = None,
) -> Message:
message = Message(
conversation_id=conversation_id,
role=role,
content=content,
token_count=token_count,
sources=sources,
)
session.add(message)
await session.commit()
await session.refresh(message)
return message
async def list_messages(session: AsyncSession, conversation_id: str) -> list[Message]:
result = await session.execute(
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at)
)
return list(result.scalars().all())

View File

@ -44,3 +44,9 @@ async def client(setup_db) -> AsyncGenerator[AsyncClient, None]:
async with AsyncClient(transport=transport, base_url="http://test") as c: async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c yield c
app.dependency_overrides.clear() app.dependency_overrides.clear()
@pytest_asyncio.fixture
async def async_client(client: AsyncClient) -> AsyncGenerator[AsyncClient, None]:
"""Alias for client fixture for compatibility."""
yield client

View File

@ -0,0 +1,188 @@
import pytest
from httpx import AsyncClient
from app.main import app
@pytest.fixture
async def tenant_id(async_client: AsyncClient) -> str:
"""Create a tenant and return its ID."""
resp = await async_client.post("/api/v1/tenants", json={"name": "Test Tenant", "slug": "test-tenant"})
return resp.json()["id"]
@pytest.fixture
async def employee_id(async_client: AsyncClient, tenant_id: str) -> str:
"""Create an employee and return its ID."""
resp = await async_client.post(
"/api/v1/employees",
json={
"tenant_id": tenant_id,
"name": "Test Employee",
"role": "assistant",
"system_prompt": "You are a helpful assistant.",
},
)
return resp.json()["id"]
class TestConversationCRUD:
"""Test conversation CRUD operations."""
@pytest.mark.asyncio
async def test_create_conversation(self, async_client: AsyncClient, tenant_id: str, employee_id: str):
"""Create a new conversation."""
resp = await async_client.post(
"/api/v1/conversations",
json={
"tenant_id": tenant_id,
"employee_id": employee_id,
"user_id": "user-001",
"title": "Test Conversation",
},
)
assert resp.status_code == 201
data = resp.json()
assert data["tenant_id"] == tenant_id
assert data["employee_id"] == employee_id
assert data["user_id"] == "user-001"
assert data["title"] == "Test Conversation"
assert "id" in data
assert "created_at" in data
@pytest.mark.asyncio
async def test_list_conversations(self, async_client: AsyncClient, tenant_id: str, employee_id: str):
"""List all conversations."""
# Create two conversations
await async_client.post(
"/api/v1/conversations",
json={
"tenant_id": tenant_id,
"employee_id": employee_id,
"user_id": "user-001",
"title": "Conv 1",
},
)
await async_client.post(
"/api/v1/conversations",
json={
"tenant_id": tenant_id,
"employee_id": employee_id,
"user_id": "user-002",
"title": "Conv 2",
},
)
resp = await async_client.get("/api/v1/conversations", params={"tenant_id": tenant_id})
assert resp.status_code == 200
data = resp.json()
assert len(data) == 2
@pytest.mark.asyncio
async def test_get_conversation(self, async_client: AsyncClient, tenant_id: str, employee_id: str):
"""Get a single conversation by ID."""
create_resp = await async_client.post(
"/api/v1/conversations",
json={
"tenant_id": tenant_id,
"employee_id": employee_id,
"user_id": "user-001",
"title": "Test Conv",
},
)
conv_id = create_resp.json()["id"]
resp = await async_client.get(f"/api/v1/conversations/{conv_id}")
assert resp.status_code == 200
data = resp.json()
assert data["id"] == conv_id
assert data["title"] == "Test Conv"
@pytest.mark.asyncio
async def test_delete_conversation(self, async_client: AsyncClient, tenant_id: str, employee_id: str):
"""Delete a conversation."""
create_resp = await async_client.post(
"/api/v1/conversations",
json={
"tenant_id": tenant_id,
"employee_id": employee_id,
"user_id": "user-001",
},
)
conv_id = create_resp.json()["id"]
resp = await async_client.delete(f"/api/v1/conversations/{conv_id}")
assert resp.status_code == 204
# Verify it's deleted
get_resp = await async_client.get(f"/api/v1/conversations/{conv_id}")
assert get_resp.status_code == 404
class TestMessageOperations:
"""Test message operations within a conversation."""
@pytest.fixture
async def conversation_id(self, async_client: AsyncClient, tenant_id: str, employee_id: str) -> str:
"""Create a conversation and return its ID."""
resp = await async_client.post(
"/api/v1/conversations",
json={
"tenant_id": tenant_id,
"employee_id": employee_id,
"user_id": "user-001",
},
)
return resp.json()["id"]
@pytest.mark.asyncio
async def test_send_message(self, async_client: AsyncClient, conversation_id: str):
"""Send a message to a conversation."""
resp = await async_client.post(
f"/api/v1/conversations/{conversation_id}/messages",
json={"content": "Hello, how are you?"},
)
# For now, we expect 200 with the created user message
# Later with SSE streaming, this will change
assert resp.status_code == 200
data = resp.json()
assert data["role"] == "user"
assert data["content"] == "Hello, how are you?"
assert "id" in data
assert data["conversation_id"] == conversation_id
@pytest.mark.asyncio
async def test_list_messages(self, async_client: AsyncClient, conversation_id: str):
"""List all messages in a conversation."""
# Send a message first
await async_client.post(
f"/api/v1/conversations/{conversation_id}/messages",
json={"content": "First message"},
)
await async_client.post(
f"/api/v1/conversations/{conversation_id}/messages",
json={"content": "Second message"},
)
resp = await async_client.get(f"/api/v1/conversations/{conversation_id}/messages")
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 2
# Messages should be ordered by created_at
assert data[0]["content"] == "First message"
assert data[1]["content"] == "Second message"
@pytest.mark.asyncio
async def test_get_conversation_with_messages(
self, async_client: AsyncClient, conversation_id: str
):
"""Get conversation details with messages."""
await async_client.post(
f"/api/v1/conversations/{conversation_id}/messages",
json={"content": "Test message"},
)
resp = await async_client.get(f"/api/v1/conversations/{conversation_id}")
assert resp.status_code == 200
# The response may or may not include messages depending on schema
# For now, we just verify the conversation is accessible