diff --git a/backend/app/core/exception_handlers.py b/backend/app/core/exception_handlers.py new file mode 100644 index 0000000..865e4af --- /dev/null +++ b/backend/app/core/exception_handlers.py @@ -0,0 +1,29 @@ +"""Exception handlers for FastAPI.""" +from fastapi import Request, status +from fastapi.responses import JSONResponse + +from app.core.exceptions import AppException + + +async def app_exception_handler(request: Request, exc: AppException): + """Handle application exceptions.""" + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={ + "code": exc.code, + "message": exc.message, + "details": exc.details, + } + ) + + +async def generic_exception_handler(request: Request, exc: Exception): + """Handle generic exceptions.""" + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={ + "code": 50000, + "message": "Internal server error", + "details": {}, + } + ) diff --git a/backend/app/core/exceptions.py b/backend/app/core/exceptions.py new file mode 100644 index 0000000..fc673af --- /dev/null +++ b/backend/app/core/exceptions.py @@ -0,0 +1,87 @@ +"""Core exceptions for the application.""" + + +class AppException(Exception): + """Base application exception.""" + + def __init__(self, code: int, message: str, details: dict = None): + self.code = code + self.message = message + self.details = details or {} + + +class NotFoundError(AppException): + """Resource not found exception.""" + + def __init__(self, resource: str, resource_id: int = None): + super().__init__( + code=10002, + message=f"{resource} not found", + details={"resource": resource, "id": resource_id} + ) + + +class ValidationError(AppException): + """Validation error exception.""" + + def __init__(self, message: str, details: dict = None): + super().__init__(code=10001, message=message, details=details) + + +class AuthenticationError(AppException): + """Authentication error exception.""" + + def __init__(self, message: str = "Authentication failed"): + super().__init__(code=20001, message=message) + + +class AuthorizationError(AppException): + """Authorization error exception.""" + + def __init__(self, message: str = "Permission denied"): + super().__init__(code=20002, message=message) + + +class SignatureError(AppException): + """Signature related error exception.""" + + def __init__(self, message: str, details: dict = None): + super().__init__(code=40000, message=message, details=details) + + +class SignatureExpiredError(SignatureError): + """Signature request expired exception.""" + + def __init__(self): + super().__init__(code=40001, message="Signature request has expired") + + +class SignatureAlreadySignedError(SignatureError): + """Signature already completed exception.""" + + def __init__(self): + super().__init__(code=40002, message="Signature request already completed") + + +class ContractError(AppException): + """Contract related error exception.""" + + def __init__(self, message: str, details: dict = None): + super().__init__(code=50000, message=message, details=details) + + +class InvalidContractStatusError(ContractError): + """Invalid contract status for operation.""" + + def __init__(self, current_status: str, required_status: str): + super().__init__( + message=f"Invalid contract status: {current_status}, expected: {required_status}", + details={"current_status": current_status, "required_status": required_status} + ) + + +class LLMError(AppException): + """LLM service error exception.""" + + def __init__(self, message: str = "LLM service error"): + super().__init__(code=30001, message=message) diff --git a/backend/app/main.py b/backend/app/main.py index 4850c94..53b2c66 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -5,6 +5,8 @@ from fastapi.middleware.cors import CORSMiddleware from app.core.config import settings from app.core.database import init_db +from app.core.exceptions import AppException +from app.core.exception_handlers import app_exception_handler from app.api.v1 import laws, analyses, contracts, signatures @@ -24,6 +26,9 @@ app = FastAPI( lifespan=lifespan, ) +# Register exception handlers +app.add_exception_handler(AppException, app_exception_handler) + # CORS middleware app.add_middleware( CORSMiddleware, diff --git a/backend/app/services/contract_service.py b/backend/app/services/contract_service.py index 2c4851a..ebeedbe 100644 --- a/backend/app/services/contract_service.py +++ b/backend/app/services/contract_service.py @@ -3,6 +3,7 @@ from typing import List, Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.core.exceptions import InvalidContractStatusError, NotFoundError from app.models.contract import Contract, ContractTemplate, ContractApproval, ContractStatus @@ -42,6 +43,13 @@ class ContractService: ) return result.scalar_one_or_none() + async def get_contract_or_raise(self, contract_id: int) -> Contract: + """Get a contract by ID or raise exception.""" + contract = await self.get_contract_by_id(contract_id) + if not contract: + raise NotFoundError("Contract", contract_id) + return contract + async def get_contracts_list( self, skip: int = 0, @@ -81,6 +89,17 @@ class ContractService: async def submit_for_approval(self, contract_id: int) -> Optional[Contract]: """Submit a contract for approval.""" + contract = await self.get_contract_by_id(contract_id) + if not contract: + return None + + # Only draft or rejected contracts can be submitted + if contract.status not in [ContractStatus.DRAFT, ContractStatus.REJECTED]: + raise InvalidContractStatusError( + current_status=contract.status.value, + required_status="draft or rejected" + ) + return await self.update_contract( contract_id, status=ContractStatus.PENDING_APPROVAL @@ -93,6 +112,15 @@ class ContractService: comment: Optional[str] = None, ) -> ContractApproval: """Approve a contract.""" + contract = await self.get_contract_or_raise(contract_id) + + # Only pending approval contracts can be approved + if contract.status != ContractStatus.PENDING_APPROVAL: + raise InvalidContractStatusError( + current_status=contract.status.value, + required_status="pending_approval" + ) + approval = ContractApproval( contract_id=contract_id, approver_id=approver_id, @@ -102,7 +130,7 @@ class ContractService: self.db.add(approval) # Update contract status - await self.update_contract(contract_id, status=ContractStatus.APPROVED) + contract.status = ContractStatus.APPROVED await self.db.flush() await self.db.refresh(approval) @@ -115,6 +143,15 @@ class ContractService: comment: Optional[str] = None, ) -> ContractApproval: """Reject a contract.""" + contract = await self.get_contract_or_raise(contract_id) + + # Only pending approval contracts can be rejected + if contract.status != ContractStatus.PENDING_APPROVAL: + raise InvalidContractStatusError( + current_status=contract.status.value, + required_status="pending_approval" + ) + approval = ContractApproval( contract_id=contract_id, approver_id=approver_id, @@ -124,7 +161,7 @@ class ContractService: self.db.add(approval) # Update contract status - await self.update_contract(contract_id, status=ContractStatus.REJECTED) + contract.status = ContractStatus.REJECTED await self.db.flush() await self.db.refresh(approval) diff --git a/backend/app/services/signature_service.py b/backend/app/services/signature_service.py index c4735b3..c0df269 100644 --- a/backend/app/services/signature_service.py +++ b/backend/app/services/signature_service.py @@ -8,6 +8,11 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings +from app.core.exceptions import ( + NotFoundError, + SignatureExpiredError, + SignatureAlreadySignedError, +) from app.models.signature import ( SignatureRequest, Signature, @@ -40,6 +45,14 @@ class SignatureService: expire_hours: int = None, ) -> SignatureRequest: """Create a signature request.""" + # Verify contract exists + result = await self.db.execute( + select(Contract).where(Contract.id == contract_id) + ) + contract = result.scalar_one_or_none() + if not contract: + raise NotFoundError("Contract", contract_id) + expire_hours = expire_hours or settings.SIGNATURE_TOKEN_EXPIRE_HOURS request = SignatureRequest( @@ -51,6 +64,8 @@ class SignatureService: expires_at=datetime.utcnow() + timedelta(hours=expire_hours), ) self.db.add(request) + await self.db.flush() + await self.db.refresh(request) # Create audit log await self.create_audit_log( @@ -59,8 +74,6 @@ class SignatureService: details={"signer_email": signer_email} ) - await self.db.flush() - await self.db.refresh(request) return request async def get_signature_request_by_token( @@ -83,13 +96,21 @@ class SignatureService: ) return result.scalar_one_or_none() + async def validate_request(self, request: SignatureRequest) -> None: + """Validate a signature request. Raises exception if invalid.""" + if request.status != SignatureStatus.PENDING: + raise SignatureAlreadySignedError() + + if request.expires_at < datetime.utcnow(): + raise SignatureExpiredError() + async def is_request_valid(self, request: SignatureRequest) -> bool: """Check if a signature request is valid.""" - if request.status != SignatureStatus.PENDING: + try: + await self.validate_request(request) + return True + except Exception: return False - if request.expires_at < datetime.utcnow(): - return False - return True async def sign_document( self, @@ -99,6 +120,9 @@ class SignatureService: user_agent: Optional[str] = None, ) -> Signature: """Sign a document.""" + # Validate request first + await self.validate_request(request) + # Get contract content for hash result = await self.db.execute( select(Contract).where(Contract.id == request.contract_id) @@ -106,7 +130,7 @@ class SignatureService: contract = result.scalar_one_or_none() if not contract: - raise ValueError("Contract not found") + raise NotFoundError("Contract", request.contract_id) verification_hash = self.generate_verification_hash(contract.content) diff --git a/backend/tests/unit/test_contract_service.py b/backend/tests/unit/test_contract_service.py index afaae2b..966579b 100644 --- a/backend/tests/unit/test_contract_service.py +++ b/backend/tests/unit/test_contract_service.py @@ -76,6 +76,9 @@ class TestContractService: created_by=1, ) + # Submit for approval first + await service.submit_for_approval(contract.id) + approval = await service.approve_contract(contract.id, 1, "同意") assert approval.status == ContractStatus.APPROVED