feat(backend): B3 Task 3.1 - 对话与消息基础 API
- 添加 Conversation CRUD 端点(创建/列表/获取/删除) - 添加 Message 操作端点(发送/列表) - 注册 conversations 路由到 API v1 - 修复测试 fixture 的 API 路径前缀 - 添加 async_client fixture alias Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
c62156af53
commit
a0c2586487
@ -1,7 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.v1 import employees, tenants
|
||||
from app.api.v1 import conversations, employees, tenants
|
||||
|
||||
api_router = APIRouter(prefix="/api/v1")
|
||||
api_router.include_router(tenants.router)
|
||||
api_router.include_router(employees.router)
|
||||
api_router.include_router(conversations.router)
|
||||
|
||||
70
backend/app/api/v1/conversations.py
Normal file
70
backend/app/api/v1/conversations.py
Normal file
@ -0,0 +1,70 @@
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.schemas.conversation import (
|
||||
ConversationCreate,
|
||||
ConversationResponse,
|
||||
ConversationUpdate,
|
||||
MessageCreate,
|
||||
MessageResponse,
|
||||
)
|
||||
from app.services import conversation_service
|
||||
|
||||
router = APIRouter(prefix="/conversations", tags=["conversations"])
|
||||
|
||||
|
||||
@router.post("", response_model=ConversationResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_conversation(
|
||||
data: ConversationCreate, session: AsyncSession = Depends(get_db)
|
||||
) -> ConversationResponse:
|
||||
conv = await conversation_service.create_conversation(
|
||||
session, data.tenant_id, data.employee_id, data.user_id, data.title
|
||||
)
|
||||
return ConversationResponse.from_model(conv)
|
||||
|
||||
|
||||
@router.get("", response_model=list[ConversationResponse])
|
||||
async def list_conversations(
|
||||
tenant_id: str, session: AsyncSession = Depends(get_db)
|
||||
) -> list[ConversationResponse]:
|
||||
convs = await conversation_service.list_conversations(session, tenant_id)
|
||||
return [ConversationResponse.from_model(c) for c in convs]
|
||||
|
||||
|
||||
@router.get("/{conversation_id}", response_model=ConversationResponse)
|
||||
async def get_conversation(
|
||||
conversation_id: str, session: AsyncSession = Depends(get_db)
|
||||
) -> ConversationResponse:
|
||||
conv = await conversation_service.get_conversation(session, conversation_id)
|
||||
return ConversationResponse.from_model(conv)
|
||||
|
||||
|
||||
@router.delete("/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_conversation(
|
||||
conversation_id: str, session: AsyncSession = Depends(get_db)
|
||||
) -> None:
|
||||
await conversation_service.delete_conversation(session, conversation_id)
|
||||
|
||||
|
||||
@router.post("/{conversation_id}/messages", response_model=MessageResponse)
|
||||
async def send_message(
|
||||
conversation_id: str,
|
||||
data: MessageCreate,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
) -> MessageResponse:
|
||||
# For now, just create a user message (LLM response will come in Task 3.2)
|
||||
await conversation_service.get_conversation(session, conversation_id)
|
||||
msg = await conversation_service.create_message(
|
||||
session, conversation_id, role="user", content=data.content
|
||||
)
|
||||
return MessageResponse.from_model(msg)
|
||||
|
||||
|
||||
@router.get("/{conversation_id}/messages", response_model=list[MessageResponse])
|
||||
async def list_messages(
|
||||
conversation_id: str, session: AsyncSession = Depends(get_db)
|
||||
) -> list[MessageResponse]:
|
||||
await conversation_service.get_conversation(session, conversation_id)
|
||||
messages = await conversation_service.list_messages(session, conversation_id)
|
||||
return [MessageResponse.from_model(m) for m in messages]
|
||||
62
backend/app/schemas/conversation.py
Normal file
62
backend/app/schemas/conversation.py
Normal file
@ -0,0 +1,62 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ConversationCreate(BaseModel):
|
||||
tenant_id: str
|
||||
employee_id: str
|
||||
user_id: str = Field(..., max_length=200)
|
||||
title: str | None = Field(None, max_length=500)
|
||||
|
||||
|
||||
class ConversationUpdate(BaseModel):
|
||||
title: str | None = Field(None, max_length=500)
|
||||
|
||||
|
||||
class MessageCreate(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
id: str
|
||||
conversation_id: str
|
||||
role: str
|
||||
content: str
|
||||
token_count: int | None
|
||||
sources: str | None
|
||||
created_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, message) -> "MessageResponse":
|
||||
return cls(
|
||||
id=message.id,
|
||||
conversation_id=message.conversation_id,
|
||||
role=message.role.value,
|
||||
content=message.content,
|
||||
token_count=message.token_count,
|
||||
sources=message.sources,
|
||||
created_at=message.created_at,
|
||||
)
|
||||
|
||||
|
||||
class ConversationResponse(BaseModel):
|
||||
id: str
|
||||
tenant_id: str
|
||||
employee_id: str
|
||||
user_id: str
|
||||
title: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, conv) -> "ConversationResponse":
|
||||
return cls(
|
||||
id=conv.id,
|
||||
tenant_id=conv.tenant_id,
|
||||
employee_id=conv.employee_id,
|
||||
user_id=conv.user_id,
|
||||
title=conv.title,
|
||||
created_at=conv.created_at,
|
||||
updated_at=conv.updated_at,
|
||||
)
|
||||
81
backend/app/services/conversation_service.py
Normal file
81
backend/app/services/conversation_service.py
Normal file
@ -0,0 +1,81 @@
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models import Conversation, Message, MessageRole
|
||||
|
||||
|
||||
async def create_conversation(
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
employee_id: str,
|
||||
user_id: str,
|
||||
title: str | None = None,
|
||||
) -> Conversation:
|
||||
conversation = Conversation(
|
||||
tenant_id=tenant_id,
|
||||
employee_id=employee_id,
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
)
|
||||
session.add(conversation)
|
||||
await session.commit()
|
||||
await session.refresh(conversation)
|
||||
return conversation
|
||||
|
||||
|
||||
async def get_conversation(
|
||||
session: AsyncSession, conversation_id: str, include_messages: bool = False
|
||||
) -> Conversation:
|
||||
query = select(Conversation).where(Conversation.id == conversation_id)
|
||||
if include_messages:
|
||||
query = query.options(selectinload(Conversation.messages))
|
||||
result = await session.execute(query)
|
||||
conversation = result.scalar_one_or_none()
|
||||
if not conversation:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
return conversation
|
||||
|
||||
|
||||
async def list_conversations(session: AsyncSession, tenant_id: str) -> list[Conversation]:
|
||||
result = await session.execute(
|
||||
select(Conversation).where(Conversation.tenant_id == tenant_id)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def delete_conversation(session: AsyncSession, conversation_id: str) -> None:
|
||||
conversation = await get_conversation(session, conversation_id)
|
||||
await session.delete(conversation)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def create_message(
|
||||
session: AsyncSession,
|
||||
conversation_id: str,
|
||||
role: MessageRole,
|
||||
content: str,
|
||||
token_count: int | None = None,
|
||||
sources: str | None = None,
|
||||
) -> Message:
|
||||
message = Message(
|
||||
conversation_id=conversation_id,
|
||||
role=role,
|
||||
content=content,
|
||||
token_count=token_count,
|
||||
sources=sources,
|
||||
)
|
||||
session.add(message)
|
||||
await session.commit()
|
||||
await session.refresh(message)
|
||||
return message
|
||||
|
||||
|
||||
async def list_messages(session: AsyncSession, conversation_id: str) -> list[Message]:
|
||||
result = await session.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
@ -44,3 +44,9 @@ async def client(setup_db) -> AsyncGenerator[AsyncClient, None]:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_client(client: AsyncClient) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Alias for client fixture for compatibility."""
|
||||
yield client
|
||||
|
||||
188
backend/tests/test_conversations.py
Normal file
188
backend/tests/test_conversations.py
Normal file
@ -0,0 +1,188 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def tenant_id(async_client: AsyncClient) -> str:
|
||||
"""Create a tenant and return its ID."""
|
||||
resp = await async_client.post("/api/v1/tenants", json={"name": "Test Tenant", "slug": "test-tenant"})
|
||||
return resp.json()["id"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def employee_id(async_client: AsyncClient, tenant_id: str) -> str:
|
||||
"""Create an employee and return its ID."""
|
||||
resp = await async_client.post(
|
||||
"/api/v1/employees",
|
||||
json={
|
||||
"tenant_id": tenant_id,
|
||||
"name": "Test Employee",
|
||||
"role": "assistant",
|
||||
"system_prompt": "You are a helpful assistant.",
|
||||
},
|
||||
)
|
||||
return resp.json()["id"]
|
||||
|
||||
|
||||
class TestConversationCRUD:
|
||||
"""Test conversation CRUD operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_conversation(self, async_client: AsyncClient, tenant_id: str, employee_id: str):
|
||||
"""Create a new conversation."""
|
||||
resp = await async_client.post(
|
||||
"/api/v1/conversations",
|
||||
json={
|
||||
"tenant_id": tenant_id,
|
||||
"employee_id": employee_id,
|
||||
"user_id": "user-001",
|
||||
"title": "Test Conversation",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["tenant_id"] == tenant_id
|
||||
assert data["employee_id"] == employee_id
|
||||
assert data["user_id"] == "user-001"
|
||||
assert data["title"] == "Test Conversation"
|
||||
assert "id" in data
|
||||
assert "created_at" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversations(self, async_client: AsyncClient, tenant_id: str, employee_id: str):
|
||||
"""List all conversations."""
|
||||
# Create two conversations
|
||||
await async_client.post(
|
||||
"/api/v1/conversations",
|
||||
json={
|
||||
"tenant_id": tenant_id,
|
||||
"employee_id": employee_id,
|
||||
"user_id": "user-001",
|
||||
"title": "Conv 1",
|
||||
},
|
||||
)
|
||||
await async_client.post(
|
||||
"/api/v1/conversations",
|
||||
json={
|
||||
"tenant_id": tenant_id,
|
||||
"employee_id": employee_id,
|
||||
"user_id": "user-002",
|
||||
"title": "Conv 2",
|
||||
},
|
||||
)
|
||||
|
||||
resp = await async_client.get("/api/v1/conversations", params={"tenant_id": tenant_id})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation(self, async_client: AsyncClient, tenant_id: str, employee_id: str):
|
||||
"""Get a single conversation by ID."""
|
||||
create_resp = await async_client.post(
|
||||
"/api/v1/conversations",
|
||||
json={
|
||||
"tenant_id": tenant_id,
|
||||
"employee_id": employee_id,
|
||||
"user_id": "user-001",
|
||||
"title": "Test Conv",
|
||||
},
|
||||
)
|
||||
conv_id = create_resp.json()["id"]
|
||||
|
||||
resp = await async_client.get(f"/api/v1/conversations/{conv_id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == conv_id
|
||||
assert data["title"] == "Test Conv"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_conversation(self, async_client: AsyncClient, tenant_id: str, employee_id: str):
|
||||
"""Delete a conversation."""
|
||||
create_resp = await async_client.post(
|
||||
"/api/v1/conversations",
|
||||
json={
|
||||
"tenant_id": tenant_id,
|
||||
"employee_id": employee_id,
|
||||
"user_id": "user-001",
|
||||
},
|
||||
)
|
||||
conv_id = create_resp.json()["id"]
|
||||
|
||||
resp = await async_client.delete(f"/api/v1/conversations/{conv_id}")
|
||||
assert resp.status_code == 204
|
||||
|
||||
# Verify it's deleted
|
||||
get_resp = await async_client.get(f"/api/v1/conversations/{conv_id}")
|
||||
assert get_resp.status_code == 404
|
||||
|
||||
|
||||
class TestMessageOperations:
|
||||
"""Test message operations within a conversation."""
|
||||
|
||||
@pytest.fixture
|
||||
async def conversation_id(self, async_client: AsyncClient, tenant_id: str, employee_id: str) -> str:
|
||||
"""Create a conversation and return its ID."""
|
||||
resp = await async_client.post(
|
||||
"/api/v1/conversations",
|
||||
json={
|
||||
"tenant_id": tenant_id,
|
||||
"employee_id": employee_id,
|
||||
"user_id": "user-001",
|
||||
},
|
||||
)
|
||||
return resp.json()["id"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message(self, async_client: AsyncClient, conversation_id: str):
|
||||
"""Send a message to a conversation."""
|
||||
resp = await async_client.post(
|
||||
f"/api/v1/conversations/{conversation_id}/messages",
|
||||
json={"content": "Hello, how are you?"},
|
||||
)
|
||||
# For now, we expect 200 with the created user message
|
||||
# Later with SSE streaming, this will change
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["role"] == "user"
|
||||
assert data["content"] == "Hello, how are you?"
|
||||
assert "id" in data
|
||||
assert data["conversation_id"] == conversation_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_messages(self, async_client: AsyncClient, conversation_id: str):
|
||||
"""List all messages in a conversation."""
|
||||
# Send a message first
|
||||
await async_client.post(
|
||||
f"/api/v1/conversations/{conversation_id}/messages",
|
||||
json={"content": "First message"},
|
||||
)
|
||||
await async_client.post(
|
||||
f"/api/v1/conversations/{conversation_id}/messages",
|
||||
json={"content": "Second message"},
|
||||
)
|
||||
|
||||
resp = await async_client.get(f"/api/v1/conversations/{conversation_id}/messages")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) >= 2
|
||||
# Messages should be ordered by created_at
|
||||
assert data[0]["content"] == "First message"
|
||||
assert data[1]["content"] == "Second message"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_with_messages(
|
||||
self, async_client: AsyncClient, conversation_id: str
|
||||
):
|
||||
"""Get conversation details with messages."""
|
||||
await async_client.post(
|
||||
f"/api/v1/conversations/{conversation_id}/messages",
|
||||
json={"content": "Test message"},
|
||||
)
|
||||
|
||||
resp = await async_client.get(f"/api/v1/conversations/{conversation_id}")
|
||||
assert resp.status_code == 200
|
||||
# The response may or may not include messages depending on schema
|
||||
# For now, we just verify the conversation is accessible
|
||||
Loading…
x
Reference in New Issue
Block a user