feat(backend): 数字员工平台 B1+B2 批次实现

B1: 项目脚手架 + 数据模型 + 租户管理
- Task 1.1: FastAPI 项目脚手架、SQLite + async SQLAlchemy
- Task 1.2: 7 个数据模型 (Tenant, TenantConfig, DigitalEmployee, Conversation, Message, KnowledgeBase, Document)
- Task 1.3: 租户 CRUD API + LLM 配置(含 API Key AES 加密)

B2: 数字员工配置 + LLM Provider 抽象层
- Task 2.1: 数字员工 CRUD API(关联知识库)
- Task 2.2: BaseLLMProvider 抽象接口 + OpenAI/Qwen Provider
- Task 2.3: Provider 动态实例化 + test-provider 端点

验证: 26 个测试全部通过

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
root 2026-05-06 11:29:48 +08:00
parent 44c37420af
commit c62156af53
37 changed files with 4419 additions and 1 deletions

16
backend/.env.example Normal file
View File

@ -0,0 +1,16 @@
# LLM Provider
OPENAI_API_KEY=sk-your-openai-key
QWEN_API_KEY=sk-your-qwen-key
# Encryption key for storing tenant LLM API keys (generate: python -c "import secrets; print(secrets.token_hex(32))")
ENCRYPTION_KEY=your-32-byte-hex-key
# Database
DATABASE_URL=sqlite+aiosqlite:///./data.db
# Upload
UPLOAD_DIR=./uploads
MAX_FILE_SIZE_MB=20
# Rate limiting
RATE_LIMIT_PER_MINUTE=60

0
backend/app/__init__.py Normal file
View File

View File

11
backend/app/api/deps.py Normal file
View File

@ -0,0 +1,11 @@
from typing import AsyncGenerator
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import async_session
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async with async_session() as session:
yield session

View File

@ -0,0 +1,7 @@
from fastapi import APIRouter
from app.api.v1 import employees, tenants
api_router = APIRouter(prefix="/api/v1")
api_router.include_router(tenants.router)
api_router.include_router(employees.router)

View File

@ -0,0 +1,76 @@
from fastapi import APIRouter, Depends, Query, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db
from app.schemas.employee import (
EmployeeCreate,
EmployeeResponse,
EmployeeUpdate,
)
from app.services import employee_service
router = APIRouter(prefix="/employees", tags=["employees"])
@router.post("", response_model=EmployeeResponse, status_code=status.HTTP_201_CREATED)
async def create_employee(
data: EmployeeCreate, session: AsyncSession = Depends(get_db)
) -> EmployeeResponse:
employee = await employee_service.create_employee(
session,
tenant_id=data.tenant_id,
name=data.name,
role=data.role,
system_prompt=data.system_prompt,
greeting=data.greeting,
avatar_url=data.avatar_url,
temperature=data.temperature,
max_context_messages=data.max_context_messages,
knowledge_base_ids=data.knowledge_base_ids,
)
return EmployeeResponse.from_model(employee)
@router.get("", response_model=list[EmployeeResponse])
async def list_employees(
tenant_id: str = Query(...), session: AsyncSession = Depends(get_db)
) -> list[EmployeeResponse]:
employees = await employee_service.list_employees(session, tenant_id)
return [EmployeeResponse.from_model(e) for e in employees]
@router.get("/{employee_id}", response_model=EmployeeResponse)
async def get_employee(
employee_id: str, session: AsyncSession = Depends(get_db)
) -> EmployeeResponse:
employee = await employee_service.get_employee(session, employee_id)
return EmployeeResponse.from_model(employee)
@router.patch("/{employee_id}", response_model=EmployeeResponse)
async def update_employee(
employee_id: str,
data: EmployeeUpdate,
session: AsyncSession = Depends(get_db),
) -> EmployeeResponse:
employee = await employee_service.update_employee(
session,
employee_id,
name=data.name,
role=data.role,
system_prompt=data.system_prompt,
greeting=data.greeting,
avatar_url=data.avatar_url,
temperature=data.temperature,
max_context_messages=data.max_context_messages,
knowledge_base_ids=data.knowledge_base_ids,
status=data.status,
)
return EmployeeResponse.from_model(employee)
@router.delete("/{employee_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_employee(
employee_id: str, session: AsyncSession = Depends(get_db)
) -> None:
await employee_service.delete_employee(session, employee_id)

View File

@ -0,0 +1,128 @@
from fastapi import APIRouter, Depends, status
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db
from app.schemas.tenant import (
TenantConfigCreate,
TenantConfigResponse,
TenantCreate,
TenantResponse,
TenantUpdate,
)
from app.services import tenant_service
router = APIRouter(prefix="/tenants", tags=["tenants"])
@router.post("", response_model=TenantResponse, status_code=status.HTTP_201_CREATED)
async def create_tenant(
data: TenantCreate, session: AsyncSession = Depends(get_db)
) -> TenantResponse:
tenant = await tenant_service.create_tenant(session, data.name, data.slug)
return TenantResponse(
id=tenant.id,
name=tenant.name,
slug=tenant.slug,
status=tenant.status.value,
created_at=tenant.created_at,
updated_at=tenant.updated_at,
)
@router.get("", response_model=list[TenantResponse])
async def list_tenants(session: AsyncSession = Depends(get_db)) -> list[TenantResponse]:
tenants = await tenant_service.list_tenants(session)
return [
TenantResponse(
id=t.id,
name=t.name,
slug=t.slug,
status=t.status.value,
created_at=t.created_at,
updated_at=t.updated_at,
)
for t in tenants
]
@router.get("/{tenant_id}", response_model=TenantResponse)
async def get_tenant(
tenant_id: str, session: AsyncSession = Depends(get_db)
) -> TenantResponse:
tenant = await tenant_service.get_tenant(session, tenant_id)
return TenantResponse(
id=tenant.id,
name=tenant.name,
slug=tenant.slug,
status=tenant.status.value,
created_at=tenant.created_at,
updated_at=tenant.updated_at,
)
@router.patch("/{tenant_id}", response_model=TenantResponse)
async def update_tenant(
tenant_id: str,
data: TenantUpdate,
session: AsyncSession = Depends(get_db),
) -> TenantResponse:
tenant = await tenant_service.update_tenant(session, tenant_id, data.name, data.slug)
return TenantResponse(
id=tenant.id,
name=tenant.name,
slug=tenant.slug,
status=tenant.status.value,
created_at=tenant.created_at,
updated_at=tenant.updated_at,
)
@router.delete("/{tenant_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_tenant(tenant_id: str, session: AsyncSession = Depends(get_db)) -> None:
await tenant_service.delete_tenant(session, tenant_id)
@router.put("/{tenant_id}/config", response_model=TenantConfigResponse)
async def update_tenant_config(
tenant_id: str,
data: TenantConfigCreate,
session: AsyncSession = Depends(get_db),
) -> TenantConfigResponse:
config = await tenant_service.update_tenant_config(
session,
tenant_id,
data.llm_provider,
data.llm_api_key,
data.llm_model,
data.llm_base_url,
data.max_tokens_per_month or 1000000,
)
return TenantConfigResponse(
id=config.id,
tenant_id=config.tenant_id,
llm_provider=config.llm_provider,
llm_model=config.llm_model,
llm_base_url=config.llm_base_url,
max_tokens_per_month=config.max_tokens_per_month,
created_at=config.created_at,
updated_at=config.updated_at,
)
class ProviderTestResponse(BaseModel):
provider: str
model: str
status: str
@router.post("/{tenant_id}/test-provider", response_model=ProviderTestResponse)
async def test_provider(
tenant_id: str, session: AsyncSession = Depends(get_db)
) -> ProviderTestResponse:
provider = await tenant_service.get_provider_for_tenant(session, tenant_id)
return ProviderTestResponse(
provider=provider.__class__.__name__.replace("Provider", "").lower(),
model=provider.model,
status="ok",
)

15
backend/app/config.py Normal file
View File

@ -0,0 +1,15 @@
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
database_url: str = "sqlite+aiosqlite:///./data.db"
secret_key: str = "dev-secret-key-change-in-production"
encryption_salt: str = "dev-salt-change-in-production"
upload_dir: str = "./uploads"
max_file_size_mb: int = 20
rate_limit_per_minute: int = 60
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
settings = Settings()

21
backend/app/database.py Normal file
View File

@ -0,0 +1,21 @@
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from app.config import settings
engine = create_async_engine(settings.database_url, echo=False)
async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
class Base(DeclarativeBase):
pass
async def get_db() -> AsyncSession: # type: ignore[misc]
async with async_session() as session:
yield session
async def init_db() -> None:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

25
backend/app/main.py Normal file
View File

@ -0,0 +1,25 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from app.api.v1 import api_router
from app.database import init_db
@asynccontextmanager
async def lifespan(app: FastAPI):
await init_db()
yield
app = FastAPI(
title="Digital Employee Platform",
version="0.1.0",
lifespan=lifespan,
)
app.include_router(api_router)
@app.get("/health")
async def health_check():
return {"status": "ok"}

View File

View File

@ -0,0 +1,18 @@
from app.models.conversation import Conversation, Message, MessageRole
from app.models.employee import DigitalEmployee, EmployeeStatus
from app.models.knowledge import Document, DocumentStatus, KnowledgeBase
from app.models.tenant import Tenant, TenantConfig, TenantStatus
__all__ = [
"Tenant",
"TenantConfig",
"TenantStatus",
"DigitalEmployee",
"EmployeeStatus",
"Conversation",
"Message",
"MessageRole",
"KnowledgeBase",
"Document",
"DocumentStatus",
]

View File

@ -0,0 +1,58 @@
import enum
import uuid
from datetime import datetime
from sqlalchemy import DateTime, Enum, ForeignKey, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
def _uuid() -> str:
return str(uuid.uuid4())
def _now() -> datetime:
return datetime.utcnow()
class MessageRole(str, enum.Enum):
system = "system"
user = "user"
assistant = "assistant"
class Conversation(Base):
__tablename__ = "conversations"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
tenant_id: Mapped[str] = mapped_column(
String(36), ForeignKey("tenants.id"), nullable=False, index=True
)
employee_id: Mapped[str] = mapped_column(
String(36), ForeignKey("digital_employees.id"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(String(200), nullable=False)
title: Mapped[str | None] = mapped_column(String(500), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=_now)
updated_at: Mapped[datetime] = mapped_column(DateTime, default=_now, onupdate=_now)
messages: Mapped[list["Message"]] = relationship(
back_populates="conversation", cascade="all, delete-orphan", lazy="selectin"
)
class Message(Base):
__tablename__ = "messages"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
conversation_id: Mapped[str] = mapped_column(
String(36), ForeignKey("conversations.id"), nullable=False, index=True
)
role: Mapped[MessageRole] = mapped_column(Enum(MessageRole), nullable=False)
content: Mapped[str] = mapped_column(Text, nullable=False)
token_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
sources: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=_now)
conversation: Mapped["Conversation"] = relationship(back_populates="messages")

View File

@ -0,0 +1,45 @@
import enum
import uuid
from datetime import datetime
from sqlalchemy import DateTime, Enum, Float, ForeignKey, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
def _uuid() -> str:
return str(uuid.uuid4())
def _now() -> datetime:
return datetime.utcnow()
class EmployeeStatus(str, enum.Enum):
active = "active"
inactive = "inactive"
class DigitalEmployee(Base):
__tablename__ = "digital_employees"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
tenant_id: Mapped[str] = mapped_column(
String(36), ForeignKey("tenants.id"), nullable=False, index=True
)
name: Mapped[str] = mapped_column(String(200), nullable=False)
role: Mapped[str] = mapped_column(String(100), nullable=False)
avatar_url: Mapped[str | None] = mapped_column(String(500), nullable=True)
system_prompt: Mapped[str] = mapped_column(Text, nullable=False)
greeting: Mapped[str | None] = mapped_column(Text, nullable=True)
temperature: Mapped[float] = mapped_column(Float, default=0.7)
max_context_messages: Mapped[int] = mapped_column(Integer, default=20)
knowledge_base_ids: Mapped[str] = mapped_column(Text, default="[]")
status: Mapped[EmployeeStatus] = mapped_column(
Enum(EmployeeStatus), default=EmployeeStatus.active
)
created_at: Mapped[datetime] = mapped_column(DateTime, default=_now)
updated_at: Mapped[datetime] = mapped_column(DateTime, default=_now, onupdate=_now)
tenant = relationship("Tenant", lazy="selectin")

View File

@ -0,0 +1,63 @@
import enum
import uuid
from datetime import datetime
from sqlalchemy import DateTime, Enum, ForeignKey, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
def _uuid() -> str:
return str(uuid.uuid4())
def _now() -> datetime:
return datetime.utcnow()
class DocumentStatus(str, enum.Enum):
pending = "pending"
processing = "processing"
completed = "completed"
failed = "failed"
class KnowledgeBase(Base):
__tablename__ = "knowledge_bases"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
tenant_id: Mapped[str] = mapped_column(
String(36), ForeignKey("tenants.id"), nullable=False, index=True
)
name: Mapped[str] = mapped_column(String(200), nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
embedding_model: Mapped[str] = mapped_column(String(100), default="text-embedding-3-small")
created_at: Mapped[datetime] = mapped_column(DateTime, default=_now)
updated_at: Mapped[datetime] = mapped_column(DateTime, default=_now, onupdate=_now)
documents: Mapped[list["Document"]] = relationship(
back_populates="knowledge_base", cascade="all, delete-orphan", lazy="selectin"
)
tenant = relationship("Tenant", lazy="selectin")
class Document(Base):
__tablename__ = "documents"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
knowledge_base_id: Mapped[str] = mapped_column(
String(36), ForeignKey("knowledge_bases.id"), nullable=False, index=True
)
filename: Mapped[str] = mapped_column(String(500), nullable=False)
file_type: Mapped[str] = mapped_column(String(20), nullable=False)
file_size: Mapped[int] = mapped_column(Integer, default=0)
chunk_count: Mapped[int] = mapped_column(Integer, default=0)
status: Mapped[DocumentStatus] = mapped_column(
Enum(DocumentStatus), default=DocumentStatus.pending
)
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=_now)
updated_at: Mapped[datetime] = mapped_column(DateTime, default=_now, onupdate=_now)
knowledge_base: Mapped["KnowledgeBase"] = relationship(back_populates="documents")

View File

@ -0,0 +1,55 @@
import enum
import uuid
from datetime import datetime
from sqlalchemy import DateTime, Enum, Float, ForeignKey, Integer, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
def _uuid() -> str:
return str(uuid.uuid4())
def _now() -> datetime:
return datetime.utcnow()
class TenantStatus(str, enum.Enum):
active = "active"
suspended = "suspended"
deleted = "deleted"
class Tenant(Base):
__tablename__ = "tenants"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
name: Mapped[str] = mapped_column(String(200), nullable=False)
slug: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
status: Mapped[TenantStatus] = mapped_column(
Enum(TenantStatus), default=TenantStatus.active
)
created_at: Mapped[datetime] = mapped_column(DateTime, default=_now)
updated_at: Mapped[datetime] = mapped_column(DateTime, default=_now, onupdate=_now)
config: Mapped["TenantConfig | None"] = relationship(back_populates="tenant", uselist=False)
class TenantConfig(Base):
__tablename__ = "tenant_configs"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
tenant_id: Mapped[str] = mapped_column(
String(36), ForeignKey("tenants.id"), unique=True, nullable=False
)
llm_provider: Mapped[str] = mapped_column(String(50), nullable=False)
llm_api_key: Mapped[str] = mapped_column(Text, nullable=False)
llm_model: Mapped[str] = mapped_column(String(100), nullable=False)
llm_base_url: Mapped[str | None] = mapped_column(String(500), nullable=True)
max_tokens_per_month: Mapped[int] = mapped_column(Integer, default=1000000)
created_at: Mapped[datetime] = mapped_column(DateTime, default=_now)
updated_at: Mapped[datetime] = mapped_column(DateTime, default=_now, onupdate=_now)
tenant: Mapped["Tenant"] = relationship(back_populates="config")

View File

@ -0,0 +1,23 @@
"""LLM Provider 注册表"""
from app.providers.base import BaseLLMProvider
from app.providers.openai_provider import OpenAIProvider
from app.providers.qwen_provider import QwenProvider
PROVIDERS = {
"openai": OpenAIProvider,
"qwen": QwenProvider,
}
def get_provider(
provider_type: str,
api_key: str,
model: str,
base_url: str | None = None,
) -> BaseLLMProvider:
"""根据类型获取 Provider 实例"""
if provider_type not in PROVIDERS:
raise ValueError(f"Unknown provider: {provider_type}")
provider_cls = PROVIDERS[provider_type]
return provider_cls(api_key=api_key, model=model, base_url=base_url)

View File

@ -0,0 +1,35 @@
"""LLM Provider 基础抽象层"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass
class LLMMessage:
role: str # "system", "user", "assistant"
content: str
@dataclass
class LLMResponse:
content: str
model: str
usage: dict # {"total_tokens": int, "prompt_tokens": int, "completion_tokens": int}
class BaseLLMProvider(ABC):
"""LLM Provider 抽象基类"""
@abstractmethod
async def chat(self, messages: list[LLMMessage]) -> LLMResponse:
"""非流式对话"""
pass
@abstractmethod
async def chat_stream(self, messages: list[LLMMessage]) -> str:
"""流式对话,返回 token 异步生成器"""
pass
@abstractmethod
async def embed(self, texts: list[str]) -> list[list[float]]:
"""生成嵌入向量"""
pass

View File

@ -0,0 +1,43 @@
"""OpenAI Provider"""
from openai import AsyncOpenAI
from app.providers.base import BaseLLMProvider, LLMMessage, LLMResponse
class OpenAIProvider(BaseLLMProvider):
def __init__(self, api_key: str, model: str = "gpt-4", base_url: str | None = None):
self.api_key = api_key
self.model = model
self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
async def chat(self, messages: list[LLMMessage]) -> LLMResponse:
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": m.role, "content": m.content} for m in messages],
)
return LLMResponse(
content=response.choices[0].message.content or "",
model=response.model,
usage={
"total_tokens": response.usage.total_tokens,
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
},
)
async def chat_stream(self, messages: list[LLMMessage]) -> str:
stream = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": m.role, "content": m.content} for m in messages],
stream=True,
)
async for chunk in stream:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
async def embed(self, texts: list[str]) -> list[list[float]]:
response = await self.client.embeddings.create(
model=self.model,
input=texts,
)
return [item.embedding for item in response.data]

View File

@ -0,0 +1,51 @@
"""Qwen Provider (兼容 OpenAI SDK)"""
from openai import AsyncOpenAI
from app.providers.base import BaseLLMProvider, LLMMessage, LLMResponse
class QwenProvider(BaseLLMProvider):
"""通义千问 Provider使用 OpenAI SDK 兼容模式"""
def __init__(
self,
api_key: str,
model: str = "qwen-turbo",
base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1",
):
self.api_key = api_key
self.model = model
self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
async def chat(self, messages: list[LLMMessage]) -> LLMResponse:
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": m.role, "content": m.content} for m in messages],
)
return LLMResponse(
content=response.choices[0].message.content or "",
model=response.model,
usage={
"total_tokens": response.usage.total_tokens,
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
},
)
async def chat_stream(self, messages: list[LLMMessage]) -> str:
stream = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": m.role, "content": m.content} for m in messages],
stream=True,
)
async for chunk in stream:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
async def embed(self, texts: list[str]) -> list[list[float]]:
# Qwen embedding 使用不同接口
response = await self.client.embeddings.create(
model="text-embedding-v3",
input=texts,
)
return [item.embedding for item in response.data]

View File

View File

@ -0,0 +1,63 @@
import json
from datetime import datetime
from pydantic import BaseModel, Field
class EmployeeCreate(BaseModel):
tenant_id: str
name: str = Field(..., max_length=200)
role: str = Field(..., max_length=100)
system_prompt: str
greeting: str | None = None
avatar_url: str | None = Field(None, max_length=500)
temperature: float = 0.7
max_context_messages: int = 20
knowledge_base_ids: list[str] = []
class EmployeeUpdate(BaseModel):
name: str | None = Field(None, max_length=200)
role: str | None = Field(None, max_length=100)
system_prompt: str | None = None
greeting: str | None = None
avatar_url: str | None = Field(None, max_length=500)
temperature: float | None = None
max_context_messages: int | None = None
knowledge_base_ids: list[str] | None = None
status: str | None = None
class EmployeeResponse(BaseModel):
id: str
tenant_id: str
name: str
role: str
system_prompt: str
greeting: str | None
avatar_url: str | None
temperature: float
max_context_messages: int
knowledge_base_ids: list[str]
status: str
created_at: datetime
updated_at: datetime
@classmethod
def from_model(cls, employee) -> "EmployeeResponse":
kb_ids = json.loads(employee.knowledge_base_ids) if employee.knowledge_base_ids else []
return cls(
id=employee.id,
tenant_id=employee.tenant_id,
name=employee.name,
role=employee.role,
system_prompt=employee.system_prompt,
greeting=employee.greeting,
avatar_url=employee.avatar_url,
temperature=employee.temperature,
max_context_messages=employee.max_context_messages,
knowledge_base_ids=kb_ids,
status=employee.status.value,
created_at=employee.created_at,
updated_at=employee.updated_at,
)

View File

@ -0,0 +1,41 @@
from datetime import datetime
from pydantic import BaseModel, Field
class TenantCreate(BaseModel):
name: str = Field(..., max_length=200)
slug: str = Field(..., max_length=100)
class TenantUpdate(BaseModel):
name: str | None = Field(None, max_length=200)
slug: str | None = Field(None, max_length=100)
class TenantResponse(BaseModel):
id: str
name: str
slug: str
status: str
created_at: datetime
updated_at: datetime
class TenantConfigCreate(BaseModel):
llm_provider: str = Field(..., max_length=50)
llm_api_key: str
llm_model: str = Field(..., max_length=100)
llm_base_url: str | None = Field(None, max_length=500)
max_tokens_per_month: int | None = None
class TenantConfigResponse(BaseModel):
id: str
tenant_id: str
llm_provider: str
llm_model: str
llm_base_url: str | None
max_tokens_per_month: int
created_at: datetime
updated_at: datetime

View File

View File

@ -0,0 +1,33 @@
import base64
import os
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from app.config import settings
def _get_encryption_key() -> bytes:
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=settings.encryption_salt.encode(),
iterations=480000,
)
return base64.urlsafe_b64encode(kdf.derive(settings.secret_key.encode()))
def encrypt_api_key(plaintext: str) -> str:
if not plaintext:
return ""
fernet = Fernet(_get_encryption_key())
encrypted = fernet.encrypt(plaintext.encode())
return base64.urlsafe_b64encode(encrypted).decode()
def decrypt_api_key(ciphertext: str) -> str:
if not ciphertext:
return ""
fernet = Fernet(_get_encryption_key())
encrypted = base64.urlsafe_b64decode(ciphertext.encode())
return fernet.decrypt(encrypted).decode()

View File

@ -0,0 +1,104 @@
import json
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import DigitalEmployee, EmployeeStatus
async def create_employee(
session: AsyncSession,
tenant_id: str,
name: str,
role: str,
system_prompt: str,
greeting: str | None = None,
avatar_url: str | None = None,
temperature: float = 0.7,
max_context_messages: int = 20,
knowledge_base_ids: list[str] | None = None,
) -> DigitalEmployee:
employee = DigitalEmployee(
tenant_id=tenant_id,
name=name,
role=role,
system_prompt=system_prompt,
greeting=greeting,
avatar_url=avatar_url,
temperature=temperature,
max_context_messages=max_context_messages,
knowledge_base_ids=json.dumps(knowledge_base_ids or []),
)
session.add(employee)
await session.commit()
await session.refresh(employee)
return employee
async def get_employee(session: AsyncSession, employee_id: str) -> DigitalEmployee:
result = await session.execute(
select(DigitalEmployee).where(
DigitalEmployee.id == employee_id,
DigitalEmployee.status != EmployeeStatus.inactive,
)
)
employee = result.scalar_one_or_none()
if not employee:
raise HTTPException(status_code=404, detail="Employee not found")
return employee
async def list_employees(session: AsyncSession, tenant_id: str) -> list[DigitalEmployee]:
result = await session.execute(
select(DigitalEmployee).where(
DigitalEmployee.tenant_id == tenant_id,
DigitalEmployee.status != EmployeeStatus.inactive,
)
)
return list(result.scalars().all())
async def update_employee(
session: AsyncSession,
employee_id: str,
name: str | None = None,
role: str | None = None,
system_prompt: str | None = None,
greeting: str | None = None,
avatar_url: str | None = None,
temperature: float | None = None,
max_context_messages: int | None = None,
knowledge_base_ids: list[str] | None = None,
status: str | None = None,
) -> DigitalEmployee:
employee = await get_employee(session, employee_id)
if name is not None:
employee.name = name
if role is not None:
employee.role = role
if system_prompt is not None:
employee.system_prompt = system_prompt
if greeting is not None:
employee.greeting = greeting
if avatar_url is not None:
employee.avatar_url = avatar_url
if temperature is not None:
employee.temperature = temperature
if max_context_messages is not None:
employee.max_context_messages = max_context_messages
if knowledge_base_ids is not None:
employee.knowledge_base_ids = json.dumps(knowledge_base_ids)
if status is not None:
employee.status = EmployeeStatus(status)
await session.commit()
await session.refresh(employee)
return employee
async def delete_employee(session: AsyncSession, employee_id: str) -> None:
employee = await get_employee(session, employee_id)
employee.status = EmployeeStatus.inactive
await session.commit()

View File

@ -0,0 +1,125 @@
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import Tenant, TenantConfig, TenantStatus
from app.providers import get_provider
from app.providers.base import BaseLLMProvider
from app.services.crypto import decrypt_api_key, encrypt_api_key
async def get_provider_for_tenant(
session: AsyncSession, tenant_id: str
) -> BaseLLMProvider:
"""根据租户配置动态实例化 Provider"""
await get_tenant(session, tenant_id)
result = await session.execute(
select(TenantConfig).where(TenantConfig.tenant_id == tenant_id)
)
config = result.scalar_one_or_none()
if not config:
raise HTTPException(status_code=404, detail="Tenant config not found")
api_key = decrypt_api_key(config.llm_api_key)
return get_provider(
provider_type=config.llm_provider,
api_key=api_key,
model=config.llm_model,
base_url=config.llm_base_url,
)
async def create_tenant(session: AsyncSession, name: str, slug: str) -> Tenant:
existing = await session.execute(select(Tenant).where(Tenant.slug == slug))
if existing.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Slug already exists")
tenant = Tenant(name=name, slug=slug)
session.add(tenant)
await session.commit()
await session.refresh(tenant)
return tenant
async def get_tenant(session: AsyncSession, tenant_id: str) -> Tenant:
result = await session.execute(
select(Tenant).where(Tenant.id == tenant_id, Tenant.status != TenantStatus.deleted)
)
tenant = result.scalar_one_or_none()
if not tenant:
raise HTTPException(status_code=404, detail="Tenant not found")
return tenant
async def list_tenants(session: AsyncSession) -> list[Tenant]:
result = await session.execute(
select(Tenant).where(Tenant.status != TenantStatus.deleted)
)
return list(result.scalars().all())
async def update_tenant(
session: AsyncSession, tenant_id: str, name: str | None, slug: str | None
) -> Tenant:
tenant = await get_tenant(session, tenant_id)
if slug and slug != tenant.slug:
existing = await session.execute(select(Tenant).where(Tenant.slug == slug))
if existing.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Slug already exists")
tenant.slug = slug
if name:
tenant.name = name
await session.commit()
await session.refresh(tenant)
return tenant
async def delete_tenant(session: AsyncSession, tenant_id: str) -> None:
tenant = await get_tenant(session, tenant_id)
tenant.status = TenantStatus.deleted
await session.commit()
async def update_tenant_config(
session: AsyncSession,
tenant_id: str,
llm_provider: str,
llm_api_key: str,
llm_model: str,
llm_base_url: str | None = None,
max_tokens_per_month: int = 1000000,
) -> TenantConfig:
await get_tenant(session, tenant_id)
result = await session.execute(
select(TenantConfig).where(TenantConfig.tenant_id == tenant_id)
)
config = result.scalar_one_or_none()
encrypted_key = encrypt_api_key(llm_api_key)
if config:
config.llm_provider = llm_provider
config.llm_api_key = encrypted_key
config.llm_model = llm_model
config.llm_base_url = llm_base_url
config.max_tokens_per_month = max_tokens_per_month
else:
config = TenantConfig(
tenant_id=tenant_id,
llm_provider=llm_provider,
llm_api_key=encrypted_key,
llm_model=llm_model,
llm_base_url=llm_base_url,
max_tokens_per_month=max_tokens_per_month,
)
session.add(config)
await session.commit()
await session.refresh(config)
return config

44
backend/pyproject.toml Normal file
View File

@ -0,0 +1,44 @@
[project]
name = "digital-employee-platform"
version = "0.1.0"
description = "Multi-tenant digital employee platform with RAG knowledge base"
requires-python = ">=3.11"
dependencies = [
"fastapi>=0.115.0",
"uvicorn[standard]>=0.30.0",
"sqlalchemy[asyncio]>=2.0.0",
"aiosqlite>=0.20.0",
"pydantic>=2.0.0",
"pydantic-settings>=2.0.0",
"httpx>=0.27.0",
"cryptography>=43.0.0",
"openai>=1.40.0",
"chromadb>=0.5.0",
"pypdf2>=3.0.0",
"python-docx>=1.1.0",
"python-multipart>=0.0.9",
"sse-starlette>=2.0.0",
]
[dependency-groups]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.24.0",
"ruff>=0.6.0",
"mypy>=1.11.0",
]
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
[tool.ruff]
target-version = "py311"
line-length = 100
[tool.ruff.lint]
select = ["E", "F", "I", "N", "W", "UP"]
[tool.mypy]
python_version = "3.11"
strict = true

View File

46
backend/tests/conftest.py Normal file
View File

@ -0,0 +1,46 @@
import asyncio
from typing import AsyncGenerator
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.api.deps import get_db
from app.database import Base
from app.main import app
TEST_DB_URL = "sqlite+aiosqlite:///:memory:"
test_engine = create_async_engine(TEST_DB_URL, echo=False)
TestSession = async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False)
@pytest.fixture(scope="session")
def event_loop():
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture(scope="function")
async def setup_db():
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
async with TestSession() as session:
yield session
@pytest_asyncio.fixture
async def client(setup_db) -> AsyncGenerator[AsyncClient, None]:
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c
app.dependency_overrides.clear()

View File

@ -0,0 +1,188 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_create_employee(client: AsyncClient):
"""创建数字员工"""
# 先创建租户
tenant_resp = await client.post(
"/api/v1/tenants",
json={"name": "Employee Test", "slug": "employee-test"},
)
tenant_id = tenant_resp.json()["id"]
response = await client.post(
"/api/v1/employees",
json={
"tenant_id": tenant_id,
"name": "客服助手",
"role": "customer_service",
"system_prompt": "你是一个专业的客服助手。",
"greeting": "你好,有什么可以帮助您的?",
"temperature": 0.7,
"max_context_messages": 20,
"knowledge_base_ids": [],
},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "客服助手"
assert data["role"] == "customer_service"
assert data["status"] == "active"
assert "id" in data
@pytest.mark.asyncio
async def test_create_employee_with_knowledge_bases(client: AsyncClient):
"""创建数字员工时关联知识库"""
tenant_resp = await client.post(
"/api/v1/tenants",
json={"name": "KB Test", "slug": "kb-test"},
)
tenant_id = tenant_resp.json()["id"]
# 创建知识库
from app.models import KnowledgeBase
from tests.conftest import TestSession
async with TestSession() as session:
kb1 = KnowledgeBase(tenant_id=tenant_id, name="产品文档")
kb2 = KnowledgeBase(tenant_id=tenant_id, name="FAQ")
session.add_all([kb1, kb2])
await session.commit()
kb1_id = kb1.id
kb2_id = kb2.id
response = await client.post(
"/api/v1/employees",
json={
"tenant_id": tenant_id,
"name": "智能客服",
"role": "customer_service",
"system_prompt": "你是一个智能客服。",
"knowledge_base_ids": [kb1_id, kb2_id],
},
)
assert response.status_code == 201
data = response.json()
assert len(data["knowledge_base_ids"]) == 2
@pytest.mark.asyncio
async def test_list_employees(client: AsyncClient):
"""列出数字员工"""
tenant_resp = await client.post(
"/api/v1/tenants",
json={"name": "List Test", "slug": "list-test"},
)
tenant_id = tenant_resp.json()["id"]
await client.post(
"/api/v1/employees",
json={
"tenant_id": tenant_id,
"name": "员工A",
"role": "sales",
"system_prompt": "销售助手",
},
)
await client.post(
"/api/v1/employees",
json={
"tenant_id": tenant_id,
"name": "员工B",
"role": "support",
"system_prompt": "支持助手",
},
)
response = await client.get(f"/api/v1/employees?tenant_id={tenant_id}")
assert response.status_code == 200
data = response.json()
assert len(data) >= 2
@pytest.mark.asyncio
async def test_get_employee(client: AsyncClient):
"""获取单个数字员工"""
tenant_resp = await client.post(
"/api/v1/tenants",
json={"name": "Get Emp Test", "slug": "get-emp-test"},
)
tenant_id = tenant_resp.json()["id"]
create_resp = await client.post(
"/api/v1/employees",
json={
"tenant_id": tenant_id,
"name": "测试员工",
"role": "test",
"system_prompt": "测试",
},
)
employee_id = create_resp.json()["id"]
response = await client.get(f"/api/v1/employees/{employee_id}")
assert response.status_code == 200
assert response.json()["name"] == "测试员工"
@pytest.mark.asyncio
async def test_update_employee(client: AsyncClient):
"""更新数字员工"""
tenant_resp = await client.post(
"/api/v1/tenants",
json={"name": "Update Emp Test", "slug": "update-emp-test"},
)
tenant_id = tenant_resp.json()["id"]
create_resp = await client.post(
"/api/v1/employees",
json={
"tenant_id": tenant_id,
"name": "旧名称",
"role": "old_role",
"system_prompt": "旧提示词",
},
)
employee_id = create_resp.json()["id"]
response = await client.patch(
f"/api/v1/employees/{employee_id}",
json={
"name": "新名称",
"temperature": 0.5,
},
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "新名称"
assert data["temperature"] == 0.5
@pytest.mark.asyncio
async def test_delete_employee(client: AsyncClient):
"""删除数字员工"""
tenant_resp = await client.post(
"/api/v1/tenants",
json={"name": "Delete Emp Test", "slug": "delete-emp-test"},
)
tenant_id = tenant_resp.json()["id"]
create_resp = await client.post(
"/api/v1/employees",
json={
"tenant_id": tenant_id,
"name": "待删除",
"role": "temp",
"system_prompt": "临时",
},
)
employee_id = create_resp.json()["id"]
response = await client.delete(f"/api/v1/employees/{employee_id}")
assert response.status_code == 204
get_resp = await client.get(f"/api/v1/employees/{employee_id}")
assert get_resp.status_code == 404

View File

@ -0,0 +1,9 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_health_check(client: AsyncClient):
response = await client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}

View File

@ -14,7 +14,7 @@ from tests.conftest import test_engine
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_all_tables_created(): async def test_all_tables_created(setup_db):
expected_tables = { expected_tables = {
"tenants", "tenants",
"tenant_configs", "tenant_configs",

View File

@ -0,0 +1,196 @@
"""LLM Provider 抽象层测试"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.providers.base import LLMMessage, LLMResponse
@pytest.mark.asyncio
async def test_base_provider_interface():
"""验证 BaseLLMProvider 接口定义"""
from app.providers.base import BaseLLMProvider
# 接口应该定义 chat, chat_stream, embed 方法
assert hasattr(BaseLLMProvider, "chat")
assert hasattr(BaseLLMProvider, "chat_stream")
assert hasattr(BaseLLMProvider, "embed")
@pytest.mark.asyncio
async def test_llm_message_dataclass():
"""验证 LLMMessage 数据类"""
msg = LLMMessage(role="user", content="Hello")
assert msg.role == "user"
assert msg.content == "Hello"
@pytest.mark.asyncio
async def test_llm_response_dataclass():
"""验证 LLMResponse 数据类"""
resp = LLMResponse(content="Hi there!", model="gpt-4", usage={"total_tokens": 10})
assert resp.content == "Hi there!"
assert resp.model == "gpt-4"
assert resp.usage["total_tokens"] == 10
@pytest.mark.asyncio
async def test_openai_provider_chat():
"""测试 OpenAI Provider chat 方法mock"""
from app.providers.openai_provider import OpenAIProvider
provider = OpenAIProvider(api_key="sk-test", model="gpt-4")
with patch.object(provider, "client") as mock_client:
mock_response = MagicMock()
mock_response.choices = [MagicMock(message=MagicMock(content="Test response"))]
mock_response.usage = MagicMock(total_tokens=100, prompt_tokens=50, completion_tokens=50)
mock_response.model = "gpt-4"
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
result = await provider.chat([LLMMessage(role="user", content="Hi")])
assert result.content == "Test response"
assert result.model == "gpt-4"
@pytest.mark.asyncio
async def test_openai_provider_chat_stream():
"""测试 OpenAI Provider chat_stream 方法mock"""
from app.providers.openai_provider import OpenAIProvider
provider = OpenAIProvider(api_key="sk-test", model="gpt-4")
async def mock_stream():
class MockChunk:
choices = [MagicMock(delta=MagicMock(content="token"))]
yield MockChunk()
yield MockChunk()
with patch.object(provider, "client") as mock_client:
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream())
tokens = []
async for chunk in provider.chat_stream([LLMMessage(role="user", content="Hi")]):
tokens.append(chunk)
assert len(tokens) >= 1
@pytest.mark.asyncio
async def test_openai_provider_embed():
"""测试 OpenAI Provider embed 方法mock"""
from app.providers.openai_provider import OpenAIProvider
provider = OpenAIProvider(api_key="sk-test", model="text-embedding-3-small")
with patch.object(provider, "client") as mock_client:
mock_response = MagicMock()
mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])]
mock_response.usage = MagicMock(total_tokens=10)
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
result = await provider.embed(["Hello world"])
assert len(result) == 1
assert len(result[0]) == 3
@pytest.mark.asyncio
async def test_qwen_provider_chat():
"""测试 Qwen Provider chat 方法mock兼容 OpenAI SDK"""
from app.providers.qwen_provider import QwenProvider
provider = QwenProvider(
api_key="sk-qwen-test",
model="qwen-turbo",
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
with patch.object(provider, "client") as mock_client:
mock_response = MagicMock()
mock_response.choices = [MagicMock(message=MagicMock(content="Qwen response"))]
mock_response.usage = MagicMock(total_tokens=80)
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
result = await provider.chat([LLMMessage(role="user", content="你好")])
assert result.content == "Qwen response"
@pytest.mark.asyncio
async def test_provider_registry():
"""测试 Provider 注册表"""
from app.providers import get_provider
# 根据 provider 类型返回正确的实例
openai_provider = get_provider("openai", api_key="sk-test", model="gpt-4")
assert openai_provider.__class__.__name__ == "OpenAIProvider"
qwen_provider = get_provider("qwen", api_key="sk-qwen", model="qwen-turbo")
assert qwen_provider.__class__.__name__ == "QwenProvider"
@pytest.mark.asyncio
async def test_provider_registry_invalid():
"""测试 Provider 注册表无效类型"""
from app.providers import get_provider
with pytest.raises(ValueError, match="Unknown provider"):
get_provider("invalid_provider", api_key="sk-test", model="model")
# ============= Task 2.3: Provider 集成到租户配置 =============
@pytest.mark.asyncio
async def test_get_provider_for_tenant(client: AsyncClient):
"""测试根据租户配置获取 Provider"""
from app.services.tenant_service import get_provider_for_tenant
from tests.conftest import TestSession
# 创建租户并配置
tenant_resp = await client.post(
"/api/v1/tenants",
json={"name": "Provider Test", "slug": "provider-test"},
)
tenant_id = tenant_resp.json()["id"]
await client.put(
f"/api/v1/tenants/{tenant_id}/config",
json={
"llm_provider": "openai",
"llm_api_key": "sk-test-provider-key",
"llm_model": "gpt-4o",
},
)
# 获取 Provider
async with TestSession() as session:
provider = await get_provider_for_tenant(session, tenant_id)
assert provider.__class__.__name__ == "OpenAIProvider"
assert provider.model == "gpt-4o"
@pytest.mark.asyncio
async def test_test_provider_endpoint(client: AsyncClient):
"""测试 Provider 测试端点"""
# 创建租户并配置
tenant_resp = await client.post(
"/api/v1/tenants",
json={"name": "Provider Endpoint Test", "slug": "provider-endpoint-test"},
)
tenant_id = tenant_resp.json()["id"]
await client.put(
f"/api/v1/tenants/{tenant_id}/config",
json={
"llm_provider": "openai",
"llm_api_key": "sk-test-endpoint",
"llm_model": "gpt-4",
},
)
# 测试 Provider
response = await client.post(f"/api/v1/tenants/{tenant_id}/test-provider")
assert response.status_code == 200
data = response.json()
assert data["provider"] == "openai"
assert data["model"] == "gpt-4"
assert "status" in data

View File

@ -0,0 +1,127 @@
import pytest
from httpx import AsyncClient
from app.models import Tenant
@pytest.mark.asyncio
async def test_create_tenant(client: AsyncClient):
response = await client.post(
"/api/v1/tenants",
json={"name": "Test Company", "slug": "test-company"},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "Test Company"
assert data["slug"] == "test-company"
assert "id" in data
@pytest.mark.asyncio
async def test_list_tenants(client: AsyncClient):
await client.post("/api/v1/tenants", json={"name": "A", "slug": "a"})
await client.post("/api/v1/tenants", json={"name": "B", "slug": "b"})
response = await client.get("/api/v1/tenants")
assert response.status_code == 200
data = response.json()
assert len(data) >= 2
@pytest.mark.asyncio
async def test_get_tenant(client: AsyncClient):
create_resp = await client.post(
"/api/v1/tenants",
json={"name": "Get Test", "slug": "get-test"},
)
tenant_id = create_resp.json()["id"]
response = await client.get(f"/api/v1/tenants/{tenant_id}")
assert response.status_code == 200
assert response.json()["name"] == "Get Test"
@pytest.mark.asyncio
async def test_update_tenant(client: AsyncClient):
create_resp = await client.post(
"/api/v1/tenants",
json={"name": "Old Name", "slug": "old-slug"},
)
tenant_id = create_resp.json()["id"]
response = await client.patch(
f"/api/v1/tenants/{tenant_id}",
json={"name": "New Name"},
)
assert response.status_code == 200
assert response.json()["name"] == "New Name"
@pytest.mark.asyncio
async def test_delete_tenant(client: AsyncClient):
create_resp = await client.post(
"/api/v1/tenants",
json={"name": "To Delete", "slug": "to-delete"},
)
tenant_id = create_resp.json()["id"]
response = await client.delete(f"/api/v1/tenants/{tenant_id}")
assert response.status_code == 204
get_resp = await client.get(f"/api/v1/tenants/{tenant_id}")
assert get_resp.status_code == 404
@pytest.mark.asyncio
async def test_update_tenant_config(client: AsyncClient):
create_resp = await client.post(
"/api/v1/tenants",
json={"name": "Config Test", "slug": "config-test"},
)
tenant_id = create_resp.json()["id"]
response = await client.put(
f"/api/v1/tenants/{tenant_id}/config",
json={
"llm_provider": "openai",
"llm_api_key": "sk-test-key-12345",
"llm_model": "gpt-4",
"llm_base_url": None,
},
)
assert response.status_code == 200
data = response.json()
assert data["llm_provider"] == "openai"
assert data["llm_model"] == "gpt-4"
assert "llm_api_key" not in data
@pytest.mark.asyncio
async def test_api_key_encrypted(client: AsyncClient):
from tests.conftest import TestSession
create_resp = await client.post(
"/api/v1/tenants",
json={"name": "Encrypt Test", "slug": "encrypt-test"},
)
tenant_id = create_resp.json()["id"]
await client.put(
f"/api/v1/tenants/{tenant_id}/config",
json={
"llm_provider": "openai",
"llm_api_key": "sk-secret-key-99999",
"llm_model": "gpt-4",
},
)
async with TestSession() as session:
from sqlalchemy import select
from app.models import TenantConfig
result = await session.execute(
select(TenantConfig).where(TenantConfig.tenant_id == tenant_id)
)
config = result.scalar_one()
assert config.llm_api_key != "sk-secret-key-99999"
assert len(config.llm_api_key) > 20

2735
backend/uv.lock generated Normal file

File diff suppressed because it is too large Load Diff

17
memory/2026-05-06.md Normal file
View File

@ -0,0 +1,17 @@
# 2026-05-06
## 纪要
- 工作区已初始化,可在此持续记录当天上下文。
## 进展
### B1 批次完成:项目脚手架 + 数据模型 + 租户管理
- ✅ Task 1.1 项目脚手架
- ✅ Task 1.2 数据模型7个模型
- ✅ Task 1.3 租户管理 API7个测试全部通过
- 实现schemas/tenant.py, api/deps.py, services/tenant_service.py, api/v1/tenants.py
- 修复conftest.py fixture 顺序和 get_db 依赖覆盖
### 下一步
- B2 批次:数字员工配置 + LLM Provider 抽象层