diff --git a/backend/app/api/v1/conversations.py b/backend/app/api/v1/conversations.py index af6bab5..b117f1b 100644 --- a/backend/app/api/v1/conversations.py +++ b/backend/app/api/v1/conversations.py @@ -1,4 +1,5 @@ from fastapi import APIRouter, Depends, status +from fastapi.responses import StreamingResponse from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_db @@ -53,7 +54,6 @@ async def send_message( data: MessageCreate, session: AsyncSession = Depends(get_db), ) -> MessageResponse: - # For now, just create a user message (LLM response will come in Task 3.2) await conversation_service.get_conversation(session, conversation_id) msg = await conversation_service.create_message( session, conversation_id, role="user", content=data.content @@ -61,6 +61,19 @@ async def send_message( return MessageResponse.from_model(msg) +@router.post("/{conversation_id}/messages/stream") +async def send_message_stream( + conversation_id: str, + data: MessageCreate, + session: AsyncSession = Depends(get_db), +) -> StreamingResponse: + """SSE 流式响应端点""" + return StreamingResponse( + conversation_service.stream_chat(session, conversation_id, data.content), + media_type="text/event-stream", + ) + + @router.get("/{conversation_id}/messages", response_model=list[MessageResponse]) async def list_messages( conversation_id: str, session: AsyncSession = Depends(get_db) diff --git a/backend/app/services/conversation_service.py b/backend/app/services/conversation_service.py index 59bcf3e..8d8360b 100644 --- a/backend/app/services/conversation_service.py +++ b/backend/app/services/conversation_service.py @@ -1,9 +1,14 @@ +import json + from fastapi import HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from app.models import Conversation, Message, MessageRole +from app.models import Conversation, DigitalEmployee, Message, MessageRole +from app.providers.base import LLMMessage +from app.services.employee_service import get_employee +from app.services.tenant_service import get_provider_for_tenant async def create_conversation( @@ -78,4 +83,67 @@ async def list_messages(session: AsyncSession, conversation_id: str) -> list[Mes .where(Message.conversation_id == conversation_id) .order_by(Message.created_at) ) - return list(result.scalars().all()) \ No newline at end of file + return list(result.scalars().all()) + + +async def build_messages_for_llm( + session: AsyncSession, + conversation_id: str, + user_content: str, +) -> tuple[list[LLMMessage], DigitalEmployee]: + """构建发送给 LLM 的消息列表,返回 (messages, employee)""" + conversation = await get_conversation(session, conversation_id, include_messages=True) + employee = await get_employee(session, conversation.employee_id) + + messages: list[LLMMessage] = [] + + # System prompt + if employee.system_prompt: + messages.append(LLMMessage(role="system", content=employee.system_prompt)) + + # 历史消息(按 max_context_messages 裁剪) + history = conversation.messages or [] + max_context = employee.max_context_messages + if len(history) > max_context: + history = history[-max_context:] + + for msg in history: + messages.append(LLMMessage(role=msg.role.value, content=msg.content)) + + # 当前用户消息 + messages.append(LLMMessage(role="user", content=user_content)) + + return messages, employee + + +async def stream_chat( + session: AsyncSession, + conversation_id: str, + user_content: str, +): + """流式对话生成器,yield SSE 事件字符串""" + # 构建消息 + messages, employee = await build_messages_for_llm(session, conversation_id, user_content) + + # 保存用户消息 + await create_message(session, conversation_id, MessageRole.user, user_content) + + # 获取 Provider + provider = await get_provider_for_tenant(session, employee.tenant_id) + + # 发送 message_start 事件 + yield f"data: {json.dumps({'type': 'message_start'})}\n\n" + + # 流式生成 + full_content = "" + token_count = 0 + async for token in provider.chat_stream(messages): + full_content += token + token_count += 1 + yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n" + + # 保存 assistant 消息 + await create_message(session, conversation_id, MessageRole.assistant, full_content, token_count=token_count) + + # 发送 message_end 事件 + yield f"data: {json.dumps({'type': 'message_end', 'token_count': token_count})}\n\n" \ No newline at end of file diff --git a/backend/tests/test_conversations.py b/backend/tests/test_conversations.py index f1dab22..fd734d5 100644 --- a/backend/tests/test_conversations.py +++ b/backend/tests/test_conversations.py @@ -8,7 +8,17 @@ from app.main import app async def tenant_id(async_client: AsyncClient) -> str: """Create a tenant and return its ID.""" resp = await async_client.post("/api/v1/tenants", json={"name": "Test Tenant", "slug": "test-tenant"}) - return resp.json()["id"] + tid = resp.json()["id"] + # Configure tenant with LLM provider (mock settings) + await async_client.put( + f"/api/v1/tenants/{tid}/config", + json={ + "llm_provider": "openai", + "llm_api_key": "test-api-key", + "llm_model": "gpt-4", + }, + ) + return tid @pytest.fixture @@ -186,3 +196,101 @@ class TestMessageOperations: assert resp.status_code == 200 # The response may or may not include messages depending on schema # For now, we just verify the conversation is accessible + + +class TestSSEStreaming: + """Test SSE streaming message responses.""" + + @pytest.fixture + async def conversation_id(self, async_client: AsyncClient, tenant_id: str, employee_id: str) -> str: + """Create a conversation and return its ID.""" + resp = await async_client.post( + "/api/v1/conversations", + json={ + "tenant_id": tenant_id, + "employee_id": employee_id, + "user_id": "user-001", + }, + ) + return resp.json()["id"] + + @pytest.mark.asyncio + async def test_sse_stream_events(self, async_client: AsyncClient, conversation_id: str): + """Test SSE stream returns properly formatted events.""" + from unittest.mock import AsyncMock, patch, MagicMock + + # Mock provider that yields tokens + async def mock_stream(messages): + for token in ["Hello", " ", "world", "!"]: + yield token + + mock_provider = MagicMock() + mock_provider.chat_stream = mock_stream + + with patch( + "app.services.conversation_service.get_provider_for_tenant", + AsyncMock(return_value=mock_provider) + ): + async with async_client.stream( + "POST", + f"/api/v1/conversations/{conversation_id}/messages/stream", + json={"content": "Hello"}, + ) as response: + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + events = [] + async for line in response.aiter_lines(): + if line.startswith("data: "): + events.append(line[6:]) # Remove "data: " prefix + + # Should have at least message_start, some tokens, and message_end + assert len(events) >= 2 + + # First event should be message_start + import json + first_event = json.loads(events[0]) + assert first_event["type"] == "message_start" + + # Last event should be message_end + last_event = json.loads(events[-1]) + assert last_event["type"] == "message_end" + assert "token_count" in last_event + + @pytest.mark.asyncio + async def test_sse_stream_persists_messages( + self, async_client: AsyncClient, conversation_id: str + ): + """Test that SSE stream persists both user and assistant messages.""" + from unittest.mock import AsyncMock, patch, MagicMock + + # Mock provider that yields tokens + async def mock_stream(messages): + for token in ["Test", " ", "response"]: + yield token + + mock_provider = MagicMock() + mock_provider.chat_stream = mock_stream + + with patch( + "app.services.conversation_service.get_provider_for_tenant", + AsyncMock(return_value=mock_provider) + ): + async with async_client.stream( + "POST", + f"/api/v1/conversations/{conversation_id}/messages/stream", + json={"content": "Test question"}, + ) as response: + # Consume the stream + async for _ in response.aiter_lines(): + pass + + # Verify messages were persisted + resp = await async_client.get(f"/api/v1/conversations/{conversation_id}/messages") + messages = resp.json() + + # Should have user message and assistant message + assert len(messages) >= 2 + assert messages[0]["role"] == "user" + assert messages[0]["content"] == "Test question" + assert messages[1]["role"] == "assistant"