feat(backend): 数字员工平台 B1+B2 批次实现
B1: 项目脚手架 + 数据模型 + 租户管理 - Task 1.1: FastAPI 项目脚手架、SQLite + async SQLAlchemy - Task 1.2: 7 个数据模型 (Tenant, TenantConfig, DigitalEmployee, Conversation, Message, KnowledgeBase, Document) - Task 1.3: 租户 CRUD API + LLM 配置(含 API Key AES 加密) B2: 数字员工配置 + LLM Provider 抽象层 - Task 2.1: 数字员工 CRUD API(关联知识库) - Task 2.2: BaseLLMProvider 抽象接口 + OpenAI/Qwen Provider - Task 2.3: Provider 动态实例化 + test-provider 端点 验证: 26 个测试全部通过 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
44c37420af
commit
c62156af53
16
backend/.env.example
Normal file
16
backend/.env.example
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# LLM Provider
|
||||||
|
OPENAI_API_KEY=sk-your-openai-key
|
||||||
|
QWEN_API_KEY=sk-your-qwen-key
|
||||||
|
|
||||||
|
# Encryption key for storing tenant LLM API keys (generate: python -c "import secrets; print(secrets.token_hex(32))")
|
||||||
|
ENCRYPTION_KEY=your-32-byte-hex-key
|
||||||
|
|
||||||
|
# Database
|
||||||
|
DATABASE_URL=sqlite+aiosqlite:///./data.db
|
||||||
|
|
||||||
|
# Upload
|
||||||
|
UPLOAD_DIR=./uploads
|
||||||
|
MAX_FILE_SIZE_MB=20
|
||||||
|
|
||||||
|
# Rate limiting
|
||||||
|
RATE_LIMIT_PER_MINUTE=60
|
||||||
0
backend/app/__init__.py
Normal file
0
backend/app/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
11
backend/app/api/deps.py
Normal file
11
backend/app/api/deps.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.database import async_session
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
async with async_session() as session:
|
||||||
|
yield session
|
||||||
7
backend/app/api/v1/__init__.py
Normal file
7
backend/app/api/v1/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from app.api.v1 import employees, tenants
|
||||||
|
|
||||||
|
api_router = APIRouter(prefix="/api/v1")
|
||||||
|
api_router.include_router(tenants.router)
|
||||||
|
api_router.include_router(employees.router)
|
||||||
76
backend/app/api/v1/employees.py
Normal file
76
backend/app/api/v1/employees.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from fastapi import APIRouter, Depends, Query, status
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_db
|
||||||
|
from app.schemas.employee import (
|
||||||
|
EmployeeCreate,
|
||||||
|
EmployeeResponse,
|
||||||
|
EmployeeUpdate,
|
||||||
|
)
|
||||||
|
from app.services import employee_service
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/employees", tags=["employees"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=EmployeeResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_employee(
|
||||||
|
data: EmployeeCreate, session: AsyncSession = Depends(get_db)
|
||||||
|
) -> EmployeeResponse:
|
||||||
|
employee = await employee_service.create_employee(
|
||||||
|
session,
|
||||||
|
tenant_id=data.tenant_id,
|
||||||
|
name=data.name,
|
||||||
|
role=data.role,
|
||||||
|
system_prompt=data.system_prompt,
|
||||||
|
greeting=data.greeting,
|
||||||
|
avatar_url=data.avatar_url,
|
||||||
|
temperature=data.temperature,
|
||||||
|
max_context_messages=data.max_context_messages,
|
||||||
|
knowledge_base_ids=data.knowledge_base_ids,
|
||||||
|
)
|
||||||
|
return EmployeeResponse.from_model(employee)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=list[EmployeeResponse])
|
||||||
|
async def list_employees(
|
||||||
|
tenant_id: str = Query(...), session: AsyncSession = Depends(get_db)
|
||||||
|
) -> list[EmployeeResponse]:
|
||||||
|
employees = await employee_service.list_employees(session, tenant_id)
|
||||||
|
return [EmployeeResponse.from_model(e) for e in employees]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{employee_id}", response_model=EmployeeResponse)
|
||||||
|
async def get_employee(
|
||||||
|
employee_id: str, session: AsyncSession = Depends(get_db)
|
||||||
|
) -> EmployeeResponse:
|
||||||
|
employee = await employee_service.get_employee(session, employee_id)
|
||||||
|
return EmployeeResponse.from_model(employee)
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/{employee_id}", response_model=EmployeeResponse)
|
||||||
|
async def update_employee(
|
||||||
|
employee_id: str,
|
||||||
|
data: EmployeeUpdate,
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
) -> EmployeeResponse:
|
||||||
|
employee = await employee_service.update_employee(
|
||||||
|
session,
|
||||||
|
employee_id,
|
||||||
|
name=data.name,
|
||||||
|
role=data.role,
|
||||||
|
system_prompt=data.system_prompt,
|
||||||
|
greeting=data.greeting,
|
||||||
|
avatar_url=data.avatar_url,
|
||||||
|
temperature=data.temperature,
|
||||||
|
max_context_messages=data.max_context_messages,
|
||||||
|
knowledge_base_ids=data.knowledge_base_ids,
|
||||||
|
status=data.status,
|
||||||
|
)
|
||||||
|
return EmployeeResponse.from_model(employee)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{employee_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_employee(
|
||||||
|
employee_id: str, session: AsyncSession = Depends(get_db)
|
||||||
|
) -> None:
|
||||||
|
await employee_service.delete_employee(session, employee_id)
|
||||||
128
backend/app/api/v1/tenants.py
Normal file
128
backend/app/api/v1/tenants.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
from fastapi import APIRouter, Depends, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_db
|
||||||
|
from app.schemas.tenant import (
|
||||||
|
TenantConfigCreate,
|
||||||
|
TenantConfigResponse,
|
||||||
|
TenantCreate,
|
||||||
|
TenantResponse,
|
||||||
|
TenantUpdate,
|
||||||
|
)
|
||||||
|
from app.services import tenant_service
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/tenants", tags=["tenants"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=TenantResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_tenant(
|
||||||
|
data: TenantCreate, session: AsyncSession = Depends(get_db)
|
||||||
|
) -> TenantResponse:
|
||||||
|
tenant = await tenant_service.create_tenant(session, data.name, data.slug)
|
||||||
|
return TenantResponse(
|
||||||
|
id=tenant.id,
|
||||||
|
name=tenant.name,
|
||||||
|
slug=tenant.slug,
|
||||||
|
status=tenant.status.value,
|
||||||
|
created_at=tenant.created_at,
|
||||||
|
updated_at=tenant.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=list[TenantResponse])
|
||||||
|
async def list_tenants(session: AsyncSession = Depends(get_db)) -> list[TenantResponse]:
|
||||||
|
tenants = await tenant_service.list_tenants(session)
|
||||||
|
return [
|
||||||
|
TenantResponse(
|
||||||
|
id=t.id,
|
||||||
|
name=t.name,
|
||||||
|
slug=t.slug,
|
||||||
|
status=t.status.value,
|
||||||
|
created_at=t.created_at,
|
||||||
|
updated_at=t.updated_at,
|
||||||
|
)
|
||||||
|
for t in tenants
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{tenant_id}", response_model=TenantResponse)
|
||||||
|
async def get_tenant(
|
||||||
|
tenant_id: str, session: AsyncSession = Depends(get_db)
|
||||||
|
) -> TenantResponse:
|
||||||
|
tenant = await tenant_service.get_tenant(session, tenant_id)
|
||||||
|
return TenantResponse(
|
||||||
|
id=tenant.id,
|
||||||
|
name=tenant.name,
|
||||||
|
slug=tenant.slug,
|
||||||
|
status=tenant.status.value,
|
||||||
|
created_at=tenant.created_at,
|
||||||
|
updated_at=tenant.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/{tenant_id}", response_model=TenantResponse)
|
||||||
|
async def update_tenant(
|
||||||
|
tenant_id: str,
|
||||||
|
data: TenantUpdate,
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
) -> TenantResponse:
|
||||||
|
tenant = await tenant_service.update_tenant(session, tenant_id, data.name, data.slug)
|
||||||
|
return TenantResponse(
|
||||||
|
id=tenant.id,
|
||||||
|
name=tenant.name,
|
||||||
|
slug=tenant.slug,
|
||||||
|
status=tenant.status.value,
|
||||||
|
created_at=tenant.created_at,
|
||||||
|
updated_at=tenant.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{tenant_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_tenant(tenant_id: str, session: AsyncSession = Depends(get_db)) -> None:
|
||||||
|
await tenant_service.delete_tenant(session, tenant_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{tenant_id}/config", response_model=TenantConfigResponse)
|
||||||
|
async def update_tenant_config(
|
||||||
|
tenant_id: str,
|
||||||
|
data: TenantConfigCreate,
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
) -> TenantConfigResponse:
|
||||||
|
config = await tenant_service.update_tenant_config(
|
||||||
|
session,
|
||||||
|
tenant_id,
|
||||||
|
data.llm_provider,
|
||||||
|
data.llm_api_key,
|
||||||
|
data.llm_model,
|
||||||
|
data.llm_base_url,
|
||||||
|
data.max_tokens_per_month or 1000000,
|
||||||
|
)
|
||||||
|
return TenantConfigResponse(
|
||||||
|
id=config.id,
|
||||||
|
tenant_id=config.tenant_id,
|
||||||
|
llm_provider=config.llm_provider,
|
||||||
|
llm_model=config.llm_model,
|
||||||
|
llm_base_url=config.llm_base_url,
|
||||||
|
max_tokens_per_month=config.max_tokens_per_month,
|
||||||
|
created_at=config.created_at,
|
||||||
|
updated_at=config.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderTestResponse(BaseModel):
|
||||||
|
provider: str
|
||||||
|
model: str
|
||||||
|
status: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{tenant_id}/test-provider", response_model=ProviderTestResponse)
|
||||||
|
async def test_provider(
|
||||||
|
tenant_id: str, session: AsyncSession = Depends(get_db)
|
||||||
|
) -> ProviderTestResponse:
|
||||||
|
provider = await tenant_service.get_provider_for_tenant(session, tenant_id)
|
||||||
|
return ProviderTestResponse(
|
||||||
|
provider=provider.__class__.__name__.replace("Provider", "").lower(),
|
||||||
|
model=provider.model,
|
||||||
|
status="ok",
|
||||||
|
)
|
||||||
15
backend/app/config.py
Normal file
15
backend/app/config.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
database_url: str = "sqlite+aiosqlite:///./data.db"
|
||||||
|
secret_key: str = "dev-secret-key-change-in-production"
|
||||||
|
encryption_salt: str = "dev-salt-change-in-production"
|
||||||
|
upload_dir: str = "./uploads"
|
||||||
|
max_file_size_mb: int = 20
|
||||||
|
rate_limit_per_minute: int = 60
|
||||||
|
|
||||||
|
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
21
backend/app/database.py
Normal file
21
backend/app/database.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
engine = create_async_engine(settings.database_url, echo=False)
|
||||||
|
async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db() -> AsyncSession: # type: ignore[misc]
|
||||||
|
async with async_session() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
async def init_db() -> None:
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
25
backend/app/main.py
Normal file
25
backend/app/main.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from app.api.v1 import api_router
|
||||||
|
from app.database import init_db
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
await init_db()
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Digital Employee Platform",
|
||||||
|
version="0.1.0",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
app.include_router(api_router)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
return {"status": "ok"}
|
||||||
0
backend/app/middleware/__init__.py
Normal file
0
backend/app/middleware/__init__.py
Normal file
18
backend/app/models/__init__.py
Normal file
18
backend/app/models/__init__.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from app.models.conversation import Conversation, Message, MessageRole
|
||||||
|
from app.models.employee import DigitalEmployee, EmployeeStatus
|
||||||
|
from app.models.knowledge import Document, DocumentStatus, KnowledgeBase
|
||||||
|
from app.models.tenant import Tenant, TenantConfig, TenantStatus
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Tenant",
|
||||||
|
"TenantConfig",
|
||||||
|
"TenantStatus",
|
||||||
|
"DigitalEmployee",
|
||||||
|
"EmployeeStatus",
|
||||||
|
"Conversation",
|
||||||
|
"Message",
|
||||||
|
"MessageRole",
|
||||||
|
"KnowledgeBase",
|
||||||
|
"Document",
|
||||||
|
"DocumentStatus",
|
||||||
|
]
|
||||||
58
backend/app/models/conversation.py
Normal file
58
backend/app/models/conversation.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import enum
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, Enum, ForeignKey, Integer, String, Text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
def _uuid() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def _now() -> datetime:
|
||||||
|
return datetime.utcnow()
|
||||||
|
|
||||||
|
|
||||||
|
class MessageRole(str, enum.Enum):
|
||||||
|
system = "system"
|
||||||
|
user = "user"
|
||||||
|
assistant = "assistant"
|
||||||
|
|
||||||
|
|
||||||
|
class Conversation(Base):
|
||||||
|
__tablename__ = "conversations"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||||
|
tenant_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), ForeignKey("tenants.id"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
employee_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), ForeignKey("digital_employees.id"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||||
|
title: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=_now)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=_now, onupdate=_now)
|
||||||
|
|
||||||
|
messages: Mapped[list["Message"]] = relationship(
|
||||||
|
back_populates="conversation", cascade="all, delete-orphan", lazy="selectin"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Message(Base):
|
||||||
|
__tablename__ = "messages"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||||
|
conversation_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), ForeignKey("conversations.id"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
role: Mapped[MessageRole] = mapped_column(Enum(MessageRole), nullable=False)
|
||||||
|
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
token_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||||
|
sources: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=_now)
|
||||||
|
|
||||||
|
conversation: Mapped["Conversation"] = relationship(back_populates="messages")
|
||||||
45
backend/app/models/employee.py
Normal file
45
backend/app/models/employee.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import enum
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, Enum, Float, ForeignKey, Integer, String, Text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
def _uuid() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def _now() -> datetime:
|
||||||
|
return datetime.utcnow()
|
||||||
|
|
||||||
|
|
||||||
|
class EmployeeStatus(str, enum.Enum):
|
||||||
|
active = "active"
|
||||||
|
inactive = "inactive"
|
||||||
|
|
||||||
|
|
||||||
|
class DigitalEmployee(Base):
|
||||||
|
__tablename__ = "digital_employees"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||||
|
tenant_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), ForeignKey("tenants.id"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
name: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||||
|
role: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
avatar_url: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
|
system_prompt: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
greeting: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
temperature: Mapped[float] = mapped_column(Float, default=0.7)
|
||||||
|
max_context_messages: Mapped[int] = mapped_column(Integer, default=20)
|
||||||
|
knowledge_base_ids: Mapped[str] = mapped_column(Text, default="[]")
|
||||||
|
status: Mapped[EmployeeStatus] = mapped_column(
|
||||||
|
Enum(EmployeeStatus), default=EmployeeStatus.active
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=_now)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=_now, onupdate=_now)
|
||||||
|
|
||||||
|
tenant = relationship("Tenant", lazy="selectin")
|
||||||
63
backend/app/models/knowledge.py
Normal file
63
backend/app/models/knowledge.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import enum
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, Enum, ForeignKey, Integer, String, Text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
def _uuid() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def _now() -> datetime:
|
||||||
|
return datetime.utcnow()
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentStatus(str, enum.Enum):
|
||||||
|
pending = "pending"
|
||||||
|
processing = "processing"
|
||||||
|
completed = "completed"
|
||||||
|
failed = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeBase(Base):
|
||||||
|
__tablename__ = "knowledge_bases"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||||
|
tenant_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), ForeignKey("tenants.id"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
name: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||||
|
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
embedding_model: Mapped[str] = mapped_column(String(100), default="text-embedding-3-small")
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=_now)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=_now, onupdate=_now)
|
||||||
|
|
||||||
|
documents: Mapped[list["Document"]] = relationship(
|
||||||
|
back_populates="knowledge_base", cascade="all, delete-orphan", lazy="selectin"
|
||||||
|
)
|
||||||
|
tenant = relationship("Tenant", lazy="selectin")
|
||||||
|
|
||||||
|
|
||||||
|
class Document(Base):
|
||||||
|
__tablename__ = "documents"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||||
|
knowledge_base_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), ForeignKey("knowledge_bases.id"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
filename: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
file_type: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||||
|
file_size: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
chunk_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
status: Mapped[DocumentStatus] = mapped_column(
|
||||||
|
Enum(DocumentStatus), default=DocumentStatus.pending
|
||||||
|
)
|
||||||
|
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=_now)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=_now, onupdate=_now)
|
||||||
|
|
||||||
|
knowledge_base: Mapped["KnowledgeBase"] = relationship(back_populates="documents")
|
||||||
55
backend/app/models/tenant.py
Normal file
55
backend/app/models/tenant.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import enum
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, Enum, Float, ForeignKey, Integer, String, Text, func
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
def _uuid() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def _now() -> datetime:
|
||||||
|
return datetime.utcnow()
|
||||||
|
|
||||||
|
|
||||||
|
class TenantStatus(str, enum.Enum):
|
||||||
|
active = "active"
|
||||||
|
suspended = "suspended"
|
||||||
|
deleted = "deleted"
|
||||||
|
|
||||||
|
|
||||||
|
class Tenant(Base):
|
||||||
|
__tablename__ = "tenants"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||||
|
name: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||||
|
slug: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
|
||||||
|
status: Mapped[TenantStatus] = mapped_column(
|
||||||
|
Enum(TenantStatus), default=TenantStatus.active
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=_now)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=_now, onupdate=_now)
|
||||||
|
|
||||||
|
config: Mapped["TenantConfig | None"] = relationship(back_populates="tenant", uselist=False)
|
||||||
|
|
||||||
|
|
||||||
|
class TenantConfig(Base):
|
||||||
|
__tablename__ = "tenant_configs"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||||
|
tenant_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), ForeignKey("tenants.id"), unique=True, nullable=False
|
||||||
|
)
|
||||||
|
llm_provider: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
|
llm_api_key: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
llm_model: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
llm_base_url: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
|
max_tokens_per_month: Mapped[int] = mapped_column(Integer, default=1000000)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=_now)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=_now, onupdate=_now)
|
||||||
|
|
||||||
|
tenant: Mapped["Tenant"] = relationship(back_populates="config")
|
||||||
23
backend/app/providers/__init__.py
Normal file
23
backend/app/providers/__init__.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
"""LLM Provider 注册表"""
|
||||||
|
from app.providers.base import BaseLLMProvider
|
||||||
|
from app.providers.openai_provider import OpenAIProvider
|
||||||
|
from app.providers.qwen_provider import QwenProvider
|
||||||
|
|
||||||
|
PROVIDERS = {
|
||||||
|
"openai": OpenAIProvider,
|
||||||
|
"qwen": QwenProvider,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider(
|
||||||
|
provider_type: str,
|
||||||
|
api_key: str,
|
||||||
|
model: str,
|
||||||
|
base_url: str | None = None,
|
||||||
|
) -> BaseLLMProvider:
|
||||||
|
"""根据类型获取 Provider 实例"""
|
||||||
|
if provider_type not in PROVIDERS:
|
||||||
|
raise ValueError(f"Unknown provider: {provider_type}")
|
||||||
|
|
||||||
|
provider_cls = PROVIDERS[provider_type]
|
||||||
|
return provider_cls(api_key=api_key, model=model, base_url=base_url)
|
||||||
35
backend/app/providers/base.py
Normal file
35
backend/app/providers/base.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
"""LLM Provider 基础抽象层"""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMMessage:
|
||||||
|
role: str # "system", "user", "assistant"
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMResponse:
|
||||||
|
content: str
|
||||||
|
model: str
|
||||||
|
usage: dict # {"total_tokens": int, "prompt_tokens": int, "completion_tokens": int}
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLMProvider(ABC):
|
||||||
|
"""LLM Provider 抽象基类"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chat(self, messages: list[LLMMessage]) -> LLMResponse:
|
||||||
|
"""非流式对话"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chat_stream(self, messages: list[LLMMessage]) -> str:
|
||||||
|
"""流式对话,返回 token 异步生成器"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
"""生成嵌入向量"""
|
||||||
|
pass
|
||||||
43
backend/app/providers/openai_provider.py
Normal file
43
backend/app/providers/openai_provider.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
"""OpenAI Provider"""
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from app.providers.base import BaseLLMProvider, LLMMessage, LLMResponse
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIProvider(BaseLLMProvider):
|
||||||
|
def __init__(self, api_key: str, model: str = "gpt-4", base_url: str | None = None):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.model = model
|
||||||
|
self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
|
async def chat(self, messages: list[LLMMessage]) -> LLMResponse:
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[{"role": m.role, "content": m.content} for m in messages],
|
||||||
|
)
|
||||||
|
return LLMResponse(
|
||||||
|
content=response.choices[0].message.content or "",
|
||||||
|
model=response.model,
|
||||||
|
usage={
|
||||||
|
"total_tokens": response.usage.total_tokens,
|
||||||
|
"prompt_tokens": response.usage.prompt_tokens,
|
||||||
|
"completion_tokens": response.usage.completion_tokens,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_stream(self, messages: list[LLMMessage]) -> str:
|
||||||
|
stream = await self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[{"role": m.role, "content": m.content} for m in messages],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk.choices[0].delta.content:
|
||||||
|
yield chunk.choices[0].delta.content
|
||||||
|
|
||||||
|
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
response = await self.client.embeddings.create(
|
||||||
|
model=self.model,
|
||||||
|
input=texts,
|
||||||
|
)
|
||||||
|
return [item.embedding for item in response.data]
|
||||||
51
backend/app/providers/qwen_provider.py
Normal file
51
backend/app/providers/qwen_provider.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
"""Qwen Provider (兼容 OpenAI SDK)"""
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from app.providers.base import BaseLLMProvider, LLMMessage, LLMResponse
|
||||||
|
|
||||||
|
|
||||||
|
class QwenProvider(BaseLLMProvider):
|
||||||
|
"""通义千问 Provider,使用 OpenAI SDK 兼容模式"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
model: str = "qwen-turbo",
|
||||||
|
base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
|
):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.model = model
|
||||||
|
self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
|
async def chat(self, messages: list[LLMMessage]) -> LLMResponse:
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[{"role": m.role, "content": m.content} for m in messages],
|
||||||
|
)
|
||||||
|
return LLMResponse(
|
||||||
|
content=response.choices[0].message.content or "",
|
||||||
|
model=response.model,
|
||||||
|
usage={
|
||||||
|
"total_tokens": response.usage.total_tokens,
|
||||||
|
"prompt_tokens": response.usage.prompt_tokens,
|
||||||
|
"completion_tokens": response.usage.completion_tokens,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_stream(self, messages: list[LLMMessage]) -> str:
|
||||||
|
stream = await self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[{"role": m.role, "content": m.content} for m in messages],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk.choices[0].delta.content:
|
||||||
|
yield chunk.choices[0].delta.content
|
||||||
|
|
||||||
|
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
# Qwen embedding 使用不同接口
|
||||||
|
response = await self.client.embeddings.create(
|
||||||
|
model="text-embedding-v3",
|
||||||
|
input=texts,
|
||||||
|
)
|
||||||
|
return [item.embedding for item in response.data]
|
||||||
0
backend/app/schemas/__init__.py
Normal file
0
backend/app/schemas/__init__.py
Normal file
63
backend/app/schemas/employee.py
Normal file
63
backend/app/schemas/employee.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class EmployeeCreate(BaseModel):
|
||||||
|
tenant_id: str
|
||||||
|
name: str = Field(..., max_length=200)
|
||||||
|
role: str = Field(..., max_length=100)
|
||||||
|
system_prompt: str
|
||||||
|
greeting: str | None = None
|
||||||
|
avatar_url: str | None = Field(None, max_length=500)
|
||||||
|
temperature: float = 0.7
|
||||||
|
max_context_messages: int = 20
|
||||||
|
knowledge_base_ids: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
class EmployeeUpdate(BaseModel):
|
||||||
|
name: str | None = Field(None, max_length=200)
|
||||||
|
role: str | None = Field(None, max_length=100)
|
||||||
|
system_prompt: str | None = None
|
||||||
|
greeting: str | None = None
|
||||||
|
avatar_url: str | None = Field(None, max_length=500)
|
||||||
|
temperature: float | None = None
|
||||||
|
max_context_messages: int | None = None
|
||||||
|
knowledge_base_ids: list[str] | None = None
|
||||||
|
status: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class EmployeeResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
tenant_id: str
|
||||||
|
name: str
|
||||||
|
role: str
|
||||||
|
system_prompt: str
|
||||||
|
greeting: str | None
|
||||||
|
avatar_url: str | None
|
||||||
|
temperature: float
|
||||||
|
max_context_messages: int
|
||||||
|
knowledge_base_ids: list[str]
|
||||||
|
status: str
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_model(cls, employee) -> "EmployeeResponse":
|
||||||
|
kb_ids = json.loads(employee.knowledge_base_ids) if employee.knowledge_base_ids else []
|
||||||
|
return cls(
|
||||||
|
id=employee.id,
|
||||||
|
tenant_id=employee.tenant_id,
|
||||||
|
name=employee.name,
|
||||||
|
role=employee.role,
|
||||||
|
system_prompt=employee.system_prompt,
|
||||||
|
greeting=employee.greeting,
|
||||||
|
avatar_url=employee.avatar_url,
|
||||||
|
temperature=employee.temperature,
|
||||||
|
max_context_messages=employee.max_context_messages,
|
||||||
|
knowledge_base_ids=kb_ids,
|
||||||
|
status=employee.status.value,
|
||||||
|
created_at=employee.created_at,
|
||||||
|
updated_at=employee.updated_at,
|
||||||
|
)
|
||||||
41
backend/app/schemas/tenant.py
Normal file
41
backend/app/schemas/tenant.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class TenantCreate(BaseModel):
|
||||||
|
name: str = Field(..., max_length=200)
|
||||||
|
slug: str = Field(..., max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class TenantUpdate(BaseModel):
|
||||||
|
name: str | None = Field(None, max_length=200)
|
||||||
|
slug: str | None = Field(None, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class TenantResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
slug: str
|
||||||
|
status: str
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class TenantConfigCreate(BaseModel):
|
||||||
|
llm_provider: str = Field(..., max_length=50)
|
||||||
|
llm_api_key: str
|
||||||
|
llm_model: str = Field(..., max_length=100)
|
||||||
|
llm_base_url: str | None = Field(None, max_length=500)
|
||||||
|
max_tokens_per_month: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TenantConfigResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
tenant_id: str
|
||||||
|
llm_provider: str
|
||||||
|
llm_model: str
|
||||||
|
llm_base_url: str | None
|
||||||
|
max_tokens_per_month: int
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
33
backend/app/services/crypto.py
Normal file
33
backend/app/services/crypto.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from cryptography.hazmat.primitives import hashes
|
||||||
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
def _get_encryption_key() -> bytes:
|
||||||
|
kdf = PBKDF2HMAC(
|
||||||
|
algorithm=hashes.SHA256(),
|
||||||
|
length=32,
|
||||||
|
salt=settings.encryption_salt.encode(),
|
||||||
|
iterations=480000,
|
||||||
|
)
|
||||||
|
return base64.urlsafe_b64encode(kdf.derive(settings.secret_key.encode()))
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_api_key(plaintext: str) -> str:
|
||||||
|
if not plaintext:
|
||||||
|
return ""
|
||||||
|
fernet = Fernet(_get_encryption_key())
|
||||||
|
encrypted = fernet.encrypt(plaintext.encode())
|
||||||
|
return base64.urlsafe_b64encode(encrypted).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_api_key(ciphertext: str) -> str:
|
||||||
|
if not ciphertext:
|
||||||
|
return ""
|
||||||
|
fernet = Fernet(_get_encryption_key())
|
||||||
|
encrypted = base64.urlsafe_b64decode(ciphertext.encode())
|
||||||
|
return fernet.decrypt(encrypted).decode()
|
||||||
104
backend/app/services/employee_service.py
Normal file
104
backend/app/services/employee_service.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models import DigitalEmployee, EmployeeStatus
|
||||||
|
|
||||||
|
|
||||||
|
async def create_employee(
|
||||||
|
session: AsyncSession,
|
||||||
|
tenant_id: str,
|
||||||
|
name: str,
|
||||||
|
role: str,
|
||||||
|
system_prompt: str,
|
||||||
|
greeting: str | None = None,
|
||||||
|
avatar_url: str | None = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_context_messages: int = 20,
|
||||||
|
knowledge_base_ids: list[str] | None = None,
|
||||||
|
) -> DigitalEmployee:
|
||||||
|
employee = DigitalEmployee(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
name=name,
|
||||||
|
role=role,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
greeting=greeting,
|
||||||
|
avatar_url=avatar_url,
|
||||||
|
temperature=temperature,
|
||||||
|
max_context_messages=max_context_messages,
|
||||||
|
knowledge_base_ids=json.dumps(knowledge_base_ids or []),
|
||||||
|
)
|
||||||
|
session.add(employee)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(employee)
|
||||||
|
return employee
|
||||||
|
|
||||||
|
|
||||||
|
async def get_employee(session: AsyncSession, employee_id: str) -> DigitalEmployee:
|
||||||
|
result = await session.execute(
|
||||||
|
select(DigitalEmployee).where(
|
||||||
|
DigitalEmployee.id == employee_id,
|
||||||
|
DigitalEmployee.status != EmployeeStatus.inactive,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
employee = result.scalar_one_or_none()
|
||||||
|
if not employee:
|
||||||
|
raise HTTPException(status_code=404, detail="Employee not found")
|
||||||
|
return employee
|
||||||
|
|
||||||
|
|
||||||
|
async def list_employees(session: AsyncSession, tenant_id: str) -> list[DigitalEmployee]:
|
||||||
|
result = await session.execute(
|
||||||
|
select(DigitalEmployee).where(
|
||||||
|
DigitalEmployee.tenant_id == tenant_id,
|
||||||
|
DigitalEmployee.status != EmployeeStatus.inactive,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
|
||||||
|
async def update_employee(
|
||||||
|
session: AsyncSession,
|
||||||
|
employee_id: str,
|
||||||
|
name: str | None = None,
|
||||||
|
role: str | None = None,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
greeting: str | None = None,
|
||||||
|
avatar_url: str | None = None,
|
||||||
|
temperature: float | None = None,
|
||||||
|
max_context_messages: int | None = None,
|
||||||
|
knowledge_base_ids: list[str] | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
) -> DigitalEmployee:
|
||||||
|
employee = await get_employee(session, employee_id)
|
||||||
|
|
||||||
|
if name is not None:
|
||||||
|
employee.name = name
|
||||||
|
if role is not None:
|
||||||
|
employee.role = role
|
||||||
|
if system_prompt is not None:
|
||||||
|
employee.system_prompt = system_prompt
|
||||||
|
if greeting is not None:
|
||||||
|
employee.greeting = greeting
|
||||||
|
if avatar_url is not None:
|
||||||
|
employee.avatar_url = avatar_url
|
||||||
|
if temperature is not None:
|
||||||
|
employee.temperature = temperature
|
||||||
|
if max_context_messages is not None:
|
||||||
|
employee.max_context_messages = max_context_messages
|
||||||
|
if knowledge_base_ids is not None:
|
||||||
|
employee.knowledge_base_ids = json.dumps(knowledge_base_ids)
|
||||||
|
if status is not None:
|
||||||
|
employee.status = EmployeeStatus(status)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(employee)
|
||||||
|
return employee
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_employee(session: AsyncSession, employee_id: str) -> None:
|
||||||
|
employee = await get_employee(session, employee_id)
|
||||||
|
employee.status = EmployeeStatus.inactive
|
||||||
|
await session.commit()
|
||||||
125
backend/app/services/tenant_service.py
Normal file
125
backend/app/services/tenant_service.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
from fastapi import HTTPException
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models import Tenant, TenantConfig, TenantStatus
|
||||||
|
from app.providers import get_provider
|
||||||
|
from app.providers.base import BaseLLMProvider
|
||||||
|
from app.services.crypto import decrypt_api_key, encrypt_api_key
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_for_tenant(
|
||||||
|
session: AsyncSession, tenant_id: str
|
||||||
|
) -> BaseLLMProvider:
|
||||||
|
"""根据租户配置动态实例化 Provider"""
|
||||||
|
await get_tenant(session, tenant_id)
|
||||||
|
|
||||||
|
result = await session.execute(
|
||||||
|
select(TenantConfig).where(TenantConfig.tenant_id == tenant_id)
|
||||||
|
)
|
||||||
|
config = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
raise HTTPException(status_code=404, detail="Tenant config not found")
|
||||||
|
|
||||||
|
api_key = decrypt_api_key(config.llm_api_key)
|
||||||
|
return get_provider(
|
||||||
|
provider_type=config.llm_provider,
|
||||||
|
api_key=api_key,
|
||||||
|
model=config.llm_model,
|
||||||
|
base_url=config.llm_base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_tenant(session: AsyncSession, name: str, slug: str) -> Tenant:
|
||||||
|
existing = await session.execute(select(Tenant).where(Tenant.slug == slug))
|
||||||
|
if existing.scalar_one_or_none():
|
||||||
|
raise HTTPException(status_code=400, detail="Slug already exists")
|
||||||
|
|
||||||
|
tenant = Tenant(name=name, slug=slug)
|
||||||
|
session.add(tenant)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(tenant)
|
||||||
|
return tenant
|
||||||
|
|
||||||
|
|
||||||
|
async def get_tenant(session: AsyncSession, tenant_id: str) -> Tenant:
|
||||||
|
result = await session.execute(
|
||||||
|
select(Tenant).where(Tenant.id == tenant_id, Tenant.status != TenantStatus.deleted)
|
||||||
|
)
|
||||||
|
tenant = result.scalar_one_or_none()
|
||||||
|
if not tenant:
|
||||||
|
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||||
|
return tenant
|
||||||
|
|
||||||
|
|
||||||
|
async def list_tenants(session: AsyncSession) -> list[Tenant]:
|
||||||
|
result = await session.execute(
|
||||||
|
select(Tenant).where(Tenant.status != TenantStatus.deleted)
|
||||||
|
)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
|
||||||
|
async def update_tenant(
|
||||||
|
session: AsyncSession, tenant_id: str, name: str | None, slug: str | None
|
||||||
|
) -> Tenant:
|
||||||
|
tenant = await get_tenant(session, tenant_id)
|
||||||
|
|
||||||
|
if slug and slug != tenant.slug:
|
||||||
|
existing = await session.execute(select(Tenant).where(Tenant.slug == slug))
|
||||||
|
if existing.scalar_one_or_none():
|
||||||
|
raise HTTPException(status_code=400, detail="Slug already exists")
|
||||||
|
tenant.slug = slug
|
||||||
|
|
||||||
|
if name:
|
||||||
|
tenant.name = name
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(tenant)
|
||||||
|
return tenant
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_tenant(session: AsyncSession, tenant_id: str) -> None:
|
||||||
|
tenant = await get_tenant(session, tenant_id)
|
||||||
|
tenant.status = TenantStatus.deleted
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def update_tenant_config(
|
||||||
|
session: AsyncSession,
|
||||||
|
tenant_id: str,
|
||||||
|
llm_provider: str,
|
||||||
|
llm_api_key: str,
|
||||||
|
llm_model: str,
|
||||||
|
llm_base_url: str | None = None,
|
||||||
|
max_tokens_per_month: int = 1000000,
|
||||||
|
) -> TenantConfig:
|
||||||
|
await get_tenant(session, tenant_id)
|
||||||
|
|
||||||
|
result = await session.execute(
|
||||||
|
select(TenantConfig).where(TenantConfig.tenant_id == tenant_id)
|
||||||
|
)
|
||||||
|
config = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
encrypted_key = encrypt_api_key(llm_api_key)
|
||||||
|
|
||||||
|
if config:
|
||||||
|
config.llm_provider = llm_provider
|
||||||
|
config.llm_api_key = encrypted_key
|
||||||
|
config.llm_model = llm_model
|
||||||
|
config.llm_base_url = llm_base_url
|
||||||
|
config.max_tokens_per_month = max_tokens_per_month
|
||||||
|
else:
|
||||||
|
config = TenantConfig(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
llm_provider=llm_provider,
|
||||||
|
llm_api_key=encrypted_key,
|
||||||
|
llm_model=llm_model,
|
||||||
|
llm_base_url=llm_base_url,
|
||||||
|
max_tokens_per_month=max_tokens_per_month,
|
||||||
|
)
|
||||||
|
session.add(config)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(config)
|
||||||
|
return config
|
||||||
44
backend/pyproject.toml
Normal file
44
backend/pyproject.toml
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
[project]
|
||||||
|
name = "digital-employee-platform"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Multi-tenant digital employee platform with RAG knowledge base"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"fastapi>=0.115.0",
|
||||||
|
"uvicorn[standard]>=0.30.0",
|
||||||
|
"sqlalchemy[asyncio]>=2.0.0",
|
||||||
|
"aiosqlite>=0.20.0",
|
||||||
|
"pydantic>=2.0.0",
|
||||||
|
"pydantic-settings>=2.0.0",
|
||||||
|
"httpx>=0.27.0",
|
||||||
|
"cryptography>=43.0.0",
|
||||||
|
"openai>=1.40.0",
|
||||||
|
"chromadb>=0.5.0",
|
||||||
|
"pypdf2>=3.0.0",
|
||||||
|
"python-docx>=1.1.0",
|
||||||
|
"python-multipart>=0.0.9",
|
||||||
|
"sse-starlette>=2.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
"pytest-asyncio>=0.24.0",
|
||||||
|
"ruff>=0.6.0",
|
||||||
|
"mypy>=1.11.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
testpaths = ["tests"]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
target-version = "py311"
|
||||||
|
line-length = 100
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = ["E", "F", "I", "N", "W", "UP"]
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.11"
|
||||||
|
strict = true
|
||||||
0
backend/tests/__init__.py
Normal file
0
backend/tests/__init__.py
Normal file
46
backend/tests/conftest.py
Normal file
46
backend/tests/conftest.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from app.api.deps import get_db
|
||||||
|
from app.database import Base
|
||||||
|
from app.main import app
|
||||||
|
|
||||||
|
TEST_DB_URL = "sqlite+aiosqlite:///:memory:"
|
||||||
|
|
||||||
|
test_engine = create_async_engine(TEST_DB_URL, echo=False)
|
||||||
|
TestSession = async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def event_loop():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="function")
|
||||||
|
async def setup_db():
|
||||||
|
async with test_engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
yield
|
||||||
|
async with test_engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
|
|
||||||
|
|
||||||
|
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
async with TestSession() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client(setup_db) -> AsyncGenerator[AsyncClient, None]:
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||||
|
yield c
|
||||||
|
app.dependency_overrides.clear()
|
||||||
188
backend/tests/test_employees.py
Normal file
188
backend/tests/test_employees.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_employee(client: AsyncClient):
|
||||||
|
"""创建数字员工"""
|
||||||
|
# 先创建租户
|
||||||
|
tenant_resp = await client.post(
|
||||||
|
"/api/v1/tenants",
|
||||||
|
json={"name": "Employee Test", "slug": "employee-test"},
|
||||||
|
)
|
||||||
|
tenant_id = tenant_resp.json()["id"]
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/employees",
|
||||||
|
json={
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"name": "客服助手",
|
||||||
|
"role": "customer_service",
|
||||||
|
"system_prompt": "你是一个专业的客服助手。",
|
||||||
|
"greeting": "你好,有什么可以帮助您的?",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_context_messages": 20,
|
||||||
|
"knowledge_base_ids": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 201
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "客服助手"
|
||||||
|
assert data["role"] == "customer_service"
|
||||||
|
assert data["status"] == "active"
|
||||||
|
assert "id" in data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_employee_with_knowledge_bases(client: AsyncClient):
|
||||||
|
"""创建数字员工时关联知识库"""
|
||||||
|
tenant_resp = await client.post(
|
||||||
|
"/api/v1/tenants",
|
||||||
|
json={"name": "KB Test", "slug": "kb-test"},
|
||||||
|
)
|
||||||
|
tenant_id = tenant_resp.json()["id"]
|
||||||
|
|
||||||
|
# 创建知识库
|
||||||
|
from app.models import KnowledgeBase
|
||||||
|
from tests.conftest import TestSession
|
||||||
|
|
||||||
|
async with TestSession() as session:
|
||||||
|
kb1 = KnowledgeBase(tenant_id=tenant_id, name="产品文档")
|
||||||
|
kb2 = KnowledgeBase(tenant_id=tenant_id, name="FAQ")
|
||||||
|
session.add_all([kb1, kb2])
|
||||||
|
await session.commit()
|
||||||
|
kb1_id = kb1.id
|
||||||
|
kb2_id = kb2.id
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/employees",
|
||||||
|
json={
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"name": "智能客服",
|
||||||
|
"role": "customer_service",
|
||||||
|
"system_prompt": "你是一个智能客服。",
|
||||||
|
"knowledge_base_ids": [kb1_id, kb2_id],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 201
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["knowledge_base_ids"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_employees(client: AsyncClient):
|
||||||
|
"""列出数字员工"""
|
||||||
|
tenant_resp = await client.post(
|
||||||
|
"/api/v1/tenants",
|
||||||
|
json={"name": "List Test", "slug": "list-test"},
|
||||||
|
)
|
||||||
|
tenant_id = tenant_resp.json()["id"]
|
||||||
|
|
||||||
|
await client.post(
|
||||||
|
"/api/v1/employees",
|
||||||
|
json={
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"name": "员工A",
|
||||||
|
"role": "sales",
|
||||||
|
"system_prompt": "销售助手",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await client.post(
|
||||||
|
"/api/v1/employees",
|
||||||
|
json={
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"name": "员工B",
|
||||||
|
"role": "support",
|
||||||
|
"system_prompt": "支持助手",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.get(f"/api/v1/employees?tenant_id={tenant_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data) >= 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_employee(client: AsyncClient):
|
||||||
|
"""获取单个数字员工"""
|
||||||
|
tenant_resp = await client.post(
|
||||||
|
"/api/v1/tenants",
|
||||||
|
json={"name": "Get Emp Test", "slug": "get-emp-test"},
|
||||||
|
)
|
||||||
|
tenant_id = tenant_resp.json()["id"]
|
||||||
|
|
||||||
|
create_resp = await client.post(
|
||||||
|
"/api/v1/employees",
|
||||||
|
json={
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"name": "测试员工",
|
||||||
|
"role": "test",
|
||||||
|
"system_prompt": "测试",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
employee_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
response = await client.get(f"/api/v1/employees/{employee_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "测试员工"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_employee(client: AsyncClient):
|
||||||
|
"""更新数字员工"""
|
||||||
|
tenant_resp = await client.post(
|
||||||
|
"/api/v1/tenants",
|
||||||
|
json={"name": "Update Emp Test", "slug": "update-emp-test"},
|
||||||
|
)
|
||||||
|
tenant_id = tenant_resp.json()["id"]
|
||||||
|
|
||||||
|
create_resp = await client.post(
|
||||||
|
"/api/v1/employees",
|
||||||
|
json={
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"name": "旧名称",
|
||||||
|
"role": "old_role",
|
||||||
|
"system_prompt": "旧提示词",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
employee_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/employees/{employee_id}",
|
||||||
|
json={
|
||||||
|
"name": "新名称",
|
||||||
|
"temperature": 0.5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "新名称"
|
||||||
|
assert data["temperature"] == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_employee(client: AsyncClient):
|
||||||
|
"""删除数字员工"""
|
||||||
|
tenant_resp = await client.post(
|
||||||
|
"/api/v1/tenants",
|
||||||
|
json={"name": "Delete Emp Test", "slug": "delete-emp-test"},
|
||||||
|
)
|
||||||
|
tenant_id = tenant_resp.json()["id"]
|
||||||
|
|
||||||
|
create_resp = await client.post(
|
||||||
|
"/api/v1/employees",
|
||||||
|
json={
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"name": "待删除",
|
||||||
|
"role": "temp",
|
||||||
|
"system_prompt": "临时",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
employee_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
response = await client.delete(f"/api/v1/employees/{employee_id}")
|
||||||
|
assert response.status_code == 204
|
||||||
|
|
||||||
|
get_resp = await client.get(f"/api/v1/employees/{employee_id}")
|
||||||
|
assert get_resp.status_code == 404
|
||||||
9
backend/tests/test_health.py
Normal file
9
backend/tests/test_health.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check(client: AsyncClient):
|
||||||
|
response = await client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"status": "ok"}
|
||||||
@ -14,7 +14,7 @@ from tests.conftest import test_engine
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_all_tables_created():
|
async def test_all_tables_created(setup_db):
|
||||||
expected_tables = {
|
expected_tables = {
|
||||||
"tenants",
|
"tenants",
|
||||||
"tenant_configs",
|
"tenant_configs",
|
||||||
|
|||||||
196
backend/tests/test_providers.py
Normal file
196
backend/tests/test_providers.py
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
"""LLM Provider 抽象层测试"""
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from app.providers.base import LLMMessage, LLMResponse
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_base_provider_interface():
|
||||||
|
"""验证 BaseLLMProvider 接口定义"""
|
||||||
|
from app.providers.base import BaseLLMProvider
|
||||||
|
|
||||||
|
# 接口应该定义 chat, chat_stream, embed 方法
|
||||||
|
assert hasattr(BaseLLMProvider, "chat")
|
||||||
|
assert hasattr(BaseLLMProvider, "chat_stream")
|
||||||
|
assert hasattr(BaseLLMProvider, "embed")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_message_dataclass():
|
||||||
|
"""验证 LLMMessage 数据类"""
|
||||||
|
msg = LLMMessage(role="user", content="Hello")
|
||||||
|
assert msg.role == "user"
|
||||||
|
assert msg.content == "Hello"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_response_dataclass():
|
||||||
|
"""验证 LLMResponse 数据类"""
|
||||||
|
resp = LLMResponse(content="Hi there!", model="gpt-4", usage={"total_tokens": 10})
|
||||||
|
assert resp.content == "Hi there!"
|
||||||
|
assert resp.model == "gpt-4"
|
||||||
|
assert resp.usage["total_tokens"] == 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_provider_chat():
|
||||||
|
"""测试 OpenAI Provider chat 方法(mock)"""
|
||||||
|
from app.providers.openai_provider import OpenAIProvider
|
||||||
|
|
||||||
|
provider = OpenAIProvider(api_key="sk-test", model="gpt-4")
|
||||||
|
|
||||||
|
with patch.object(provider, "client") as mock_client:
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock(message=MagicMock(content="Test response"))]
|
||||||
|
mock_response.usage = MagicMock(total_tokens=100, prompt_tokens=50, completion_tokens=50)
|
||||||
|
mock_response.model = "gpt-4"
|
||||||
|
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
result = await provider.chat([LLMMessage(role="user", content="Hi")])
|
||||||
|
assert result.content == "Test response"
|
||||||
|
assert result.model == "gpt-4"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_provider_chat_stream():
|
||||||
|
"""测试 OpenAI Provider chat_stream 方法(mock)"""
|
||||||
|
from app.providers.openai_provider import OpenAIProvider
|
||||||
|
|
||||||
|
provider = OpenAIProvider(api_key="sk-test", model="gpt-4")
|
||||||
|
|
||||||
|
async def mock_stream():
|
||||||
|
class MockChunk:
|
||||||
|
choices = [MagicMock(delta=MagicMock(content="token"))]
|
||||||
|
yield MockChunk()
|
||||||
|
yield MockChunk()
|
||||||
|
|
||||||
|
with patch.object(provider, "client") as mock_client:
|
||||||
|
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream())
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
async for chunk in provider.chat_stream([LLMMessage(role="user", content="Hi")]):
|
||||||
|
tokens.append(chunk)
|
||||||
|
|
||||||
|
assert len(tokens) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_provider_embed():
|
||||||
|
"""测试 OpenAI Provider embed 方法(mock)"""
|
||||||
|
from app.providers.openai_provider import OpenAIProvider
|
||||||
|
|
||||||
|
provider = OpenAIProvider(api_key="sk-test", model="text-embedding-3-small")
|
||||||
|
|
||||||
|
with patch.object(provider, "client") as mock_client:
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])]
|
||||||
|
mock_response.usage = MagicMock(total_tokens=10)
|
||||||
|
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
result = await provider.embed(["Hello world"])
|
||||||
|
assert len(result) == 1
|
||||||
|
assert len(result[0]) == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_qwen_provider_chat():
|
||||||
|
"""测试 Qwen Provider chat 方法(mock,兼容 OpenAI SDK)"""
|
||||||
|
from app.providers.qwen_provider import QwenProvider
|
||||||
|
|
||||||
|
provider = QwenProvider(
|
||||||
|
api_key="sk-qwen-test",
|
||||||
|
model="qwen-turbo",
|
||||||
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(provider, "client") as mock_client:
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock(message=MagicMock(content="Qwen response"))]
|
||||||
|
mock_response.usage = MagicMock(total_tokens=80)
|
||||||
|
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
result = await provider.chat([LLMMessage(role="user", content="你好")])
|
||||||
|
assert result.content == "Qwen response"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_registry():
|
||||||
|
"""测试 Provider 注册表"""
|
||||||
|
from app.providers import get_provider
|
||||||
|
|
||||||
|
# 根据 provider 类型返回正确的实例
|
||||||
|
openai_provider = get_provider("openai", api_key="sk-test", model="gpt-4")
|
||||||
|
assert openai_provider.__class__.__name__ == "OpenAIProvider"
|
||||||
|
|
||||||
|
qwen_provider = get_provider("qwen", api_key="sk-qwen", model="qwen-turbo")
|
||||||
|
assert qwen_provider.__class__.__name__ == "QwenProvider"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_registry_invalid():
|
||||||
|
"""测试 Provider 注册表无效类型"""
|
||||||
|
from app.providers import get_provider
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unknown provider"):
|
||||||
|
get_provider("invalid_provider", api_key="sk-test", model="model")
|
||||||
|
|
||||||
|
|
||||||
|
# ============= Task 2.3: Provider 集成到租户配置 =============
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_provider_for_tenant(client: AsyncClient):
|
||||||
|
"""测试根据租户配置获取 Provider"""
|
||||||
|
from app.services.tenant_service import get_provider_for_tenant
|
||||||
|
from tests.conftest import TestSession
|
||||||
|
|
||||||
|
# 创建租户并配置
|
||||||
|
tenant_resp = await client.post(
|
||||||
|
"/api/v1/tenants",
|
||||||
|
json={"name": "Provider Test", "slug": "provider-test"},
|
||||||
|
)
|
||||||
|
tenant_id = tenant_resp.json()["id"]
|
||||||
|
|
||||||
|
await client.put(
|
||||||
|
f"/api/v1/tenants/{tenant_id}/config",
|
||||||
|
json={
|
||||||
|
"llm_provider": "openai",
|
||||||
|
"llm_api_key": "sk-test-provider-key",
|
||||||
|
"llm_model": "gpt-4o",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取 Provider
|
||||||
|
async with TestSession() as session:
|
||||||
|
provider = await get_provider_for_tenant(session, tenant_id)
|
||||||
|
assert provider.__class__.__name__ == "OpenAIProvider"
|
||||||
|
assert provider.model == "gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_test_provider_endpoint(client: AsyncClient):
|
||||||
|
"""测试 Provider 测试端点"""
|
||||||
|
# 创建租户并配置
|
||||||
|
tenant_resp = await client.post(
|
||||||
|
"/api/v1/tenants",
|
||||||
|
json={"name": "Provider Endpoint Test", "slug": "provider-endpoint-test"},
|
||||||
|
)
|
||||||
|
tenant_id = tenant_resp.json()["id"]
|
||||||
|
|
||||||
|
await client.put(
|
||||||
|
f"/api/v1/tenants/{tenant_id}/config",
|
||||||
|
json={
|
||||||
|
"llm_provider": "openai",
|
||||||
|
"llm_api_key": "sk-test-endpoint",
|
||||||
|
"llm_model": "gpt-4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试 Provider
|
||||||
|
response = await client.post(f"/api/v1/tenants/{tenant_id}/test-provider")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["provider"] == "openai"
|
||||||
|
assert data["model"] == "gpt-4"
|
||||||
|
assert "status" in data
|
||||||
127
backend/tests/test_tenants.py
Normal file
127
backend/tests/test_tenants.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
from app.models import Tenant
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_tenant(client: AsyncClient):
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/tenants",
|
||||||
|
json={"name": "Test Company", "slug": "test-company"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 201
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "Test Company"
|
||||||
|
assert data["slug"] == "test-company"
|
||||||
|
assert "id" in data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tenants(client: AsyncClient):
|
||||||
|
await client.post("/api/v1/tenants", json={"name": "A", "slug": "a"})
|
||||||
|
await client.post("/api/v1/tenants", json={"name": "B", "slug": "b"})
|
||||||
|
|
||||||
|
response = await client.get("/api/v1/tenants")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data) >= 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_tenant(client: AsyncClient):
|
||||||
|
create_resp = await client.post(
|
||||||
|
"/api/v1/tenants",
|
||||||
|
json={"name": "Get Test", "slug": "get-test"},
|
||||||
|
)
|
||||||
|
tenant_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
response = await client.get(f"/api/v1/tenants/{tenant_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "Get Test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_tenant(client: AsyncClient):
|
||||||
|
create_resp = await client.post(
|
||||||
|
"/api/v1/tenants",
|
||||||
|
json={"name": "Old Name", "slug": "old-slug"},
|
||||||
|
)
|
||||||
|
tenant_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/tenants/{tenant_id}",
|
||||||
|
json={"name": "New Name"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "New Name"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_tenant(client: AsyncClient):
|
||||||
|
create_resp = await client.post(
|
||||||
|
"/api/v1/tenants",
|
||||||
|
json={"name": "To Delete", "slug": "to-delete"},
|
||||||
|
)
|
||||||
|
tenant_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
response = await client.delete(f"/api/v1/tenants/{tenant_id}")
|
||||||
|
assert response.status_code == 204
|
||||||
|
|
||||||
|
get_resp = await client.get(f"/api/v1/tenants/{tenant_id}")
|
||||||
|
assert get_resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_tenant_config(client: AsyncClient):
|
||||||
|
create_resp = await client.post(
|
||||||
|
"/api/v1/tenants",
|
||||||
|
json={"name": "Config Test", "slug": "config-test"},
|
||||||
|
)
|
||||||
|
tenant_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
response = await client.put(
|
||||||
|
f"/api/v1/tenants/{tenant_id}/config",
|
||||||
|
json={
|
||||||
|
"llm_provider": "openai",
|
||||||
|
"llm_api_key": "sk-test-key-12345",
|
||||||
|
"llm_model": "gpt-4",
|
||||||
|
"llm_base_url": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["llm_provider"] == "openai"
|
||||||
|
assert data["llm_model"] == "gpt-4"
|
||||||
|
assert "llm_api_key" not in data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_key_encrypted(client: AsyncClient):
|
||||||
|
from tests.conftest import TestSession
|
||||||
|
|
||||||
|
create_resp = await client.post(
|
||||||
|
"/api/v1/tenants",
|
||||||
|
json={"name": "Encrypt Test", "slug": "encrypt-test"},
|
||||||
|
)
|
||||||
|
tenant_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
await client.put(
|
||||||
|
f"/api/v1/tenants/{tenant_id}/config",
|
||||||
|
json={
|
||||||
|
"llm_provider": "openai",
|
||||||
|
"llm_api_key": "sk-secret-key-99999",
|
||||||
|
"llm_model": "gpt-4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async with TestSession() as session:
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.models import TenantConfig
|
||||||
|
|
||||||
|
result = await session.execute(
|
||||||
|
select(TenantConfig).where(TenantConfig.tenant_id == tenant_id)
|
||||||
|
)
|
||||||
|
config = result.scalar_one()
|
||||||
|
assert config.llm_api_key != "sk-secret-key-99999"
|
||||||
|
assert len(config.llm_api_key) > 20
|
||||||
2735
backend/uv.lock
generated
Normal file
2735
backend/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
17
memory/2026-05-06.md
Normal file
17
memory/2026-05-06.md
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# 2026-05-06
|
||||||
|
|
||||||
|
## 纪要
|
||||||
|
|
||||||
|
- 工作区已初始化,可在此持续记录当天上下文。
|
||||||
|
|
||||||
|
## 进展
|
||||||
|
|
||||||
|
### B1 批次完成:项目脚手架 + 数据模型 + 租户管理
|
||||||
|
- ✅ Task 1.1 项目脚手架
|
||||||
|
- ✅ Task 1.2 数据模型(7个模型)
|
||||||
|
- ✅ Task 1.3 租户管理 API(7个测试全部通过)
|
||||||
|
- 实现:schemas/tenant.py, api/deps.py, services/tenant_service.py, api/v1/tenants.py
|
||||||
|
- 修复:conftest.py fixture 顺序和 get_db 依赖覆盖
|
||||||
|
|
||||||
|
### 下一步
|
||||||
|
- B2 批次:数字员工配置 + LLM Provider 抽象层
|
||||||
Loading…
x
Reference in New Issue
Block a user