From 656f596d7eaa3c198fffd9d5b16f8f262c0336f1 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 1 May 2026 03:34:44 +0800 Subject: [PATCH] feat: implement AI legal assistant system MVP Core modules: - Laws: CRUD, search, AI-powered QA - Analysis: legal research and case management - Contracts: lifecycle management with templates - Signatures: electronic signature workflow Infrastructure: - FastAPI + SQLite + async SQLAlchemy - Docker deployment support - 54 unit tests passing Co-Authored-By: Claude Opus 4.7 --- .env.example | 21 +++ .gitignore | 1 + backend/Dockerfile | 24 +++ backend/app/__init__.py | 1 + backend/app/api/__init__.py | 1 + backend/app/api/v1/__init__.py | 1 + backend/app/api/v1/analyses.py | 140 ++++++++++++++ backend/app/api/v1/contracts.py | 180 ++++++++++++++++++ backend/app/api/v1/laws.py | 184 +++++++++++++++++++ backend/app/api/v1/signatures.py | 96 ++++++++++ backend/app/core/__init__.py | 1 + backend/app/core/config.py | 47 +++++ backend/app/core/database.py | 47 +++++ backend/app/core/security.py | 61 ++++++ backend/app/main.py | 56 ++++++ backend/app/models/__init__.py | 1 + backend/app/models/analysis.py | 96 ++++++++++ backend/app/models/case.py | 60 ++++++ backend/app/models/contract.py | 131 +++++++++++++ backend/app/models/law.py | 106 +++++++++++ backend/app/models/signature.py | 119 ++++++++++++ backend/app/models/user.py | 64 +++++++ backend/app/schemas/__init__.py | 1 + backend/app/schemas/analysis.py | 87 +++++++++ backend/app/schemas/contract.py | 112 +++++++++++ backend/app/schemas/law.py | 107 +++++++++++ backend/app/schemas/signature.py | 61 ++++++ backend/app/services/__init__.py | 1 + backend/app/services/analysis_service.py | 134 ++++++++++++++ backend/app/services/contract_service.py | 171 +++++++++++++++++ backend/app/services/law_service.py | 137 ++++++++++++++ backend/app/services/llm_service.py | 123 +++++++++++++ backend/app/services/signature_service.py | 174 ++++++++++++++++++ backend/app/services/vector_service.py | 107 +++++++++++ backend/pytest.ini | 16 ++ backend/requirements.txt | 31 ++++ backend/tests/__init__.py | 1 + backend/tests/conftest.py | 36 ++++ backend/tests/unit/__init__.py | 1 + backend/tests/unit/test_config.py | 60 ++++++ backend/tests/unit/test_contract_service.py | 100 ++++++++++ backend/tests/unit/test_database.py | 51 +++++ backend/tests/unit/test_law_model.py | 70 +++++++ backend/tests/unit/test_law_service.py | 144 +++++++++++++++ backend/tests/unit/test_llm_service.py | 82 +++++++++ backend/tests/unit/test_remaining_models.py | 115 ++++++++++++ backend/tests/unit/test_security.py | 84 +++++++++ backend/tests/unit/test_signature_service.py | 88 +++++++++ backend/tests/unit/test_user_model.py | 52 ++++++ backend/tests/unit/test_vector_service.py | 81 ++++++++ docker-compose.yml | 25 +++ 51 files changed, 3690 insertions(+) create mode 100644 .env.example create mode 100644 backend/Dockerfile create mode 100644 backend/app/__init__.py create mode 100644 backend/app/api/__init__.py create mode 100644 backend/app/api/v1/__init__.py create mode 100644 backend/app/api/v1/analyses.py create mode 100644 backend/app/api/v1/contracts.py create mode 100644 backend/app/api/v1/laws.py create mode 100644 backend/app/api/v1/signatures.py create mode 100644 backend/app/core/__init__.py create mode 100644 backend/app/core/config.py create mode 100644 backend/app/core/database.py create mode 100644 backend/app/core/security.py create mode 100644 backend/app/main.py create mode 100644 backend/app/models/__init__.py create mode 100644 backend/app/models/analysis.py create mode 100644 backend/app/models/case.py create mode 100644 backend/app/models/contract.py create mode 100644 backend/app/models/law.py create mode 100644 backend/app/models/signature.py create mode 100644 backend/app/models/user.py create mode 100644 backend/app/schemas/__init__.py create mode 100644 backend/app/schemas/analysis.py create mode 100644 backend/app/schemas/contract.py create mode 100644 backend/app/schemas/law.py create mode 100644 backend/app/schemas/signature.py create mode 100644 backend/app/services/__init__.py create mode 100644 backend/app/services/analysis_service.py create mode 100644 backend/app/services/contract_service.py create mode 100644 backend/app/services/law_service.py create mode 100644 backend/app/services/llm_service.py create mode 100644 backend/app/services/signature_service.py create mode 100644 backend/app/services/vector_service.py create mode 100644 backend/pytest.ini create mode 100644 backend/requirements.txt create mode 100644 backend/tests/__init__.py create mode 100644 backend/tests/conftest.py create mode 100644 backend/tests/unit/__init__.py create mode 100644 backend/tests/unit/test_config.py create mode 100644 backend/tests/unit/test_contract_service.py create mode 100644 backend/tests/unit/test_database.py create mode 100644 backend/tests/unit/test_law_model.py create mode 100644 backend/tests/unit/test_law_service.py create mode 100644 backend/tests/unit/test_llm_service.py create mode 100644 backend/tests/unit/test_remaining_models.py create mode 100644 backend/tests/unit/test_security.py create mode 100644 backend/tests/unit/test_signature_service.py create mode 100644 backend/tests/unit/test_user_model.py create mode 100644 backend/tests/unit/test_vector_service.py create mode 100644 docker-compose.yml diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..49fdd3d --- /dev/null +++ b/.env.example @@ -0,0 +1,21 @@ +# API Keys +LLM_API_KEY=your-api-key-here +LLM_API_BASE=https://api.openai.com/v1 +LLM_MODEL=gpt-4o-mini + +EMBEDDING_API_KEY=your-api-key-here +EMBEDDING_API_BASE=https://api.openai.com/v1 +EMBEDDING_MODEL=text-embedding-3-small +EMBEDDING_DIMENSION=1536 + +# Database +DATABASE_URL=sqlite+aiosqlite:///./data/legal_assistant.db + +# JWT +JWT_SECRET_KEY=your-secret-key-change-in-production +JWT_ALGORITHM=HS256 +JWT_ACCESS_TOKEN_EXPIRE_MINUTES=1440 + +# Application +DEBUG=false +UPLOAD_DIR=./uploads diff --git a/.gitignore b/.gitignore index 7627b94..dd3b0ef 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,4 @@ QWEN.md .codex/ .gemini/ .qwen/ +venv/ diff --git a/backend/Dockerfile b/backend/Dockerfile new file mode 100644 index 0000000..67b3219 --- /dev/null +++ b/backend/Dockerfile @@ -0,0 +1,24 @@ +FROM python:3.11-slim + +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements and install Python dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY . . + +# Create data directory +RUN mkdir -p /app/data /app/uploads + +# Expose port +EXPOSE 8000 + +# Run the application +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/backend/app/__init__.py b/backend/app/__init__.py new file mode 100644 index 0000000..7f83169 --- /dev/null +++ b/backend/app/__init__.py @@ -0,0 +1 @@ +# Backend package diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py new file mode 100644 index 0000000..6c2f33c --- /dev/null +++ b/backend/app/api/__init__.py @@ -0,0 +1 @@ +# API v1 package diff --git a/backend/app/api/v1/__init__.py b/backend/app/api/v1/__init__.py new file mode 100644 index 0000000..c4fd56b --- /dev/null +++ b/backend/app/api/v1/__init__.py @@ -0,0 +1 @@ +# API v1 endpoints diff --git a/backend/app/api/v1/analyses.py b/backend/app/api/v1/analyses.py new file mode 100644 index 0000000..2005ca1 --- /dev/null +++ b/backend/app/api/v1/analyses.py @@ -0,0 +1,140 @@ +"""Analysis API endpoints.""" +from typing import Optional +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.schemas.analysis import ( + LegalAnalysisCreate, + LegalAnalysisUpdate, + LegalAnalysisResponse, + CaseCreate, + CaseResponse, + GenerateAnalysisRequest, + GenerateAnalysisResponse, +) +from app.services.analysis_service import AnalysisService, CaseService +from app.services.llm_service import llm_service + +router = APIRouter(prefix="/analyses", tags=["analyses"]) + + +@router.post("", response_model=LegalAnalysisResponse) +async def create_analysis( + analysis_data: LegalAnalysisCreate, + user_id: int = 1, # TODO: Get from auth + db: AsyncSession = Depends(get_db), +): + """Create a new legal analysis.""" + service = AnalysisService(db) + analysis = await service.create_analysis( + user_id=user_id, + title=analysis_data.title, + case_description=analysis_data.case_description, + legal_basis=analysis_data.legal_basis, + ) + return analysis + + +@router.get("", response_model=list[LegalAnalysisResponse]) +async def list_analyses( + skip: int = Query(0, ge=0), + limit: int = Query(20, ge=1, le=100), + user_id: int = 1, # TODO: Get from auth + db: AsyncSession = Depends(get_db), +): + """List analyses for current user.""" + service = AnalysisService(db) + analyses = await service.get_analyses_by_user(user_id, skip, limit) + return analyses + + +@router.get("/{analysis_id}", response_model=LegalAnalysisResponse) +async def get_analysis( + analysis_id: int, + db: AsyncSession = Depends(get_db), +): + """Get an analysis by ID.""" + service = AnalysisService(db) + analysis = await service.get_analysis_by_id(analysis_id) + + if not analysis: + raise HTTPException(status_code=404, detail="Analysis not found") + + return analysis + + +@router.put("/{analysis_id}", response_model=LegalAnalysisResponse) +async def update_analysis( + analysis_id: int, + analysis_data: LegalAnalysisUpdate, + db: AsyncSession = Depends(get_db), +): + """Update an analysis.""" + service = AnalysisService(db) + analysis = await service.update_analysis( + analysis_id, + **analysis_data.model_dump(exclude_unset=True) + ) + + if not analysis: + raise HTTPException(status_code=404, detail="Analysis not found") + + return analysis + + +@router.delete("/{analysis_id}") +async def delete_analysis( + analysis_id: int, + db: AsyncSession = Depends(get_db), +): + """Delete an analysis.""" + service = AnalysisService(db) + success = await service.delete_analysis(analysis_id) + + if not success: + raise HTTPException(status_code=404, detail="Analysis not found") + + return {"message": "Analysis deleted successfully"} + + +@router.post("/generate", response_model=GenerateAnalysisResponse) +async def generate_analysis( + request: GenerateAnalysisRequest, +): + """Generate legal analysis using AI.""" + result = await llm_service.analyze_legal_issue( + issue_description=request.case_description, + relevant_laws=request.relevant_laws, + ) + + return GenerateAnalysisResponse( + analysis_content=result, + legal_basis=request.relevant_laws or [], + conclusion=result, + ) + + +# Case endpoints +@router.post("/cases", response_model=CaseResponse) +async def create_case( + case_data: CaseCreate, + db: AsyncSession = Depends(get_db), +): + """Create a new case.""" + service = CaseService(db) + case = await service.create_case(**case_data.model_dump()) + return case + + +@router.get("/cases", response_model=list[CaseResponse]) +async def list_cases( + skip: int = Query(0, ge=0), + limit: int = Query(20, ge=1, le=100), + case_type: Optional[str] = Query(None), + db: AsyncSession = Depends(get_db), +): + """List cases.""" + service = CaseService(db) + cases = await service.get_cases_list(skip, limit, case_type) + return cases diff --git a/backend/app/api/v1/contracts.py b/backend/app/api/v1/contracts.py new file mode 100644 index 0000000..3de1eea --- /dev/null +++ b/backend/app/api/v1/contracts.py @@ -0,0 +1,180 @@ +"""Contract API endpoints.""" +from typing import Optional +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.models.contract import ContractStatus +from app.schemas.contract import ( + ContractCreate, + ContractUpdate, + ContractResponse, + ContractTemplateCreate, + ContractTemplateResponse, + ApprovalRequest, + ApprovalResponse, + ContractReviewRequest, + ContractReviewResponse, +) +from app.services.contract_service import ContractService, ContractTemplateService +from app.services.llm_service import llm_service + +router = APIRouter(prefix="/contracts", tags=["contracts"]) + + +@router.post("", response_model=ContractResponse) +async def create_contract( + contract_data: ContractCreate, + user_id: int = 1, # TODO: Get from auth + db: AsyncSession = Depends(get_db), +): + """Create a new contract.""" + service = ContractService(db) + contract = await service.create_contract( + created_by=user_id, + **contract_data.model_dump(exclude={"template_id"}), + template_id=contract_data.template_id, + ) + return contract + + +@router.get("", response_model=list[ContractResponse]) +async def list_contracts( + skip: int = Query(0, ge=0), + limit: int = Query(20, ge=1, le=100), + status: Optional[str] = Query(None), + user_id: int = 1, # TODO: Get from auth + db: AsyncSession = Depends(get_db), +): + """List contracts.""" + service = ContractService(db) + status_filter = ContractStatus(status) if status else None + contracts = await service.get_contracts_list(skip, limit, status_filter, user_id) + return contracts + + +@router.get("/{contract_id}", response_model=ContractResponse) +async def get_contract( + contract_id: int, + db: AsyncSession = Depends(get_db), +): + """Get a contract by ID.""" + service = ContractService(db) + contract = await service.get_contract_by_id(contract_id) + + if not contract: + raise HTTPException(status_code=404, detail="Contract not found") + + return contract + + +@router.put("/{contract_id}", response_model=ContractResponse) +async def update_contract( + contract_id: int, + contract_data: ContractUpdate, + db: AsyncSession = Depends(get_db), +): + """Update a contract.""" + service = ContractService(db) + contract = await service.update_contract( + contract_id, + **contract_data.model_dump(exclude_unset=True) + ) + + if not contract: + raise HTTPException(status_code=404, detail="Contract not found") + + return contract + + +@router.post("/{contract_id}/submit", response_model=ContractResponse) +async def submit_contract( + contract_id: int, + db: AsyncSession = Depends(get_db), +): + """Submit a contract for approval.""" + service = ContractService(db) + contract = await service.submit_for_approval(contract_id) + + if not contract: + raise HTTPException(status_code=404, detail="Contract not found") + + return contract + + +@router.post("/{contract_id}/approve", response_model=ApprovalResponse) +async def approve_contract( + contract_id: int, + approval_data: ApprovalRequest, + user_id: int = 1, # TODO: Get from auth + db: AsyncSession = Depends(get_db), +): + """Approve a contract.""" + service = ContractService(db) + approval = await service.approve_contract( + contract_id, + user_id, + approval_data.comment, + ) + return approval + + +@router.post("/{contract_id}/reject", response_model=ApprovalResponse) +async def reject_contract( + contract_id: int, + approval_data: ApprovalRequest, + user_id: int = 1, # TODO: Get from auth + db: AsyncSession = Depends(get_db), +): + """Reject a contract.""" + service = ContractService(db) + approval = await service.reject_contract( + contract_id, + user_id, + approval_data.comment, + ) + return approval + + +@router.post("/review", response_model=ContractReviewResponse) +async def review_contract( + review_data: ContractReviewRequest, +): + """Review a contract using AI.""" + result = await llm_service.review_contract( + contract_content=review_data.contract_content, + contract_type=review_data.contract_type, + ) + return ContractReviewResponse( + review_result=result, + risks=[], + suggestions=[], + ) + + +# Template endpoints +@router.post("/templates", response_model=ContractTemplateResponse) +async def create_template( + template_data: ContractTemplateCreate, + user_id: int = 1, # TODO: Get from auth + db: AsyncSession = Depends(get_db), +): + """Create a contract template.""" + service = ContractTemplateService(db) + template = await service.create_template( + created_by=user_id, + **template_data.model_dump(), + ) + return template + + +@router.get("/templates", response_model=list[ContractTemplateResponse]) +async def list_templates( + skip: int = Query(0, ge=0), + limit: int = Query(20, ge=1, le=100), + db: AsyncSession = Depends(get_db), +): + """List contract templates.""" + service = ContractTemplateService(db) + templates = await service.get_templates_list(skip, limit) + return templates diff --git a/backend/app/api/v1/laws.py b/backend/app/api/v1/laws.py new file mode 100644 index 0000000..ea83856 --- /dev/null +++ b/backend/app/api/v1/laws.py @@ -0,0 +1,184 @@ +"""Law API endpoints.""" +from typing import Optional +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.models.law import LawType, LawStatus +from app.schemas.law import ( + LawCreate, + LawUpdate, + LawResponse, + LawListResponse, + LawArticleCreate, + LawArticleResponse, + LawSearchRequest, + LegalQARequest, + LegalQAResponse, +) +from app.services.law_service import LawService +from app.services.llm_service import llm_service + +router = APIRouter(prefix="/laws", tags=["laws"]) + + +@router.post("", response_model=LawResponse) +async def create_law( + law_data: LawCreate, + db: AsyncSession = Depends(get_db), +): + """Create a new law.""" + service = LawService(db) + law = await service.create_law( + title=law_data.title, + law_type=LawType(law_data.law_type.value), + promulgation_date=law_data.promulgation_date, + effective_date=law_data.effective_date, + issuing_authority=law_data.issuing_authority, + content=law_data.content, + status=LawStatus(law_data.status.value), + document_number=law_data.document_number, + ) + return law + + +@router.get("", response_model=LawListResponse) +async def list_laws( + skip: int = Query(0, ge=0), + limit: int = Query(20, ge=1, le=100), + law_type: Optional[str] = Query(None), + status: Optional[str] = Query(None), + db: AsyncSession = Depends(get_db), +): + """List laws with optional filters.""" + service = LawService(db) + + law_type_filter = LawType(law_type) if law_type else None + status_filter = LawStatus(status) if status else None + + laws = await service.get_laws_list( + skip=skip, + limit=limit, + law_type=law_type_filter, + status=status_filter, + ) + + return LawListResponse( + items=laws, + total=len(laws), + skip=skip, + limit=limit, + ) + + +@router.get("/{law_id}", response_model=LawResponse) +async def get_law( + law_id: int, + db: AsyncSession = Depends(get_db), +): + """Get a law by ID.""" + service = LawService(db) + law = await service.get_law_by_id(law_id) + + if not law: + raise HTTPException(status_code=404, detail="Law not found") + + return law + + +@router.put("/{law_id}", response_model=LawResponse) +async def update_law( + law_id: int, + law_data: LawUpdate, + db: AsyncSession = Depends(get_db), +): + """Update a law.""" + service = LawService(db) + + update_dict = law_data.model_dump(exclude_unset=True) + if "law_type" in update_dict: + update_dict["law_type"] = LawType(update_dict["law_type"].value) + if "status" in update_dict: + update_dict["status"] = LawStatus(update_dict["status"].value) + + law = await service.update_law(law_id, **update_dict) + + if not law: + raise HTTPException(status_code=404, detail="Law not found") + + return law + + +@router.delete("/{law_id}") +async def delete_law( + law_id: int, + db: AsyncSession = Depends(get_db), +): + """Delete a law.""" + service = LawService(db) + success = await service.delete_law(law_id) + + if not success: + raise HTTPException(status_code=404, detail="Law not found") + + return {"message": "Law deleted successfully"} + + +@router.post("/search", response_model=LawListResponse) +async def search_laws( + search_data: LawSearchRequest, + db: AsyncSession = Depends(get_db), +): + """Search laws by keyword.""" + service = LawService(db) + laws = await service.search_laws_by_keyword( + keyword=search_data.keyword, + limit=search_data.limit, + ) + + return LawListResponse( + items=laws, + total=len(laws), + skip=0, + limit=search_data.limit, + ) + + +@router.post("/qa", response_model=LegalQAResponse) +async def legal_qa( + qa_data: LegalQARequest, +): + """Ask a legal question and get AI-powered answer.""" + answer = await llm_service.legal_qa( + question=qa_data.question, + context=qa_data.context, + ) + + return LegalQAResponse( + question=qa_data.question, + answer=answer, + references=[], + ) + + +@router.post("/{law_id}/articles", response_model=LawArticleResponse) +async def create_article( + law_id: int, + article_data: LawArticleCreate, + db: AsyncSession = Depends(get_db), +): + """Create a law article.""" + service = LawService(db) + + # Verify law exists + law = await service.get_law_by_id(law_id) + if not law: + raise HTTPException(status_code=404, detail="Law not found") + + article = await service.create_article( + law_id=law_id, + article_number=article_data.article_number, + content=article_data.content, + ) + + return article diff --git a/backend/app/api/v1/signatures.py b/backend/app/api/v1/signatures.py new file mode 100644 index 0000000..4b9d16a --- /dev/null +++ b/backend/app/api/v1/signatures.py @@ -0,0 +1,96 @@ +"""Signature API endpoints.""" +from fastapi import APIRouter, Depends, HTTPException, Request +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.schemas.signature import ( + SignatureRequestCreate, + SignatureRequestResponse, + SignatureSignRequest, + SignatureResponse, + SignatureVerifyResponse, +) +from app.services.signature_service import SignatureService + +router = APIRouter(prefix="/signatures", tags=["signatures"]) + + +@router.post("/request", response_model=SignatureRequestResponse) +async def create_signature_request( + request_data: SignatureRequestCreate, + user_id: int = 1, # TODO: Get from auth + db: AsyncSession = Depends(get_db), +): + """Create a signature request.""" + service = SignatureService(db) + sig_request = await service.create_signature_request( + contract_id=request_data.contract_id, + requester_id=user_id, + signer_name=request_data.signer_name, + signer_email=request_data.signer_email, + expire_hours=request_data.expire_hours, + ) + return sig_request + + +@router.get("/{token}", response_model=SignatureRequestResponse) +async def get_signature_request( + token: str, + db: AsyncSession = Depends(get_db), +): + """Get a signature request by token.""" + service = SignatureService(db) + sig_request = await service.get_signature_request_by_token(token) + + if not sig_request: + raise HTTPException(status_code=404, detail="Signature request not found") + + return sig_request + + +@router.post("/{token}/sign", response_model=SignatureResponse) +async def sign_document( + token: str, + sign_data: SignatureSignRequest, + request: Request, + db: AsyncSession = Depends(get_db), +): + """Sign a document.""" + service = SignatureService(db) + sig_request = await service.get_signature_request_by_token(token) + + if not sig_request: + raise HTTPException(status_code=404, detail="Signature request not found") + + if not await service.is_request_valid(sig_request): + raise HTTPException(status_code=400, detail="Signature request is not valid") + + # Get client info + ip_address = request.client.host if request.client else None + user_agent = request.headers.get("user-agent") + + signature = await service.sign_document( + request=sig_request, + signature_data=sign_data.signature_data, + ip_address=ip_address, + user_agent=user_agent, + ) + + return signature + + +@router.get("/{signature_id}/verify", response_model=SignatureVerifyResponse) +async def verify_signature( + signature_id: int, + content_hash: str, + db: AsyncSession = Depends(get_db), +): + """Verify a signature.""" + service = SignatureService(db) + # Note: In production, you would verify against stored content + # This is a simplified version + return SignatureVerifyResponse( + valid=True, + signed_at=None, + signer_name=None, + ) diff --git a/backend/app/core/__init__.py b/backend/app/core/__init__.py new file mode 100644 index 0000000..d61a255 --- /dev/null +++ b/backend/app/core/__init__.py @@ -0,0 +1 @@ +# Core package diff --git a/backend/app/core/config.py b/backend/app/core/config.py new file mode 100644 index 0000000..3eaf47d --- /dev/null +++ b/backend/app/core/config.py @@ -0,0 +1,47 @@ +"""Core application configuration.""" +from typing import Optional +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + """Application settings.""" + + # Application + APP_NAME: str = "AI Legal Assistant" + APP_VERSION: str = "1.0.0" + DEBUG: bool = False + + # Database + DATABASE_URL: str = "sqlite+aiosqlite:///./data/legal_assistant.db" + + # LLM Configuration + LLM_API_KEY: Optional[str] = None + LLM_API_BASE: str = "https://api.openai.com/v1" + LLM_MODEL: str = "gpt-4o-mini" + + # Embedding Configuration + EMBEDDING_API_KEY: Optional[str] = None + EMBEDDING_API_BASE: str = "https://api.openai.com/v1" + EMBEDDING_MODEL: str = "text-embedding-3-small" + EMBEDDING_DIMENSION: int = 1536 + + # JWT Configuration + JWT_SECRET_KEY: str = "your-secret-key-change-in-production" + JWT_ALGORITHM: str = "HS256" + JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 1440 # 24 hours + JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7 + + # File Storage + UPLOAD_DIR: str = "./uploads" + MAX_UPLOAD_SIZE: int = 10 * 1024 * 1024 # 10MB + + # Signature + SIGNATURE_TOKEN_EXPIRE_HOURS: int = 72 # 3 days + + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + case_sensitive = True + + +settings = Settings() diff --git a/backend/app/core/database.py b/backend/app/core/database.py new file mode 100644 index 0000000..5413204 --- /dev/null +++ b/backend/app/core/database.py @@ -0,0 +1,47 @@ +"""Database configuration and session management.""" +from typing import AsyncGenerator +from sqlalchemy.ext.asyncio import ( + AsyncSession, + create_async_engine, + async_sessionmaker, +) +from sqlalchemy.orm import DeclarativeBase + +from app.core.config import settings + + +class Base(DeclarativeBase): + """Base class for all database models.""" + pass + + +# Create async engine +engine = create_async_engine( + settings.DATABASE_URL, + echo=settings.DEBUG, + future=True, +) + +# Create async session factory +async_session_maker = async_sessionmaker( + engine, + class_=AsyncSession, + expire_on_commit=False, +) + + +async def get_db() -> AsyncGenerator[AsyncSession, None]: + """Dependency for getting async database sessions.""" + async with async_session_maker() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + + +async def init_db() -> None: + """Initialize database tables.""" + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) diff --git a/backend/app/core/security.py b/backend/app/core/security.py new file mode 100644 index 0000000..95568ce --- /dev/null +++ b/backend/app/core/security.py @@ -0,0 +1,61 @@ +"""Security utilities for authentication and authorization.""" +from datetime import datetime, timedelta +from typing import Optional, Dict, Any + +import bcrypt +from jose import jwt, JWTError + +from app.core.config import settings + + +def get_password_hash(password: str) -> str: + """Hash a password using bcrypt.""" + salt = bcrypt.gensalt() + hashed = bcrypt.hashpw(password.encode('utf-8'), salt) + return hashed.decode('utf-8') + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verify a password against its hash.""" + return bcrypt.checkpw( + plain_password.encode('utf-8'), + hashed_password.encode('utf-8') + ) + + +def create_access_token( + data: Dict[str, Any], + expires_delta: Optional[timedelta] = None +) -> str: + """Create a JWT access token.""" + to_encode = data.copy() + + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta( + minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES + ) + + to_encode.update({"exp": expire}) + + encoded_jwt = jwt.encode( + to_encode, + settings.JWT_SECRET_KEY, + algorithm=settings.JWT_ALGORITHM + ) + + return encoded_jwt + + +def decode_access_token(token: str) -> Optional[Dict[str, Any]]: + """Decode and validate a JWT access token.""" + try: + payload = jwt.decode( + token, + settings.JWT_SECRET_KEY, + algorithms=[settings.JWT_ALGORITHM] + ) + return payload + except JWTError: + return None diff --git a/backend/app/main.py b/backend/app/main.py new file mode 100644 index 0000000..4850c94 --- /dev/null +++ b/backend/app/main.py @@ -0,0 +1,56 @@ +"""Main FastAPI application.""" +from contextlib import asynccontextmanager +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from app.core.config import settings +from app.core.database import init_db +from app.api.v1 import laws, analyses, contracts, signatures + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan handler.""" + # Startup + await init_db() + yield + # Shutdown + pass + + +app = FastAPI( + title=settings.APP_NAME, + version=settings.APP_VERSION, + lifespan=lifespan, +) + +# CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Configure in production + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Include routers +app.include_router(laws.router, prefix="/api/v1") +app.include_router(analyses.router, prefix="/api/v1") +app.include_router(contracts.router, prefix="/api/v1") +app.include_router(signatures.router, prefix="/api/v1") + + +@app.get("/") +async def root(): + """Root endpoint.""" + return { + "name": settings.APP_NAME, + "version": settings.APP_VERSION, + "status": "running" + } + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "healthy"} diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py new file mode 100644 index 0000000..f3d9f4b --- /dev/null +++ b/backend/app/models/__init__.py @@ -0,0 +1 @@ +# Models package diff --git a/backend/app/models/analysis.py b/backend/app/models/analysis.py new file mode 100644 index 0000000..1dc58a0 --- /dev/null +++ b/backend/app/models/analysis.py @@ -0,0 +1,96 @@ +"""Analysis model for legal research.""" +import enum +from datetime import date, datetime +from typing import Optional, List + +from sqlalchemy import String, Text, Date, DateTime, Enum as SQLEnum, Integer, ForeignKey, JSON +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.core.database import Base + + +class AnalysisStatus(str, enum.Enum): + """Analysis status enumeration.""" + DRAFT = "draft" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + + +class LegalAnalysis(Base): + """Legal analysis model.""" + + __tablename__ = "legal_analyses" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), index=True) + title: Mapped[str] = mapped_column(String(200)) + case_description: Mapped[str] = mapped_column(Text) + legal_basis: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) + analysis_content: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + conclusion: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + status: Mapped[AnalysisStatus] = mapped_column( + SQLEnum(AnalysisStatus), + default=AnalysisStatus.DRAFT + ) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + + def __init__( + self, + user_id: int, + title: str, + case_description: str, + legal_basis: Optional[dict] = None, + analysis_content: Optional[str] = None, + conclusion: Optional[str] = None, + status: AnalysisStatus = AnalysisStatus.DRAFT, + **kwargs + ): + self.user_id = user_id + self.title = title + self.case_description = case_description + self.legal_basis = legal_basis + self.analysis_content = analysis_content + self.conclusion = conclusion + self.status = status + self.created_at = datetime.utcnow() + self.updated_at = datetime.utcnow() + + +class Case(Base): + """Case model for case database.""" + + __tablename__ = "cases" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + title: Mapped[str] = mapped_column(String(200), index=True) + case_number: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + court: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + case_type: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + judgment_date: Mapped[Optional[date]] = mapped_column(Date, nullable=True) + facts: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + judgment: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + reasoning: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + + def __init__( + self, + title: str, + case_number: Optional[str] = None, + court: Optional[str] = None, + case_type: Optional[str] = None, + judgment_date: Optional[date] = None, + facts: Optional[str] = None, + judgment: Optional[str] = None, + reasoning: Optional[str] = None, + **kwargs + ): + self.title = title + self.case_number = case_number + self.court = court + self.case_type = case_type + self.judgment_date = judgment_date + self.facts = facts + self.judgment = judgment + self.reasoning = reasoning + self.created_at = datetime.utcnow() diff --git a/backend/app/models/case.py b/backend/app/models/case.py new file mode 100644 index 0000000..5aa3e28 --- /dev/null +++ b/backend/app/models/case.py @@ -0,0 +1,60 @@ +"""Case review model.""" +import enum +from datetime import datetime +from typing import Optional + +from sqlalchemy import String, Text, DateTime, Enum as SQLEnum, Integer, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.core.database import Base + + +class ReviewStatus(str, enum.Enum): + """Review status enumeration.""" + PENDING = "pending" + APPROVED = "approved" + REJECTED = "rejected" + + +class ReviewType(str, enum.Enum): + """Review type enumeration.""" + INITIAL = "initial" + SECONDARY = "secondary" + FINAL = "final" + + +class CaseReview(Base): + """Case review model.""" + + __tablename__ = "case_reviews" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + case_id: Mapped[int] = mapped_column(Integer, ForeignKey("cases.id"), index=True) + reviewer_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), index=True) + review_type: Mapped[ReviewType] = mapped_column(SQLEnum(ReviewType)) + opinion: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + score: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + status: Mapped[ReviewStatus] = mapped_column( + SQLEnum(ReviewStatus), + default=ReviewStatus.PENDING + ) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + reviewed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + + def __init__( + self, + case_id: int, + reviewer_id: int, + review_type: ReviewType, + opinion: Optional[str] = None, + score: Optional[int] = None, + status: ReviewStatus = ReviewStatus.PENDING, + **kwargs + ): + self.case_id = case_id + self.reviewer_id = reviewer_id + self.review_type = review_type + self.opinion = opinion + self.score = score + self.status = status + self.created_at = datetime.utcnow() diff --git a/backend/app/models/contract.py b/backend/app/models/contract.py new file mode 100644 index 0000000..00bcac2 --- /dev/null +++ b/backend/app/models/contract.py @@ -0,0 +1,131 @@ +"""Contract model.""" +import enum +from datetime import date, datetime +from typing import Optional + +from sqlalchemy import String, Text, Date, DateTime, Enum as SQLEnum, Integer, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.core.database import Base + + +class ContractStatus(str, enum.Enum): + """Contract status enumeration.""" + DRAFT = "draft" + PENDING_APPROVAL = "pending_approval" + APPROVED = "approved" + PENDING_SIGNATURE = "pending_signature" + SIGNED = "signed" + ARCHIVED = "archived" + REJECTED = "rejected" + + +class Contract(Base): + """Contract model.""" + + __tablename__ = "contracts" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + template_id: Mapped[Optional[int]] = mapped_column(Integer, ForeignKey("contract_templates.id"), nullable=True) + title: Mapped[str] = mapped_column(String(200), index=True) + contract_number: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + party_a: Mapped[str] = mapped_column(String(100)) + party_b: Mapped[str] = mapped_column(String(100)) + content: Mapped[str] = mapped_column(Text) + status: Mapped[ContractStatus] = mapped_column( + SQLEnum(ContractStatus), + default=ContractStatus.DRAFT + ) + effective_date: Mapped[Optional[date]] = mapped_column(Date, nullable=True) + expiry_date: Mapped[Optional[date]] = mapped_column(Date, nullable=True) + file_path: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + created_by: Mapped[int] = mapped_column(Integer, ForeignKey("users.id")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + + def __init__( + self, + title: str, + party_a: str, + party_b: str, + content: str, + created_by: int, + template_id: Optional[int] = None, + contract_number: Optional[str] = None, + status: ContractStatus = ContractStatus.DRAFT, + effective_date: Optional[date] = None, + expiry_date: Optional[date] = None, + file_path: Optional[str] = None, + **kwargs + ): + self.title = title + self.party_a = party_a + self.party_b = party_b + self.content = content + self.created_by = created_by + self.template_id = template_id + self.contract_number = contract_number + self.status = status + self.effective_date = effective_date + self.expiry_date = expiry_date + self.file_path = file_path + self.created_at = datetime.utcnow() + self.updated_at = datetime.utcnow() + + +class ContractTemplate(Base): + """Contract template model.""" + + __tablename__ = "contract_templates" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = mapped_column(String(100)) + contract_type: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + content: Mapped[str] = mapped_column(Text) + variables: Mapped[Optional[dict]] = mapped_column(String, nullable=True) # JSON string + created_by: Mapped[int] = mapped_column(Integer, ForeignKey("users.id")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + + def __init__( + self, + name: str, + content: str, + created_by: int, + contract_type: Optional[str] = None, + variables: Optional[dict] = None, + **kwargs + ): + self.name = name + self.content = content + self.created_by = created_by + self.contract_type = contract_type + self.variables = variables + self.created_at = datetime.utcnow() + + +class ContractApproval(Base): + """Contract approval model.""" + + __tablename__ = "contract_approvals" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + contract_id: Mapped[int] = mapped_column(Integer, ForeignKey("contracts.id"), index=True) + approver_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), index=True) + status: Mapped[ContractStatus] = mapped_column(SQLEnum(ContractStatus)) + comment: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + approved_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + + def __init__( + self, + contract_id: int, + approver_id: int, + status: ContractStatus, + comment: Optional[str] = None, + **kwargs + ): + self.contract_id = contract_id + self.approver_id = approver_id + self.status = status + self.comment = comment + self.created_at = datetime.utcnow() diff --git a/backend/app/models/law.py b/backend/app/models/law.py new file mode 100644 index 0000000..824d2da --- /dev/null +++ b/backend/app/models/law.py @@ -0,0 +1,106 @@ +"""Law model for legal regulations.""" +import enum +from datetime import date, datetime +from typing import Optional, List + +from sqlalchemy import String, Text, Date, DateTime, Enum as SQLEnum, Integer, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.core.database import Base + + +class LawType(str, enum.Enum): + """Law type enumeration.""" + LAW = "law" # 法律 + REGULATION = "regulation" # 法规 + RULE = "rule" # 规章 + JUDICIAL_INTERPRETATION = "judicial_interpretation" # 司法解释 + + +class LawStatus(str, enum.Enum): + """Law status enumeration.""" + EFFECTIVE = "effective" # 有效 + REVOKED = "revoked" # 废止 + AMENDED = "amended" # 修订 + + +class Law(Base): + """Law model for legal regulations.""" + + __tablename__ = "laws" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + title: Mapped[str] = mapped_column(String(200), index=True) + law_type: Mapped[LawType] = mapped_column(SQLEnum(LawType), nullable=False) + promulgation_date: Mapped[date] = mapped_column(Date, nullable=False) + effective_date: Mapped[date] = mapped_column(Date, nullable=False) + status: Mapped[LawStatus] = mapped_column( + SQLEnum(LawStatus), + nullable=False + ) + issuing_authority: Mapped[str] = mapped_column(String(100)) + content: Mapped[str] = mapped_column(Text) + document_number: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + + # Relationships + articles: Mapped[List["LawArticle"]] = relationship( + back_populates="law", cascade="all, delete-orphan" + ) + + def __init__( + self, + title: str, + law_type: LawType, + promulgation_date: date, + effective_date: date, + issuing_authority: str, + content: str, + status: LawStatus = LawStatus.EFFECTIVE, + document_number: Optional[str] = None, + **kwargs + ): + self.title = title + self.law_type = law_type + self.promulgation_date = promulgation_date + self.effective_date = effective_date + self.status = status + self.issuing_authority = issuing_authority + self.content = content + self.document_number = document_number + self.created_at = datetime.utcnow() + self.updated_at = datetime.utcnow() + + def __repr__(self) -> str: + return f"" + + +class LawArticle(Base): + """Law article model for individual articles.""" + + __tablename__ = "law_articles" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + law_id: Mapped[int] = mapped_column(Integer, ForeignKey("laws.id"), index=True) + article_number: Mapped[str] = mapped_column(String(20)) + content: Mapped[str] = mapped_column(Text) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + + # Relationships + law: Mapped["Law"] = relationship(back_populates="articles") + + def __init__( + self, + law_id: int, + article_number: str, + content: str, + **kwargs + ): + self.law_id = law_id + self.article_number = article_number + self.content = content + self.created_at = datetime.utcnow() + + def __repr__(self) -> str: + return f"" diff --git a/backend/app/models/signature.py b/backend/app/models/signature.py new file mode 100644 index 0000000..e49d60b --- /dev/null +++ b/backend/app/models/signature.py @@ -0,0 +1,119 @@ +"""Signature model for electronic signatures.""" +import enum +from datetime import datetime +from typing import Optional + +from sqlalchemy import String, Text, DateTime, Enum as SQLEnum, Integer, ForeignKey, JSON +from sqlalchemy.orm import Mapped, mapped_column + +from app.core.database import Base + + +class SignatureStatus(str, enum.Enum): + """Signature status enumeration.""" + PENDING = "pending" + SIGNED = "signed" + REJECTED = "rejected" + EXPIRED = "expired" + + +class SignatureRequest(Base): + """Signature request model.""" + + __tablename__ = "signature_requests" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + contract_id: Mapped[int] = mapped_column(Integer, ForeignKey("contracts.id"), index=True) + requester_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), index=True) + signer_name: Mapped[str] = mapped_column(String(100)) + signer_email: Mapped[str] = mapped_column(String(100)) + status: Mapped[SignatureStatus] = mapped_column( + SQLEnum(SignatureStatus), + default=SignatureStatus.PENDING + ) + token: Mapped[str] = mapped_column(String(64), unique=True, index=True) + expires_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + + def __init__( + self, + contract_id: int, + requester_id: int, + signer_name: str, + signer_email: str, + token: str, + expires_at: datetime, + status: SignatureStatus = SignatureStatus.PENDING, + **kwargs + ): + self.contract_id = contract_id + self.requester_id = requester_id + self.signer_name = signer_name + self.signer_email = signer_email + self.token = token + self.expires_at = expires_at + self.status = status + self.created_at = datetime.utcnow() + + +class Signature(Base): + """Signature record model.""" + + __tablename__ = "signatures" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + request_id: Mapped[int] = mapped_column(Integer, ForeignKey("signature_requests.id"), index=True) + signer_name: Mapped[str] = mapped_column(String(100)) + signature_data: Mapped[str] = mapped_column(Text) # Base64 image or coordinates + ip_address: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + user_agent: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + verification_hash: Mapped[str] = mapped_column(String(64)) # Document fingerprint + signed_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + + def __init__( + self, + request_id: int, + signer_name: str, + signature_data: str, + verification_hash: str, + ip_address: Optional[str] = None, + user_agent: Optional[str] = None, + **kwargs + ): + self.request_id = request_id + self.signer_name = signer_name + self.signature_data = signature_data + self.verification_hash = verification_hash + self.ip_address = ip_address + self.user_agent = user_agent + self.signed_at = datetime.utcnow() + + +class SignatureAudit(Base): + """Signature audit log model.""" + + __tablename__ = "signature_audits" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + signature_id: Mapped[Optional[int]] = mapped_column(Integer, ForeignKey("signatures.id"), nullable=True) + request_id: Mapped[Optional[int]] = mapped_column(Integer, ForeignKey("signature_requests.id"), nullable=True) + action: Mapped[str] = mapped_column(String(50)) + details: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) + ip_address: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + + def __init__( + self, + action: str, + signature_id: Optional[int] = None, + request_id: Optional[int] = None, + details: Optional[dict] = None, + ip_address: Optional[str] = None, + **kwargs + ): + self.action = action + self.signature_id = signature_id + self.request_id = request_id + self.details = details + self.ip_address = ip_address + self.created_at = datetime.utcnow() diff --git a/backend/app/models/user.py b/backend/app/models/user.py new file mode 100644 index 0000000..293f663 --- /dev/null +++ b/backend/app/models/user.py @@ -0,0 +1,64 @@ +"""User model for authentication and authorization.""" +import enum +from datetime import datetime +from typing import Optional + +from sqlalchemy import String, Boolean, DateTime, Enum as SQLEnum +from sqlalchemy.orm import Mapped, mapped_column + +from app.core.database import Base + + +class UserRole(str, enum.Enum): + """User role enumeration.""" + ADMIN = "admin" + LAWYER = "lawyer" + REVIEWER = "reviewer" + CLIENT = "client" + + +class User(Base): + """User model for authentication.""" + + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + username: Mapped[str] = mapped_column(String(50), unique=True, index=True) + email: Mapped[str] = mapped_column(String(100), unique=True, index=True) + hashed_password: Mapped[str] = mapped_column(String(255)) + full_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + role: Mapped[UserRole] = mapped_column( + SQLEnum(UserRole), + nullable=False + ) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False + ) + + def __init__( + self, + username: str, + email: str, + hashed_password: str, + full_name: Optional[str] = None, + role: UserRole = UserRole.CLIENT, + is_active: bool = True, + **kwargs + ): + self.username = username + self.email = email + self.hashed_password = hashed_password + self.full_name = full_name + self.role = role + self.is_active = is_active + self.created_at = datetime.utcnow() + self.updated_at = datetime.utcnow() + + def __repr__(self) -> str: + return f"" diff --git a/backend/app/schemas/__init__.py b/backend/app/schemas/__init__.py new file mode 100644 index 0000000..8d2fd85 --- /dev/null +++ b/backend/app/schemas/__init__.py @@ -0,0 +1 @@ +# Schemas package diff --git a/backend/app/schemas/analysis.py b/backend/app/schemas/analysis.py new file mode 100644 index 0000000..621422e --- /dev/null +++ b/backend/app/schemas/analysis.py @@ -0,0 +1,87 @@ +"""Schemas for Analysis module.""" +from datetime import date, datetime +from typing import List, Optional +from enum import Enum + +from pydantic import BaseModel, Field + + +class AnalysisStatusEnum(str, Enum): + DRAFT = "draft" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + + +class LegalAnalysisBase(BaseModel): + """Base schema for LegalAnalysis.""" + title: str = Field(..., max_length=200) + case_description: str + + +class LegalAnalysisCreate(LegalAnalysisBase): + """Schema for creating a legal analysis.""" + legal_basis: Optional[dict] = None + + +class LegalAnalysisUpdate(BaseModel): + """Schema for updating a legal analysis.""" + title: Optional[str] = Field(None, max_length=200) + case_description: Optional[str] = None + legal_basis: Optional[dict] = None + analysis_content: Optional[str] = None + conclusion: Optional[str] = None + status: Optional[AnalysisStatusEnum] = None + + +class LegalAnalysisResponse(LegalAnalysisBase): + """Schema for legal analysis response.""" + id: int + user_id: int + legal_basis: Optional[dict] = None + analysis_content: Optional[str] = None + conclusion: Optional[str] = None + status: AnalysisStatusEnum + created_at: datetime + updated_at: datetime + + class Config: + from_attributes = True + + +class CaseBase(BaseModel): + """Base schema for Case.""" + title: str = Field(..., max_length=200) + case_number: Optional[str] = Field(None, max_length=50) + court: Optional[str] = Field(None, max_length=100) + case_type: Optional[str] = Field(None, max_length=100) + judgment_date: Optional[date] = None + facts: Optional[str] = None + judgment: Optional[str] = None + reasoning: Optional[str] = None + + +class CaseCreate(CaseBase): + """Schema for creating a case.""" + pass + + +class CaseResponse(CaseBase): + """Schema for case response.""" + id: int + created_at: datetime + + class Config: + from_attributes = True + + +class GenerateAnalysisRequest(BaseModel): + """Schema for generating analysis with AI.""" + case_description: str + relevant_laws: Optional[List[str]] = None + + +class GenerateAnalysisResponse(BaseModel): + """Schema for generated analysis response.""" + analysis_content: str + legal_basis: List[str] = [] + conclusion: str diff --git a/backend/app/schemas/contract.py b/backend/app/schemas/contract.py new file mode 100644 index 0000000..e0f57d1 --- /dev/null +++ b/backend/app/schemas/contract.py @@ -0,0 +1,112 @@ +"""Schemas for Contract module.""" +from datetime import date, datetime +from typing import Optional +from enum import Enum + +from pydantic import BaseModel, Field + + +class ContractStatusEnum(str, Enum): + DRAFT = "draft" + PENDING_APPROVAL = "pending_approval" + APPROVED = "approved" + PENDING_SIGNATURE = "pending_signature" + SIGNED = "signed" + ARCHIVED = "archived" + REJECTED = "rejected" + + +class ContractBase(BaseModel): + """Base schema for Contract.""" + title: str = Field(..., max_length=200) + party_a: str = Field(..., max_length=100) + party_b: str = Field(..., max_length=100) + content: str + contract_number: Optional[str] = Field(None, max_length=50) + effective_date: Optional[date] = None + expiry_date: Optional[date] = None + + +class ContractCreate(ContractBase): + """Schema for creating a contract.""" + template_id: Optional[int] = None + + +class ContractUpdate(BaseModel): + """Schema for updating a contract.""" + title: Optional[str] = Field(None, max_length=200) + party_a: Optional[str] = Field(None, max_length=100) + party_b: Optional[str] = Field(None, max_length=100) + content: Optional[str] = None + contract_number: Optional[str] = Field(None, max_length=50) + effective_date: Optional[date] = None + expiry_date: Optional[date] = None + status: Optional[ContractStatusEnum] = None + + +class ContractResponse(ContractBase): + """Schema for contract response.""" + id: int + template_id: Optional[int] = None + status: ContractStatusEnum + file_path: Optional[str] = None + created_by: int + created_at: datetime + updated_at: datetime + + class Config: + from_attributes = True + + +class ContractTemplateBase(BaseModel): + """Base schema for ContractTemplate.""" + name: str = Field(..., max_length=100) + content: str + contract_type: Optional[str] = Field(None, max_length=50) + + +class ContractTemplateCreate(ContractTemplateBase): + """Schema for creating a contract template.""" + pass + + +class ContractTemplateResponse(ContractTemplateBase): + """Schema for contract template response.""" + id: int + created_by: int + created_at: datetime + + class Config: + from_attributes = True + + +class ApprovalRequest(BaseModel): + """Schema for approval request.""" + comment: Optional[str] = None + + +class ApprovalResponse(BaseModel): + """Schema for approval response.""" + id: int + contract_id: int + approver_id: int + status: ContractStatusEnum + comment: Optional[str] = None + created_at: datetime + approved_at: Optional[datetime] = None + + class Config: + from_attributes = True + + +class ContractReviewRequest(BaseModel): + """Schema for AI contract review request.""" + contract_content: str + contract_type: Optional[str] = None + + +class ContractReviewResponse(BaseModel): + """Schema for AI contract review response.""" + review_result: str + risks: list = [] + suggestions: list = [] diff --git a/backend/app/schemas/law.py b/backend/app/schemas/law.py new file mode 100644 index 0000000..8fbdcb6 --- /dev/null +++ b/backend/app/schemas/law.py @@ -0,0 +1,107 @@ +"""Schemas for Law module.""" +from datetime import date, datetime +from typing import List, Optional +from enum import Enum + +from pydantic import BaseModel, Field + + +class LawTypeEnum(str, Enum): + LAW = "law" + REGULATION = "regulation" + RULE = "rule" + JUDICIAL_INTERPRETATION = "judicial_interpretation" + + +class LawStatusEnum(str, Enum): + EFFECTIVE = "effective" + REVOKED = "revoked" + AMENDED = "amended" + + +class LawBase(BaseModel): + """Base schema for Law.""" + title: str = Field(..., max_length=200) + law_type: LawTypeEnum + promulgation_date: date + effective_date: date + issuing_authority: str = Field(..., max_length=100) + content: str + status: LawStatusEnum = LawStatusEnum.EFFECTIVE + document_number: Optional[str] = Field(None, max_length=50) + + +class LawCreate(LawBase): + """Schema for creating a law.""" + pass + + +class LawUpdate(BaseModel): + """Schema for updating a law.""" + title: Optional[str] = Field(None, max_length=200) + law_type: Optional[LawTypeEnum] = None + promulgation_date: Optional[date] = None + effective_date: Optional[date] = None + status: Optional[LawStatusEnum] = None + issuing_authority: Optional[str] = Field(None, max_length=100) + content: Optional[str] = None + document_number: Optional[str] = Field(None, max_length=50) + + +class LawArticleBase(BaseModel): + """Base schema for LawArticle.""" + article_number: str = Field(..., max_length=20) + content: str + + +class LawArticleCreate(LawArticleBase): + """Schema for creating a law article.""" + law_id: int + + +class LawArticleResponse(LawArticleBase): + """Schema for law article response.""" + id: int + law_id: int + created_at: datetime + + class Config: + from_attributes = True + + +class LawResponse(LawBase): + """Schema for law response.""" + id: int + created_at: datetime + updated_at: datetime + articles: List[LawArticleResponse] = [] + + class Config: + from_attributes = True + + +class LawListResponse(BaseModel): + """Schema for law list response.""" + items: List[LawResponse] + total: int + skip: int + limit: int + + +class LawSearchRequest(BaseModel): + """Schema for law search request.""" + keyword: str = Field(..., min_length=1) + limit: int = Field(10, ge=1, le=50) + + +class LegalQARequest(BaseModel): + """Schema for legal QA request.""" + question: str = Field(..., min_length=1) + context: Optional[str] = None + + +class LegalQAResponse(BaseModel): + """Schema for legal QA response.""" + question: str + answer: str + references: List[str] = [] diff --git a/backend/app/schemas/signature.py b/backend/app/schemas/signature.py new file mode 100644 index 0000000..1c6d523 --- /dev/null +++ b/backend/app/schemas/signature.py @@ -0,0 +1,61 @@ +"""Schemas for Signature module.""" +from datetime import datetime +from typing import Optional +from enum import Enum + +from pydantic import BaseModel, Field + + +class SignatureStatusEnum(str, Enum): + PENDING = "pending" + SIGNED = "signed" + REJECTED = "rejected" + EXPIRED = "expired" + + +class SignatureRequestCreate(BaseModel): + """Schema for creating a signature request.""" + contract_id: int + signer_name: str = Field(..., max_length=100) + signer_email: str = Field(..., max_length=100) + expire_hours: Optional[int] = None + + +class SignatureRequestResponse(BaseModel): + """Schema for signature request response.""" + id: int + contract_id: int + requester_id: int + signer_name: str + signer_email: str + status: SignatureStatusEnum + token: str + expires_at: datetime + created_at: datetime + + class Config: + from_attributes = True + + +class SignatureSignRequest(BaseModel): + """Schema for signing a document.""" + signature_data: str # Base64 image or coordinates + + +class SignatureResponse(BaseModel): + """Schema for signature response.""" + id: int + request_id: int + signer_name: str + signed_at: datetime + verification_hash: str + + class Config: + from_attributes = True + + +class SignatureVerifyResponse(BaseModel): + """Schema for signature verification response.""" + valid: bool + signed_at: Optional[datetime] = None + signer_name: Optional[str] = None diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py new file mode 100644 index 0000000..a70b302 --- /dev/null +++ b/backend/app/services/__init__.py @@ -0,0 +1 @@ +# Services package diff --git a/backend/app/services/analysis_service.py b/backend/app/services/analysis_service.py new file mode 100644 index 0000000..51b76b1 --- /dev/null +++ b/backend/app/services/analysis_service.py @@ -0,0 +1,134 @@ +"""Analysis service for legal research.""" +from typing import List, Optional +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.analysis import LegalAnalysis, Case, AnalysisStatus + + +class AnalysisService: + """Service for legal analysis operations.""" + + def __init__(self, db: AsyncSession): + self.db = db + + async def create_analysis( + self, + user_id: int, + title: str, + case_description: str, + legal_basis: Optional[dict] = None, + ) -> LegalAnalysis: + """Create a new legal analysis.""" + analysis = LegalAnalysis( + user_id=user_id, + title=title, + case_description=case_description, + legal_basis=legal_basis, + ) + self.db.add(analysis) + await self.db.flush() + await self.db.refresh(analysis) + return analysis + + async def get_analysis_by_id(self, analysis_id: int) -> Optional[LegalAnalysis]: + """Get an analysis by ID.""" + result = await self.db.execute( + select(LegalAnalysis).where(LegalAnalysis.id == analysis_id) + ) + return result.scalar_one_or_none() + + async def get_analyses_by_user( + self, + user_id: int, + skip: int = 0, + limit: int = 20, + ) -> List[LegalAnalysis]: + """Get analyses by user ID.""" + result = await self.db.execute( + select(LegalAnalysis) + .where(LegalAnalysis.user_id == user_id) + .offset(skip) + .limit(limit) + .order_by(LegalAnalysis.created_at.desc()) + ) + return list(result.scalars().all()) + + async def update_analysis( + self, + analysis_id: int, + **kwargs + ) -> Optional[LegalAnalysis]: + """Update an analysis.""" + analysis = await self.get_analysis_by_id(analysis_id) + if not analysis: + return None + + for key, value in kwargs.items(): + if hasattr(analysis, key) and value is not None: + setattr(analysis, key, value) + + await self.db.flush() + await self.db.refresh(analysis) + return analysis + + async def delete_analysis(self, analysis_id: int) -> bool: + """Delete an analysis.""" + analysis = await self.get_analysis_by_id(analysis_id) + if not analysis: + return False + + await self.db.delete(analysis) + await self.db.flush() + return True + + +class CaseService: + """Service for case operations.""" + + def __init__(self, db: AsyncSession): + self.db = db + + async def create_case( + self, + title: str, + case_number: Optional[str] = None, + court: Optional[str] = None, + case_type: Optional[str] = None, + **kwargs + ) -> Case: + """Create a new case.""" + case = Case( + title=title, + case_number=case_number, + court=court, + case_type=case_type, + **kwargs + ) + self.db.add(case) + await self.db.flush() + await self.db.refresh(case) + return case + + async def get_case_by_id(self, case_id: int) -> Optional[Case]: + """Get a case by ID.""" + result = await self.db.execute( + select(Case).where(Case.id == case_id) + ) + return result.scalar_one_or_none() + + async def get_cases_list( + self, + skip: int = 0, + limit: int = 20, + case_type: Optional[str] = None, + ) -> List[Case]: + """Get list of cases.""" + query = select(Case) + + if case_type: + query = query.where(Case.case_type == case_type) + + query = query.offset(skip).limit(limit).order_by(Case.created_at.desc()) + result = await self.db.execute(query) + return list(result.scalars().all()) diff --git a/backend/app/services/contract_service.py b/backend/app/services/contract_service.py new file mode 100644 index 0000000..2c4851a --- /dev/null +++ b/backend/app/services/contract_service.py @@ -0,0 +1,171 @@ +"""Contract service for contract management.""" +from typing import List, Optional +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.contract import Contract, ContractTemplate, ContractApproval, ContractStatus + + +class ContractService: + """Service for contract operations.""" + + def __init__(self, db: AsyncSession): + self.db = db + + async def create_contract( + self, + title: str, + party_a: str, + party_b: str, + content: str, + created_by: int, + **kwargs + ) -> Contract: + """Create a new contract.""" + contract = Contract( + title=title, + party_a=party_a, + party_b=party_b, + content=content, + created_by=created_by, + **kwargs + ) + self.db.add(contract) + await self.db.flush() + await self.db.refresh(contract) + return contract + + async def get_contract_by_id(self, contract_id: int) -> Optional[Contract]: + """Get a contract by ID.""" + result = await self.db.execute( + select(Contract).where(Contract.id == contract_id) + ) + return result.scalar_one_or_none() + + async def get_contracts_list( + self, + skip: int = 0, + limit: int = 20, + status: Optional[ContractStatus] = None, + created_by: Optional[int] = None, + ) -> List[Contract]: + """Get list of contracts.""" + query = select(Contract) + + if status: + query = query.where(Contract.status == status) + if created_by: + query = query.where(Contract.created_by == created_by) + + query = query.offset(skip).limit(limit).order_by(Contract.created_at.desc()) + result = await self.db.execute(query) + return list(result.scalars().all()) + + async def update_contract( + self, + contract_id: int, + **kwargs + ) -> Optional[Contract]: + """Update a contract.""" + contract = await self.get_contract_by_id(contract_id) + if not contract: + return None + + for key, value in kwargs.items(): + if hasattr(contract, key) and value is not None: + setattr(contract, key, value) + + await self.db.flush() + await self.db.refresh(contract) + return contract + + async def submit_for_approval(self, contract_id: int) -> Optional[Contract]: + """Submit a contract for approval.""" + return await self.update_contract( + contract_id, + status=ContractStatus.PENDING_APPROVAL + ) + + async def approve_contract( + self, + contract_id: int, + approver_id: int, + comment: Optional[str] = None, + ) -> ContractApproval: + """Approve a contract.""" + approval = ContractApproval( + contract_id=contract_id, + approver_id=approver_id, + status=ContractStatus.APPROVED, + comment=comment, + ) + self.db.add(approval) + + # Update contract status + await self.update_contract(contract_id, status=ContractStatus.APPROVED) + + await self.db.flush() + await self.db.refresh(approval) + return approval + + async def reject_contract( + self, + contract_id: int, + approver_id: int, + comment: Optional[str] = None, + ) -> ContractApproval: + """Reject a contract.""" + approval = ContractApproval( + contract_id=contract_id, + approver_id=approver_id, + status=ContractStatus.REJECTED, + comment=comment, + ) + self.db.add(approval) + + # Update contract status + await self.update_contract(contract_id, status=ContractStatus.REJECTED) + + await self.db.flush() + await self.db.refresh(approval) + return approval + + +class ContractTemplateService: + """Service for contract template operations.""" + + def __init__(self, db: AsyncSession): + self.db = db + + async def create_template( + self, + name: str, + content: str, + created_by: int, + contract_type: Optional[str] = None, + ) -> ContractTemplate: + """Create a contract template.""" + template = ContractTemplate( + name=name, + content=content, + created_by=created_by, + contract_type=contract_type, + ) + self.db.add(template) + await self.db.flush() + await self.db.refresh(template) + return template + + async def get_templates_list( + self, + skip: int = 0, + limit: int = 20, + ) -> List[ContractTemplate]: + """Get list of templates.""" + result = await self.db.execute( + select(ContractTemplate) + .offset(skip) + .limit(limit) + .order_by(ContractTemplate.created_at.desc()) + ) + return list(result.scalars().all()) diff --git a/backend/app/services/law_service.py b/backend/app/services/law_service.py new file mode 100644 index 0000000..f29e96f --- /dev/null +++ b/backend/app/services/law_service.py @@ -0,0 +1,137 @@ +"""Law service for CRUD operations.""" +from datetime import date +from typing import List, Optional +from sqlalchemy import select, or_ +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.law import Law, LawArticle, LawType, LawStatus + + +class LawService: + """Service for law CRUD operations.""" + + def __init__(self, db: AsyncSession): + self.db = db + + async def create_law( + self, + title: str, + law_type: LawType, + promulgation_date: date, + effective_date: date, + issuing_authority: str, + content: str, + status: LawStatus = LawStatus.EFFECTIVE, + document_number: Optional[str] = None, + ) -> Law: + """Create a new law.""" + law = Law( + title=title, + law_type=law_type, + promulgation_date=promulgation_date, + effective_date=effective_date, + status=status, + issuing_authority=issuing_authority, + content=content, + document_number=document_number, + ) + self.db.add(law) + await self.db.flush() + await self.db.refresh(law) + return law + + async def get_law_by_id(self, law_id: int) -> Optional[Law]: + """Get a law by ID.""" + result = await self.db.execute( + select(Law).where(Law.id == law_id) + ) + return result.scalar_one_or_none() + + async def get_laws_list( + self, + skip: int = 0, + limit: int = 20, + law_type: Optional[LawType] = None, + status: Optional[LawStatus] = None, + ) -> List[Law]: + """Get list of laws with optional filters.""" + query = select(Law) + + if law_type: + query = query.where(Law.law_type == law_type) + if status: + query = query.where(Law.status == status) + + query = query.offset(skip).limit(limit).order_by(Law.created_at.desc()) + result = await self.db.execute(query) + return list(result.scalars().all()) + + async def update_law( + self, + law_id: int, + **kwargs + ) -> Optional[Law]: + """Update a law.""" + law = await self.get_law_by_id(law_id) + if not law: + return None + + for key, value in kwargs.items(): + if hasattr(law, key) and value is not None: + setattr(law, key, value) + + await self.db.flush() + await self.db.refresh(law) + return law + + async def delete_law(self, law_id: int) -> bool: + """Delete a law.""" + law = await self.get_law_by_id(law_id) + if not law: + return False + + await self.db.delete(law) + await self.db.flush() + return True + + async def search_laws_by_keyword( + self, + keyword: str, + limit: int = 10 + ) -> List[Law]: + """Search laws by keyword in title or content.""" + query = select(Law).where( + or_( + Law.title.contains(keyword), + Law.content.contains(keyword), + ) + ).limit(limit) + + result = await self.db.execute(query) + return list(result.scalars().all()) + + async def create_article( + self, + law_id: int, + article_number: str, + content: str, + ) -> LawArticle: + """Create a law article.""" + article = LawArticle( + law_id=law_id, + article_number=article_number, + content=content, + ) + self.db.add(article) + await self.db.flush() + await self.db.refresh(article) + return article + + async def get_articles_by_law_id(self, law_id: int) -> List[LawArticle]: + """Get articles for a law.""" + result = await self.db.execute( + select(LawArticle) + .where(LawArticle.law_id == law_id) + .order_by(LawArticle.id) + ) + return list(result.scalars().all()) diff --git a/backend/app/services/llm_service.py b/backend/app/services/llm_service.py new file mode 100644 index 0000000..44cbb2d --- /dev/null +++ b/backend/app/services/llm_service.py @@ -0,0 +1,123 @@ +"""LLM service for AI-powered features.""" +from typing import List, Dict, Any, Optional +import httpx + +from app.core.config import settings + + +class LLMService: + """Service for LLM API interactions.""" + + def __init__(self): + self.api_base = settings.LLM_API_BASE + self.api_key = settings.LLM_API_KEY + self.model = settings.LLM_MODEL + + async def chat_completion( + self, + messages: List[Dict[str, str]], + temperature: float = 0.7, + max_tokens: int = 2000, + ) -> str: + """Get chat completion from LLM.""" + if not self.api_key: + # Return mock response for testing + return "这是一个模拟的法律回复。" + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.api_base}/chat/completions", + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + json={ + "model": self.model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + }, + timeout=60.0, + ) + response.raise_for_status() + data = response.json() + return data["choices"][0]["message"]["content"] + + async def legal_qa( + self, + question: str, + context: Optional[str] = None, + ) -> str: + """Answer a legal question.""" + system_prompt = """你是一个专业的法律助手。请根据以下原则回答问题: +1. 准确引用相关法律条文 +2. 解释法律条文的含义和适用条件 +3. 提供实用的法律建议 +4. 如不确定,请明确说明""" + + messages = [ + {"role": "system", "content": system_prompt}, + ] + + if context: + messages.append({ + "role": "user", + "content": f"参考以下法律内容:\n{context}\n\n问题:{question}" + }) + else: + messages.append({"role": "user", "content": question}) + + return await self.chat_completion(messages) + + async def analyze_legal_issue( + self, + issue_description: str, + relevant_laws: Optional[List[str]] = None, + ) -> str: + """Analyze a legal issue and provide insights.""" + system_prompt = """你是一个专业的法律分析师。请根据以下结构分析法律问题: +1. 问题定性 +2. 相关法律依据 +3. 法律分析 +4. 风险提示 +5. 建议""" + + user_content = f"请分析以下法律问题:\n{issue_description}" + + if relevant_laws: + user_content += f"\n\n相关法律:\n" + "\n".join(relevant_laws) + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_content}, + ] + + return await self.chat_completion(messages, temperature=0.5) + + async def review_contract( + self, + contract_content: str, + contract_type: Optional[str] = None, + ) -> str: + """Review a contract and identify potential issues.""" + system_prompt = """你是一个专业的合同审查专家。请审查合同并: +1. 识别潜在风险条款 +2. 指出不明确的条款 +3. 提出修改建议 +4. 检查法律合规性""" + + user_content = f"请审查以下合同内容:\n{contract_content}" + + if contract_type: + user_content = f"合同类型:{contract_type}\n\n{user_content}" + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_content}, + ] + + return await self.chat_completion(messages, max_tokens=3000) + + +# Singleton instance +llm_service = LLMService() diff --git a/backend/app/services/signature_service.py b/backend/app/services/signature_service.py new file mode 100644 index 0000000..c4735b3 --- /dev/null +++ b/backend/app/services/signature_service.py @@ -0,0 +1,174 @@ +"""Signature service for electronic signatures.""" +import hashlib +import secrets +from datetime import datetime, timedelta +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.config import settings +from app.models.signature import ( + SignatureRequest, + Signature, + SignatureAudit, + SignatureStatus, +) +from app.models.contract import Contract + + +class SignatureService: + """Service for electronic signature operations.""" + + def __init__(self, db: AsyncSession): + self.db = db + + def generate_token(self) -> str: + """Generate a secure random token.""" + return secrets.token_urlsafe(32) + + def generate_verification_hash(self, content: str) -> str: + """Generate a verification hash for document content.""" + return hashlib.sha256(content.encode()).hexdigest() + + async def create_signature_request( + self, + contract_id: int, + requester_id: int, + signer_name: str, + signer_email: str, + expire_hours: int = None, + ) -> SignatureRequest: + """Create a signature request.""" + expire_hours = expire_hours or settings.SIGNATURE_TOKEN_EXPIRE_HOURS + + request = SignatureRequest( + contract_id=contract_id, + requester_id=requester_id, + signer_name=signer_name, + signer_email=signer_email, + token=self.generate_token(), + expires_at=datetime.utcnow() + timedelta(hours=expire_hours), + ) + self.db.add(request) + + # Create audit log + await self.create_audit_log( + action="request_created", + request_id=request.id, + details={"signer_email": signer_email} + ) + + await self.db.flush() + await self.db.refresh(request) + return request + + async def get_signature_request_by_token( + self, + token: str + ) -> Optional[SignatureRequest]: + """Get a signature request by token.""" + result = await self.db.execute( + select(SignatureRequest).where(SignatureRequest.token == token) + ) + return result.scalar_one_or_none() + + async def get_signature_request_by_id( + self, + request_id: int + ) -> Optional[SignatureRequest]: + """Get a signature request by ID.""" + result = await self.db.execute( + select(SignatureRequest).where(SignatureRequest.id == request_id) + ) + return result.scalar_one_or_none() + + async def is_request_valid(self, request: SignatureRequest) -> bool: + """Check if a signature request is valid.""" + if request.status != SignatureStatus.PENDING: + return False + if request.expires_at < datetime.utcnow(): + return False + return True + + async def sign_document( + self, + request: SignatureRequest, + signature_data: str, + ip_address: Optional[str] = None, + user_agent: Optional[str] = None, + ) -> Signature: + """Sign a document.""" + # Get contract content for hash + result = await self.db.execute( + select(Contract).where(Contract.id == request.contract_id) + ) + contract = result.scalar_one_or_none() + + if not contract: + raise ValueError("Contract not found") + + verification_hash = self.generate_verification_hash(contract.content) + + signature = Signature( + request_id=request.id, + signer_name=request.signer_name, + signature_data=signature_data, + verification_hash=verification_hash, + ip_address=ip_address, + user_agent=user_agent, + ) + self.db.add(signature) + + # Update request status + request.status = SignatureStatus.SIGNED + + # Create audit log + await self.create_audit_log( + action="document_signed", + signature_id=signature.id, + request_id=request.id, + ip_address=ip_address + ) + + await self.db.flush() + await self.db.refresh(signature) + return signature + + async def verify_signature( + self, + signature_id: int, + content: str + ) -> bool: + """Verify a signature against document content.""" + result = await self.db.execute( + select(Signature).where(Signature.id == signature_id) + ) + signature = result.scalar_one_or_none() + + if not signature: + return False + + current_hash = self.generate_verification_hash(content) + return current_hash == signature.verification_hash + + async def create_audit_log( + self, + action: str, + signature_id: Optional[int] = None, + request_id: Optional[int] = None, + details: Optional[dict] = None, + ip_address: Optional[str] = None, + ) -> SignatureAudit: + """Create an audit log entry.""" + audit = SignatureAudit( + action=action, + signature_id=signature_id, + request_id=request_id, + details=details, + ip_address=ip_address, + ) + self.db.add(audit) + await self.db.flush() + await self.db.refresh(audit) + return audit diff --git a/backend/app/services/vector_service.py b/backend/app/services/vector_service.py new file mode 100644 index 0000000..0070572 --- /dev/null +++ b/backend/app/services/vector_service.py @@ -0,0 +1,107 @@ +"""Vector service for embedding and similarity search.""" +import numpy as np +from typing import List, Dict, Any, Optional +import httpx + +from app.core.config import settings + + +class VectorService: + """Service for vector embeddings and similarity search.""" + + def __init__(self): + self.api_base = settings.EMBEDDING_API_BASE + self.api_key = settings.EMBEDDING_API_KEY or settings.LLM_API_KEY + self.model = settings.EMBEDDING_MODEL + self.dimension = settings.EMBEDDING_DIMENSION + + async def get_embedding(self, text: str) -> List[float]: + """Get embedding for a text using external API.""" + if not self.api_key: + # Return mock embedding for testing + return [0.0] * self.dimension + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.api_base}/embeddings", + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + json={ + "model": self.model, + "input": text, + }, + timeout=30.0, + ) + response.raise_for_status() + data = response.json() + return data["data"][0]["embedding"] + + async def get_embeddings(self, texts: List[str]) -> List[List[float]]: + """Get embeddings for multiple texts.""" + if not self.api_key: + return [[0.0] * self.dimension for _ in texts] + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.api_base}/embeddings", + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + json={ + "model": self.model, + "input": texts, + }, + timeout=60.0, + ) + response.raise_for_status() + data = response.json() + return [item["embedding"] for item in data["data"]] + + def cosine_similarity( + self, + vec1: List[float], + vec2: List[float] + ) -> float: + """Calculate cosine similarity between two vectors.""" + arr1 = np.array(vec1) + arr2 = np.array(vec2) + + dot_product = np.dot(arr1, arr2) + norm1 = np.linalg.norm(arr1) + norm2 = np.linalg.norm(arr2) + + if norm1 == 0 or norm2 == 0: + return 0.0 + + return float(dot_product / (norm1 * norm2)) + + def search_similar( + self, + query_embedding: List[float], + stored_vectors: List[Dict[str, Any]], + top_k: int = 5 + ) -> List[Dict[str, Any]]: + """Search for similar vectors.""" + results = [] + + for item in stored_vectors: + similarity = self.cosine_similarity( + query_embedding, + item["embedding"] + ) + results.append({ + "id": item["id"], + "similarity": similarity, + }) + + # Sort by similarity descending + results.sort(key=lambda x: x["similarity"], reverse=True) + + return results[:top_k] + + +# Singleton instance +vector_service = VectorService() diff --git a/backend/pytest.ini b/backend/pytest.ini new file mode 100644 index 0000000..0afdb1b --- /dev/null +++ b/backend/pytest.ini @@ -0,0 +1,16 @@ +[tool:pytest] +asyncio_mode = auto +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +[coverage:run] +source = app +omit = app/__init__.py + +[coverage:report] +exclude_lines = + pragma: no cover + def __repr__ + raise NotImplementedError diff --git a/backend/requirements.txt b/backend/requirements.txt new file mode 100644 index 0000000..0fc2600 --- /dev/null +++ b/backend/requirements.txt @@ -0,0 +1,31 @@ +# FastAPI and dependencies +fastapi==0.110.0 +uvicorn[standard]==0.27.1 +python-multipart==0.0.9 + +# Database +sqlalchemy==2.0.25 +aiosqlite==0.19.0 + +# Vector search +sqlite-vss==0.1.2 + +# Pydantic +pydantic==2.6.0 +pydantic-settings==2.1.0 + +# Authentication +python-jose[cryptography]==3.3.0 +bcrypt==4.0.1 + +# HTTP client for external APIs +httpx==0.26.0 + +# Testing +pytest==7.4.4 +pytest-asyncio==0.23.4 +pytest-cov==4.1.0 + +# Utilities +python-dotenv==1.0.1 +numpy==1.26.4 diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py new file mode 100644 index 0000000..d4839a6 --- /dev/null +++ b/backend/tests/__init__.py @@ -0,0 +1 @@ +# Tests package diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 0000000..7f6878e --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,36 @@ +"""Pytest configuration and fixtures.""" +import asyncio +import pytest +import pytest_asyncio +from typing import AsyncGenerator +from sqlalchemy.ext.asyncio import ( + AsyncSession, + create_async_engine, + async_sessionmaker, +) + +from app.core.database import Base +from app.models.user import User +from app.models.law import Law, LawArticle + + +@pytest_asyncio.fixture +async def db_session(tmp_path) -> AsyncGenerator[AsyncSession, None]: + """Create a test database session.""" + db_file = tmp_path / "test.db" + engine = create_async_engine( + f"sqlite+aiosqlite:///{db_file}", + echo=False, + ) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + async_session = async_sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + + async with async_session() as session: + yield session + + await engine.dispose() diff --git a/backend/tests/unit/__init__.py b/backend/tests/unit/__init__.py new file mode 100644 index 0000000..4a5d263 --- /dev/null +++ b/backend/tests/unit/__init__.py @@ -0,0 +1 @@ +# Unit tests package diff --git a/backend/tests/unit/test_config.py b/backend/tests/unit/test_config.py new file mode 100644 index 0000000..c1b514b --- /dev/null +++ b/backend/tests/unit/test_config.py @@ -0,0 +1,60 @@ +"""Unit tests for core configuration.""" +import pytest +from app.core.config import Settings + + +class TestSettings: + """Test cases for Settings configuration.""" + + def test_default_settings(self): + """Test default configuration values.""" + settings = Settings() + + assert settings.APP_NAME == "AI Legal Assistant" + assert settings.APP_VERSION == "1.0.0" + assert settings.DEBUG is False + assert settings.DATABASE_URL == "sqlite+aiosqlite:///./data/legal_assistant.db" + + def test_custom_settings_from_env(self, monkeypatch): + """Test configuration from environment variables.""" + monkeypatch.setenv("DEBUG", "true") + monkeypatch.setenv("DATABASE_URL", "sqlite+aiosqlite:///./test.db") + monkeypatch.setenv("LLM_API_KEY", "test-key") + + settings = Settings() + + assert settings.DEBUG is True + assert settings.DATABASE_URL == "sqlite+aiosqlite:///./test.db" + assert settings.LLM_API_KEY == "test-key" + + def test_llm_settings(self, monkeypatch): + """Test LLM configuration.""" + monkeypatch.setenv("LLM_API_BASE", "https://api.example.com/v1") + monkeypatch.setenv("LLM_MODEL", "gpt-4") + + settings = Settings() + + assert settings.LLM_API_BASE == "https://api.example.com/v1" + assert settings.LLM_MODEL == "gpt-4" + + def test_embedding_settings(self, monkeypatch): + """Test embedding configuration.""" + monkeypatch.setenv("EMBEDDING_MODEL", "text-embedding-3-small") + monkeypatch.setenv("EMBEDDING_DIMENSION", "1536") + + settings = Settings() + + assert settings.EMBEDDING_MODEL == "text-embedding-3-small" + assert settings.EMBEDDING_DIMENSION == 1536 + + def test_jwt_settings(self, monkeypatch): + """Test JWT configuration.""" + monkeypatch.setenv("JWT_SECRET_KEY", "test-secret-key") + monkeypatch.setenv("JWT_ALGORITHM", "HS256") + monkeypatch.setenv("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "60") + + settings = Settings() + + assert settings.JWT_SECRET_KEY == "test-secret-key" + assert settings.JWT_ALGORITHM == "HS256" + assert settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES == 60 diff --git a/backend/tests/unit/test_contract_service.py b/backend/tests/unit/test_contract_service.py new file mode 100644 index 0000000..afaae2b --- /dev/null +++ b/backend/tests/unit/test_contract_service.py @@ -0,0 +1,100 @@ +"""Unit tests for Contract service.""" +import pytest +from datetime import date + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.contract import ContractStatus +from app.services.contract_service import ContractService, ContractTemplateService + + +class TestContractService: + """Test cases for Contract service.""" + + @pytest.mark.asyncio + async def test_create_contract(self, db_session: AsyncSession): + """Test creating a contract.""" + service = ContractService(db_session) + + contract = await service.create_contract( + title="测试合同", + party_a="甲方", + party_b="乙方", + content="合同内容", + created_by=1, + ) + + assert contract.id is not None + assert contract.title == "测试合同" + assert contract.status == ContractStatus.DRAFT + + @pytest.mark.asyncio + async def test_get_contract_by_id(self, db_session: AsyncSession): + """Test getting a contract by ID.""" + service = ContractService(db_session) + + created = await service.create_contract( + title="测试合同", + party_a="甲方", + party_b="乙方", + content="合同内容", + created_by=1, + ) + + contract = await service.get_contract_by_id(created.id) + + assert contract is not None + assert contract.title == "测试合同" + + @pytest.mark.asyncio + async def test_submit_for_approval(self, db_session: AsyncSession): + """Test submitting a contract for approval.""" + service = ContractService(db_session) + + contract = await service.create_contract( + title="测试合同", + party_a="甲方", + party_b="乙方", + content="合同内容", + created_by=1, + ) + + updated = await service.submit_for_approval(contract.id) + + assert updated.status == ContractStatus.PENDING_APPROVAL + + @pytest.mark.asyncio + async def test_approve_contract(self, db_session: AsyncSession): + """Test approving a contract.""" + service = ContractService(db_session) + + contract = await service.create_contract( + title="测试合同", + party_a="甲方", + party_b="乙方", + content="合同内容", + created_by=1, + ) + + approval = await service.approve_contract(contract.id, 1, "同意") + + assert approval.status == ContractStatus.APPROVED + + +class TestContractTemplateService: + """Test cases for ContractTemplate service.""" + + @pytest.mark.asyncio + async def test_create_template(self, db_session: AsyncSession): + """Test creating a contract template.""" + service = ContractTemplateService(db_session) + + template = await service.create_template( + name="测试模板", + content="模板内容", + created_by=1, + contract_type="销售合同", + ) + + assert template.id is not None + assert template.name == "测试模板" diff --git a/backend/tests/unit/test_database.py b/backend/tests/unit/test_database.py new file mode 100644 index 0000000..e9e8b05 --- /dev/null +++ b/backend/tests/unit/test_database.py @@ -0,0 +1,51 @@ +"""Unit tests for database configuration.""" +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker +from sqlalchemy import text + +from app.core.database import Base, get_db, init_db + + +class TestDatabase: + """Test cases for database configuration.""" + + @pytest.mark.asyncio + async def test_create_tables(self, tmp_path): + """Test creating database tables.""" + db_path = tmp_path / "test.db" + engine = create_async_engine( + f"sqlite+aiosqlite:///{db_path}", + echo=False + ) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + # Verify tables exist + async with engine.begin() as conn: + result = await conn.execute(text( + "SELECT name FROM sqlite_master WHERE type='table'" + )) + tables = [row[0] for row in result.fetchall()] + # At minimum, alembic_version table should exist after migration + # For now, just verify the connection works + assert len(tables) >= 0 + + await engine.dispose() + + @pytest.mark.asyncio + async def test_get_db_session(self, tmp_path): + """Test getting database session.""" + db_path = tmp_path / "test.db" + engine = create_async_engine( + f"sqlite+aiosqlite:///{db_path}", + echo=False + ) + async_session = async_sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + + async with async_session() as session: + assert isinstance(session, AsyncSession) + + await engine.dispose() diff --git a/backend/tests/unit/test_law_model.py b/backend/tests/unit/test_law_model.py new file mode 100644 index 0000000..6cd4c03 --- /dev/null +++ b/backend/tests/unit/test_law_model.py @@ -0,0 +1,70 @@ +"""Unit tests for Law model.""" +import pytest +from datetime import date + +from app.models.law import Law, LawArticle, LawType, LawStatus + + +class TestLawModel: + """Test cases for Law model.""" + + def test_law_creation(self): + """Test creating a law instance.""" + law = Law( + title="中华人民共和国民法典", + law_type=LawType.LAW, + promulgation_date=date(2020, 5, 28), + effective_date=date(2021, 1, 1), + status=LawStatus.EFFECTIVE, + issuing_authority="全国人民代表大会", + content="民法典全文...", + ) + + assert law.title == "中华人民共和国民法典" + assert law.law_type == LawType.LAW + assert law.promulgation_date == date(2020, 5, 28) + assert law.effective_date == date(2021, 1, 1) + assert law.status == LawStatus.EFFECTIVE + assert law.issuing_authority == "全国人民代表大会" + + def test_law_type_enum(self): + """Test law type enum values.""" + assert LawType.LAW.value == "law" + assert LawType.REGULATION.value == "regulation" + assert LawType.RULE.value == "rule" + assert LawType.JUDICIAL_INTERPRETATION.value == "judicial_interpretation" + + def test_law_status_enum(self): + """Test law status enum values.""" + assert LawStatus.EFFECTIVE.value == "effective" + assert LawStatus.REVOKED.value == "revoked" + assert LawStatus.AMENDED.value == "amended" + + def test_law_default_values(self): + """Test law default values.""" + law = Law( + title="测试法规", + law_type=LawType.REGULATION, + promulgation_date=date(2024, 1, 1), + effective_date=date(2024, 2, 1), + issuing_authority="测试机关", + content="测试内容", + ) + + assert law.status == LawStatus.EFFECTIVE + + +class TestLawArticleModel: + """Test cases for LawArticle model.""" + + def test_article_creation(self): + """Test creating a law article instance.""" + article = LawArticle( + law_id=1, + article_number="第一条", + content="为了保护民事主体的合法权益...", + ) + + assert article.law_id == 1 + assert article.article_number == "第一条" + assert "民事主体" in article.content diff --git a/backend/tests/unit/test_law_service.py b/backend/tests/unit/test_law_service.py new file mode 100644 index 0000000..e5f10cd --- /dev/null +++ b/backend/tests/unit/test_law_service.py @@ -0,0 +1,144 @@ +"""Unit tests for Law service.""" +import pytest +from datetime import date +from unittest.mock import AsyncMock, patch, MagicMock + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.law import Law, LawArticle, LawType, LawStatus +from app.services.law_service import LawService + + +class TestLawService: + """Test cases for Law service.""" + + @pytest.mark.asyncio + async def test_create_law(self, db_session: AsyncSession): + """Test creating a law.""" + service = LawService(db_session) + + law = await service.create_law( + title="测试法律", + law_type=LawType.LAW, + promulgation_date=date(2024, 1, 1), + effective_date=date(2024, 2, 1), + issuing_authority="测试机关", + content="测试内容", + ) + + assert law.id is not None + assert law.title == "测试法律" + assert law.law_type == LawType.LAW + + @pytest.mark.asyncio + async def test_get_law_by_id(self, db_session: AsyncSession): + """Test getting a law by ID.""" + service = LawService(db_session) + + # Create a law first + created = await service.create_law( + title="测试法律", + law_type=LawType.LAW, + promulgation_date=date(2024, 1, 1), + effective_date=date(2024, 2, 1), + issuing_authority="测试机关", + content="测试内容", + ) + + # Get the law + law = await service.get_law_by_id(created.id) + + assert law is not None + assert law.title == "测试法律" + + @pytest.mark.asyncio + async def test_get_laws_list(self, db_session: AsyncSession): + """Test getting list of laws.""" + service = LawService(db_session) + + # Create multiple laws + for i in range(3): + await service.create_law( + title=f"测试法律{i}", + law_type=LawType.LAW, + promulgation_date=date(2024, 1, 1), + effective_date=date(2024, 2, 1), + issuing_authority="测试机关", + content=f"测试内容{i}", + ) + + laws = await service.get_laws_list(skip=0, limit=10) + + assert len(laws) == 3 + + @pytest.mark.asyncio + async def test_update_law(self, db_session: AsyncSession): + """Test updating a law.""" + service = LawService(db_session) + + # Create a law + law = await service.create_law( + title="原标题", + law_type=LawType.LAW, + promulgation_date=date(2024, 1, 1), + effective_date=date(2024, 2, 1), + issuing_authority="测试机关", + content="原内容", + ) + + # Update the law + updated = await service.update_law(law.id, title="新标题") + + assert updated.title == "新标题" + + @pytest.mark.asyncio + async def test_delete_law(self, db_session: AsyncSession): + """Test deleting a law.""" + service = LawService(db_session) + + # Create a law + law = await service.create_law( + title="待删除法律", + law_type=LawType.LAW, + promulgation_date=date(2024, 1, 1), + effective_date=date(2024, 2, 1), + issuing_authority="测试机关", + content="测试内容", + ) + + # Delete the law + result = await service.delete_law(law.id) + + assert result is True + + # Verify it's deleted + deleted = await service.get_law_by_id(law.id) + assert deleted is None + + @pytest.mark.asyncio + async def test_search_laws_by_keyword(self, db_session: AsyncSession): + """Test searching laws by keyword.""" + service = LawService(db_session) + + # Create laws with different content + await service.create_law( + title="民法典", + law_type=LawType.LAW, + promulgation_date=date(2024, 1, 1), + effective_date=date(2024, 2, 1), + issuing_authority="全国人大", + content="民法典全文内容", + ) + await service.create_law( + title="刑法", + law_type=LawType.LAW, + promulgation_date=date(2024, 1, 1), + effective_date=date(2024, 2, 1), + issuing_authority="全国人大", + content="刑法全文内容", + ) + + results = await service.search_laws_by_keyword("民法") + + assert len(results) >= 1 + assert any("民法典" in law.title for law in results) diff --git a/backend/tests/unit/test_llm_service.py b/backend/tests/unit/test_llm_service.py new file mode 100644 index 0000000..3a34cfc --- /dev/null +++ b/backend/tests/unit/test_llm_service.py @@ -0,0 +1,82 @@ +"""Unit tests for LLM service.""" +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from app.services.llm_service import LLMService + + +class TestLLMService: + """Test cases for LLM service.""" + + @pytest.mark.asyncio + async def test_chat_completion_no_api_key(self): + """Test chat completion without API key returns mock.""" + service = LLMService() + service.api_key = None # Ensure no API key + + response = await service.chat_completion( + messages=[{"role": "user", "content": "你好"}] + ) + + # Should return mock response + assert "模拟" in response or "法律" in response + + @pytest.mark.asyncio + async def test_chat_completion_with_api_key(self): + """Test chat completion with API key.""" + service = LLMService() + service.api_key = "test-key" + + with patch('httpx.AsyncClient') as mock_client: + mock_response = MagicMock() + mock_response.json.return_value = { + "choices": [{ + "message": { + "role": "assistant", + "content": "这是测试回复" + } + }] + } + mock_response.raise_for_status = MagicMock() + + mock_client_instance = MagicMock() + mock_client_instance.post = AsyncMock(return_value=mock_response) + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client.return_value = mock_client_instance + + response = await service.chat_completion( + messages=[{"role": "user", "content": "你好"}] + ) + + assert "这是测试回复" in response + + @pytest.mark.asyncio + async def test_legal_qa(self): + """Test legal QA.""" + service = LLMService() + + with patch.object(service, 'chat_completion') as mock_chat: + mock_chat.return_value = "根据《民法典》第一条规定..." + + response = await service.legal_qa( + question="民法典的立法目的是什么?", + context="《民法典》第一条 为了保护民事主体的合法权益..." + ) + + assert "民法典" in response + + @pytest.mark.asyncio + async def test_analyze_legal_issue(self): + """Test legal issue analysis.""" + service = LLMService() + + with patch.object(service, 'chat_completion') as mock_chat: + mock_chat.return_value = "分析结果:该问题涉及..." + + response = await service.analyze_legal_issue( + issue_description="甲方未按合同约定支付货款", + relevant_laws=["《民法典》第五百零九条"] + ) + + assert "分析结果" in response diff --git a/backend/tests/unit/test_remaining_models.py b/backend/tests/unit/test_remaining_models.py new file mode 100644 index 0000000..2a1bf32 --- /dev/null +++ b/backend/tests/unit/test_remaining_models.py @@ -0,0 +1,115 @@ +"""Unit tests for remaining models and services.""" +import pytest +from datetime import date, datetime + +from app.models.analysis import LegalAnalysis, Case, AnalysisStatus +from app.models.case import CaseReview, ReviewStatus, ReviewType +from app.models.contract import Contract, ContractTemplate, ContractStatus +from app.models.signature import SignatureRequest, Signature, SignatureStatus + + +class TestAnalysisModel: + """Test cases for Analysis model.""" + + def test_legal_analysis_creation(self): + """Test creating a legal analysis instance.""" + analysis = LegalAnalysis( + user_id=1, + title="合同纠纷分析", + case_description="甲方未按合同约定支付货款...", + ) + + assert analysis.user_id == 1 + assert analysis.title == "合同纠纷分析" + assert analysis.status == AnalysisStatus.DRAFT + + def test_case_creation(self): + """Test creating a case instance.""" + case = Case( + title="买卖合同纠纷案", + case_number="(2024)京01民初1号", + court="北京市第一中级人民法院", + case_type="买卖合同纠纷", + ) + + assert case.title == "买卖合同纠纷案" + assert case.case_number == "(2024)京01民初1号" + + +class TestCaseReviewModel: + """Test cases for CaseReview model.""" + + def test_case_review_creation(self): + """Test creating a case review instance.""" + review = CaseReview( + case_id=1, + reviewer_id=1, + review_type=ReviewType.INITIAL, + opinion="同意立案", + score=90, + ) + + assert review.case_id == 1 + assert review.review_type == ReviewType.INITIAL + assert review.status == ReviewStatus.PENDING + + +class TestContractModel: + """Test cases for Contract model.""" + + def test_contract_creation(self): + """Test creating a contract instance.""" + contract = Contract( + title="销售合同", + party_a="甲方公司", + party_b="乙方公司", + content="合同内容...", + created_by=1, + ) + + assert contract.title == "销售合同" + assert contract.party_a == "甲方公司" + assert contract.status == ContractStatus.DRAFT + + def test_contract_template_creation(self): + """Test creating a contract template instance.""" + template = ContractTemplate( + name="标准销售合同模板", + content="模板内容...", + created_by=1, + contract_type="销售合同", + ) + + assert template.name == "标准销售合同模板" + assert template.contract_type == "销售合同" + + +class TestSignatureModel: + """Test cases for Signature model.""" + + def test_signature_request_creation(self): + """Test creating a signature request instance.""" + request = SignatureRequest( + contract_id=1, + requester_id=1, + signer_name="张三", + signer_email="zhangsan@example.com", + token="abc123token", + expires_at=datetime.utcnow(), + ) + + assert request.contract_id == 1 + assert request.signer_name == "张三" + assert request.status == SignatureStatus.PENDING + + def test_signature_creation(self): + """Test creating a signature instance.""" + signature = Signature( + request_id=1, + signer_name="张三", + signature_data="base64encodedimage", + verification_hash="hash123", + ) + + assert signature.request_id == 1 + assert signature.signer_name == "张三" diff --git a/backend/tests/unit/test_security.py b/backend/tests/unit/test_security.py new file mode 100644 index 0000000..442037a --- /dev/null +++ b/backend/tests/unit/test_security.py @@ -0,0 +1,84 @@ +"""Unit tests for security module.""" +import pytest +from datetime import datetime, timedelta + +from app.core.security import ( + get_password_hash, + verify_password, + create_access_token, + decode_access_token, +) + + +class TestPasswordHashing: + """Test cases for password hashing.""" + + def test_password_hash(self): + """Test password hashing.""" + password = "test_password_123" + hashed = get_password_hash(password) + + assert hashed != password + assert len(hashed) > 0 + assert hashed.startswith("$2b$") + + def test_verify_password_correct(self): + """Test verifying correct password.""" + password = "test_password_123" + hashed = get_password_hash(password) + + assert verify_password(password, hashed) is True + + def test_verify_password_incorrect(self): + """Test verifying incorrect password.""" + password = "test_password_123" + hashed = get_password_hash(password) + + assert verify_password("wrong_password", hashed) is False + + def test_different_passwords_different_hashes(self): + """Test that same password produces different hashes.""" + password = "test_password_123" + hash1 = get_password_hash(password) + hash2 = get_password_hash(password) + + assert hash1 != hash2 + + +class TestJWTTokens: + """Test cases for JWT token handling.""" + + def test_create_access_token(self): + """Test creating access token.""" + data = {"sub": "user123", "role": "lawyer"} + token = create_access_token(data) + + assert token is not None + assert isinstance(token, str) + assert len(token) > 0 + + def test_decode_access_token(self): + """Test decoding access token.""" + data = {"sub": "user123", "role": "lawyer"} + token = create_access_token(data) + + decoded = decode_access_token(token) + + assert decoded is not None + assert decoded["sub"] == "user123" + assert decoded["role"] == "lawyer" + + def test_decode_invalid_token(self): + """Test decoding invalid token.""" + decoded = decode_access_token("invalid.token.here") + + assert decoded is None + + def test_token_contains_exp(self): + """Test that token contains expiration.""" + data = {"sub": "user123"} + token = create_access_token(data) + + decoded = decode_access_token(token) + + assert "exp" in decoded diff --git a/backend/tests/unit/test_signature_service.py b/backend/tests/unit/test_signature_service.py new file mode 100644 index 0000000..c891c93 --- /dev/null +++ b/backend/tests/unit/test_signature_service.py @@ -0,0 +1,88 @@ +"""Unit tests for Signature service.""" +import pytest +from datetime import datetime, timedelta + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.signature import SignatureStatus +from app.models.contract import Contract +from app.services.signature_service import SignatureService +from app.services.contract_service import ContractService + + +class TestSignatureService: + """Test cases for Signature service.""" + + @pytest.mark.asyncio + async def test_generate_token(self, db_session: AsyncSession): + """Test token generation.""" + service = SignatureService(db_session) + + token = service.generate_token() + + assert token is not None + assert len(token) > 20 + + @pytest.mark.asyncio + async def test_generate_verification_hash(self, db_session: AsyncSession): + """Test verification hash generation.""" + service = SignatureService(db_session) + + hash1 = service.generate_verification_hash("test content") + hash2 = service.generate_verification_hash("test content") + hash3 = service.generate_verification_hash("different content") + + assert hash1 == hash2 + assert hash1 != hash3 + assert len(hash1) == 64 # SHA-256 hex length + + @pytest.mark.asyncio + async def test_create_signature_request(self, db_session: AsyncSession): + """Test creating a signature request.""" + # First create a contract + contract_service = ContractService(db_session) + contract = await contract_service.create_contract( + title="测试合同", + party_a="甲方", + party_b="乙方", + content="合同内容", + created_by=1, + ) + + service = SignatureService(db_session) + request = await service.create_signature_request( + contract_id=contract.id, + requester_id=1, + signer_name="张三", + signer_email="zhangsan@example.com", + ) + + assert request.id is not None + assert request.signer_name == "张三" + assert request.status == SignatureStatus.PENDING + assert request.token is not None + + @pytest.mark.asyncio + async def test_is_request_valid(self, db_session: AsyncSession): + """Test checking if request is valid.""" + # Create a contract first + contract_service = ContractService(db_session) + contract = await contract_service.create_contract( + title="测试合同", + party_a="甲方", + party_b="乙方", + content="合同内容", + created_by=1, + ) + + service = SignatureService(db_session) + request = await service.create_signature_request( + contract_id=contract.id, + requester_id=1, + signer_name="张三", + signer_email="zhangsan@example.com", + expire_hours=72, + ) + + is_valid = await service.is_request_valid(request) + assert is_valid is True diff --git a/backend/tests/unit/test_user_model.py b/backend/tests/unit/test_user_model.py new file mode 100644 index 0000000..53e5ee3 --- /dev/null +++ b/backend/tests/unit/test_user_model.py @@ -0,0 +1,52 @@ +"""Unit tests for User model.""" +import pytest +from datetime import datetime + +from app.models.user import User, UserRole + + +class TestUserModel: + """Test cases for User model.""" + + def test_user_creation(self): + """Test creating a user instance.""" + user = User( + username="testuser", + email="test@example.com", + hashed_password="hashed_password", + role=UserRole.LAWYER, + ) + + assert user.username == "testuser" + assert user.email == "test@example.com" + assert user.hashed_password == "hashed_password" + assert user.role == UserRole.LAWYER + assert user.is_active is True + + def test_user_role_enum(self): + """Test user role enum values.""" + assert UserRole.ADMIN.value == "admin" + assert UserRole.LAWYER.value == "lawyer" + assert UserRole.REVIEWER.value == "reviewer" + assert UserRole.CLIENT.value == "client" + + def test_user_default_values(self): + """Test user default values.""" + user = User( + username="testuser", + email="test@example.com", + hashed_password="hashed_password", + ) + + assert user.role == UserRole.CLIENT + assert user.is_active is True + + def test_user_repr(self): + """Test user string representation.""" + user = User( + username="testuser", + email="test@example.com", + hashed_password="hashed_password", + ) + + assert "testuser" in repr(user) diff --git a/backend/tests/unit/test_vector_service.py b/backend/tests/unit/test_vector_service.py new file mode 100644 index 0000000..35b043e --- /dev/null +++ b/backend/tests/unit/test_vector_service.py @@ -0,0 +1,81 @@ +"""Unit tests for Vector service.""" +import pytest +import numpy as np +from unittest.mock import AsyncMock, patch, MagicMock + +from app.services.vector_service import VectorService + + +class TestVectorService: + """Test cases for Vector service.""" + + @pytest.mark.asyncio + async def test_get_embedding_no_api_key(self): + """Test getting embedding without API key returns mock.""" + service = VectorService() + service.api_key = None # Ensure no API key + + embedding = await service.get_embedding("测试文本") + + # Should return mock embedding + assert len(embedding) == service.dimension + assert all(x == 0.0 for x in embedding) + + @pytest.mark.asyncio + async def test_get_embedding_with_api_key(self): + """Test getting embedding with API key.""" + service = VectorService() + service.api_key = "test-key" + + with patch('httpx.AsyncClient') as mock_client: + mock_response = MagicMock() + mock_response.json.return_value = { + "data": [{"embedding": [0.1] * 1536}] + } + mock_response.raise_for_status = MagicMock() + + mock_client_instance = MagicMock() + mock_client_instance.post = AsyncMock(return_value=mock_response) + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client.return_value = mock_client_instance + + embedding = await service.get_embedding("测试文本") + + assert len(embedding) == 1536 + assert all(x == 0.1 for x in embedding) + + @pytest.mark.asyncio + async def test_cosine_similarity(self): + """Test cosine similarity calculation.""" + service = VectorService() + + vec1 = [1.0, 0.0, 0.0] + vec2 = [1.0, 0.0, 0.0] + vec3 = [0.0, 1.0, 0.0] + + # Same vectors should have similarity 1.0 + sim1 = service.cosine_similarity(vec1, vec2) + assert abs(sim1 - 1.0) < 0.001 + + # Orthogonal vectors should have similarity 0.0 + sim2 = service.cosine_similarity(vec1, vec3) + assert abs(sim2 - 0.0) < 0.001 + + @pytest.mark.asyncio + async def test_search_similar(self): + """Test searching similar vectors.""" + service = VectorService() + + # Mock vectors: first is similar, second is different + stored_vectors = [ + {"id": 1, "embedding": [1.0, 0.0, 0.0]}, + {"id": 2, "embedding": [0.0, 1.0, 0.0]}, + ] + + query_vec = [0.9, 0.1, 0.0] + + results = service.search_similar(query_vec, stored_vectors, top_k=2) + + assert len(results) == 2 + assert results[0]["id"] == 1 # Most similar should be first diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..df9f2be --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,25 @@ +version: '3.8' + +services: + web: + build: ./backend + ports: + - "8000:8000" + volumes: + - ./data:/app/data + - ./uploads:/app/uploads + environment: + - DATABASE_URL=sqlite+aiosqlite:///./data/legal_assistant.db + - LLM_API_KEY=${LLM_API_KEY} + - LLM_API_BASE=${LLM_API_BASE:-https://api.openai.com/v1} + - LLM_MODEL=${LLM_MODEL:-gpt-4o-mini} + - EMBEDDING_API_KEY=${EMBEDDING_API_KEY} + - EMBEDDING_API_BASE=${EMBEDDING_API_BASE:-https://api.openai.com/v1} + - EMBEDDING_MODEL=${EMBEDDING_MODEL:-text-embedding-3-small} + - JWT_SECRET_KEY=${JWT_SECRET_KEY:-change-me-in-production} + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 30s + timeout: 10s + retries: 3