fix: improve error handling and validation

- Add centralized exception classes
- Validate signature request status before signing
- Check contract status before approval/rejection
- Add exception handlers to FastAPI app
- Update tests for new validation logic

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
root 2026-05-01 03:38:20 +08:00
parent 656f596d7e
commit ae72b180e5
6 changed files with 194 additions and 9 deletions

View File

@ -0,0 +1,29 @@
"""Exception handlers for FastAPI."""
from fastapi import Request, status
from fastapi.responses import JSONResponse
from app.core.exceptions import AppException
async def app_exception_handler(request: Request, exc: AppException):
"""Handle application exceptions."""
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
"code": exc.code,
"message": exc.message,
"details": exc.details,
}
)
async def generic_exception_handler(request: Request, exc: Exception):
"""Handle generic exceptions."""
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"code": 50000,
"message": "Internal server error",
"details": {},
}
)

View File

@ -0,0 +1,87 @@
"""Core exceptions for the application."""
class AppException(Exception):
"""Base application exception."""
def __init__(self, code: int, message: str, details: dict = None):
self.code = code
self.message = message
self.details = details or {}
class NotFoundError(AppException):
"""Resource not found exception."""
def __init__(self, resource: str, resource_id: int = None):
super().__init__(
code=10002,
message=f"{resource} not found",
details={"resource": resource, "id": resource_id}
)
class ValidationError(AppException):
"""Validation error exception."""
def __init__(self, message: str, details: dict = None):
super().__init__(code=10001, message=message, details=details)
class AuthenticationError(AppException):
"""Authentication error exception."""
def __init__(self, message: str = "Authentication failed"):
super().__init__(code=20001, message=message)
class AuthorizationError(AppException):
"""Authorization error exception."""
def __init__(self, message: str = "Permission denied"):
super().__init__(code=20002, message=message)
class SignatureError(AppException):
"""Signature related error exception."""
def __init__(self, message: str, details: dict = None):
super().__init__(code=40000, message=message, details=details)
class SignatureExpiredError(SignatureError):
"""Signature request expired exception."""
def __init__(self):
super().__init__(code=40001, message="Signature request has expired")
class SignatureAlreadySignedError(SignatureError):
"""Signature already completed exception."""
def __init__(self):
super().__init__(code=40002, message="Signature request already completed")
class ContractError(AppException):
"""Contract related error exception."""
def __init__(self, message: str, details: dict = None):
super().__init__(code=50000, message=message, details=details)
class InvalidContractStatusError(ContractError):
"""Invalid contract status for operation."""
def __init__(self, current_status: str, required_status: str):
super().__init__(
message=f"Invalid contract status: {current_status}, expected: {required_status}",
details={"current_status": current_status, "required_status": required_status}
)
class LLMError(AppException):
"""LLM service error exception."""
def __init__(self, message: str = "LLM service error"):
super().__init__(code=30001, message=message)

View File

@ -5,6 +5,8 @@ from fastapi.middleware.cors import CORSMiddleware
from app.core.config import settings from app.core.config import settings
from app.core.database import init_db from app.core.database import init_db
from app.core.exceptions import AppException
from app.core.exception_handlers import app_exception_handler
from app.api.v1 import laws, analyses, contracts, signatures from app.api.v1 import laws, analyses, contracts, signatures
@ -24,6 +26,9 @@ app = FastAPI(
lifespan=lifespan, lifespan=lifespan,
) )
# Register exception handlers
app.add_exception_handler(AppException, app_exception_handler)
# CORS middleware # CORS middleware
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,

View File

@ -3,6 +3,7 @@ from typing import List, Optional
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import InvalidContractStatusError, NotFoundError
from app.models.contract import Contract, ContractTemplate, ContractApproval, ContractStatus from app.models.contract import Contract, ContractTemplate, ContractApproval, ContractStatus
@ -42,6 +43,13 @@ class ContractService:
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
async def get_contract_or_raise(self, contract_id: int) -> Contract:
"""Get a contract by ID or raise exception."""
contract = await self.get_contract_by_id(contract_id)
if not contract:
raise NotFoundError("Contract", contract_id)
return contract
async def get_contracts_list( async def get_contracts_list(
self, self,
skip: int = 0, skip: int = 0,
@ -81,6 +89,17 @@ class ContractService:
async def submit_for_approval(self, contract_id: int) -> Optional[Contract]: async def submit_for_approval(self, contract_id: int) -> Optional[Contract]:
"""Submit a contract for approval.""" """Submit a contract for approval."""
contract = await self.get_contract_by_id(contract_id)
if not contract:
return None
# Only draft or rejected contracts can be submitted
if contract.status not in [ContractStatus.DRAFT, ContractStatus.REJECTED]:
raise InvalidContractStatusError(
current_status=contract.status.value,
required_status="draft or rejected"
)
return await self.update_contract( return await self.update_contract(
contract_id, contract_id,
status=ContractStatus.PENDING_APPROVAL status=ContractStatus.PENDING_APPROVAL
@ -93,6 +112,15 @@ class ContractService:
comment: Optional[str] = None, comment: Optional[str] = None,
) -> ContractApproval: ) -> ContractApproval:
"""Approve a contract.""" """Approve a contract."""
contract = await self.get_contract_or_raise(contract_id)
# Only pending approval contracts can be approved
if contract.status != ContractStatus.PENDING_APPROVAL:
raise InvalidContractStatusError(
current_status=contract.status.value,
required_status="pending_approval"
)
approval = ContractApproval( approval = ContractApproval(
contract_id=contract_id, contract_id=contract_id,
approver_id=approver_id, approver_id=approver_id,
@ -102,7 +130,7 @@ class ContractService:
self.db.add(approval) self.db.add(approval)
# Update contract status # Update contract status
await self.update_contract(contract_id, status=ContractStatus.APPROVED) contract.status = ContractStatus.APPROVED
await self.db.flush() await self.db.flush()
await self.db.refresh(approval) await self.db.refresh(approval)
@ -115,6 +143,15 @@ class ContractService:
comment: Optional[str] = None, comment: Optional[str] = None,
) -> ContractApproval: ) -> ContractApproval:
"""Reject a contract.""" """Reject a contract."""
contract = await self.get_contract_or_raise(contract_id)
# Only pending approval contracts can be rejected
if contract.status != ContractStatus.PENDING_APPROVAL:
raise InvalidContractStatusError(
current_status=contract.status.value,
required_status="pending_approval"
)
approval = ContractApproval( approval = ContractApproval(
contract_id=contract_id, contract_id=contract_id,
approver_id=approver_id, approver_id=approver_id,
@ -124,7 +161,7 @@ class ContractService:
self.db.add(approval) self.db.add(approval)
# Update contract status # Update contract status
await self.update_contract(contract_id, status=ContractStatus.REJECTED) contract.status = ContractStatus.REJECTED
await self.db.flush() await self.db.flush()
await self.db.refresh(approval) await self.db.refresh(approval)

View File

@ -8,6 +8,11 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings from app.core.config import settings
from app.core.exceptions import (
NotFoundError,
SignatureExpiredError,
SignatureAlreadySignedError,
)
from app.models.signature import ( from app.models.signature import (
SignatureRequest, SignatureRequest,
Signature, Signature,
@ -40,6 +45,14 @@ class SignatureService:
expire_hours: int = None, expire_hours: int = None,
) -> SignatureRequest: ) -> SignatureRequest:
"""Create a signature request.""" """Create a signature request."""
# Verify contract exists
result = await self.db.execute(
select(Contract).where(Contract.id == contract_id)
)
contract = result.scalar_one_or_none()
if not contract:
raise NotFoundError("Contract", contract_id)
expire_hours = expire_hours or settings.SIGNATURE_TOKEN_EXPIRE_HOURS expire_hours = expire_hours or settings.SIGNATURE_TOKEN_EXPIRE_HOURS
request = SignatureRequest( request = SignatureRequest(
@ -51,6 +64,8 @@ class SignatureService:
expires_at=datetime.utcnow() + timedelta(hours=expire_hours), expires_at=datetime.utcnow() + timedelta(hours=expire_hours),
) )
self.db.add(request) self.db.add(request)
await self.db.flush()
await self.db.refresh(request)
# Create audit log # Create audit log
await self.create_audit_log( await self.create_audit_log(
@ -59,8 +74,6 @@ class SignatureService:
details={"signer_email": signer_email} details={"signer_email": signer_email}
) )
await self.db.flush()
await self.db.refresh(request)
return request return request
async def get_signature_request_by_token( async def get_signature_request_by_token(
@ -83,13 +96,21 @@ class SignatureService:
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
async def validate_request(self, request: SignatureRequest) -> None:
"""Validate a signature request. Raises exception if invalid."""
if request.status != SignatureStatus.PENDING:
raise SignatureAlreadySignedError()
if request.expires_at < datetime.utcnow():
raise SignatureExpiredError()
async def is_request_valid(self, request: SignatureRequest) -> bool: async def is_request_valid(self, request: SignatureRequest) -> bool:
"""Check if a signature request is valid.""" """Check if a signature request is valid."""
if request.status != SignatureStatus.PENDING: try:
return False await self.validate_request(request)
if request.expires_at < datetime.utcnow():
return False
return True return True
except Exception:
return False
async def sign_document( async def sign_document(
self, self,
@ -99,6 +120,9 @@ class SignatureService:
user_agent: Optional[str] = None, user_agent: Optional[str] = None,
) -> Signature: ) -> Signature:
"""Sign a document.""" """Sign a document."""
# Validate request first
await self.validate_request(request)
# Get contract content for hash # Get contract content for hash
result = await self.db.execute( result = await self.db.execute(
select(Contract).where(Contract.id == request.contract_id) select(Contract).where(Contract.id == request.contract_id)
@ -106,7 +130,7 @@ class SignatureService:
contract = result.scalar_one_or_none() contract = result.scalar_one_or_none()
if not contract: if not contract:
raise ValueError("Contract not found") raise NotFoundError("Contract", request.contract_id)
verification_hash = self.generate_verification_hash(contract.content) verification_hash = self.generate_verification_hash(contract.content)

View File

@ -76,6 +76,9 @@ class TestContractService:
created_by=1, created_by=1,
) )
# Submit for approval first
await service.submit_for_approval(contract.id)
approval = await service.approve_contract(contract.id, 1, "同意") approval = await service.approve_contract(contract.id, 1, "同意")
assert approval.status == ContractStatus.APPROVED assert approval.status == ContractStatus.APPROVED