fix: improve error handling and validation
- Add centralized exception classes - Validate signature request status before signing - Check contract status before approval/rejection - Add exception handlers to FastAPI app - Update tests for new validation logic Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
656f596d7e
commit
ae72b180e5
29
backend/app/core/exception_handlers.py
Normal file
29
backend/app/core/exception_handlers.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""Exception handlers for FastAPI."""
|
||||
from fastapi import Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.core.exceptions import AppException
|
||||
|
||||
|
||||
async def app_exception_handler(request: Request, exc: AppException):
|
||||
"""Handle application exceptions."""
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
"code": exc.code,
|
||||
"message": exc.message,
|
||||
"details": exc.details,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def generic_exception_handler(request: Request, exc: Exception):
|
||||
"""Handle generic exceptions."""
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={
|
||||
"code": 50000,
|
||||
"message": "Internal server error",
|
||||
"details": {},
|
||||
}
|
||||
)
|
||||
87
backend/app/core/exceptions.py
Normal file
87
backend/app/core/exceptions.py
Normal file
@ -0,0 +1,87 @@
|
||||
"""Core exceptions for the application."""
|
||||
|
||||
|
||||
class AppException(Exception):
|
||||
"""Base application exception."""
|
||||
|
||||
def __init__(self, code: int, message: str, details: dict = None):
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
|
||||
|
||||
class NotFoundError(AppException):
|
||||
"""Resource not found exception."""
|
||||
|
||||
def __init__(self, resource: str, resource_id: int = None):
|
||||
super().__init__(
|
||||
code=10002,
|
||||
message=f"{resource} not found",
|
||||
details={"resource": resource, "id": resource_id}
|
||||
)
|
||||
|
||||
|
||||
class ValidationError(AppException):
|
||||
"""Validation error exception."""
|
||||
|
||||
def __init__(self, message: str, details: dict = None):
|
||||
super().__init__(code=10001, message=message, details=details)
|
||||
|
||||
|
||||
class AuthenticationError(AppException):
|
||||
"""Authentication error exception."""
|
||||
|
||||
def __init__(self, message: str = "Authentication failed"):
|
||||
super().__init__(code=20001, message=message)
|
||||
|
||||
|
||||
class AuthorizationError(AppException):
|
||||
"""Authorization error exception."""
|
||||
|
||||
def __init__(self, message: str = "Permission denied"):
|
||||
super().__init__(code=20002, message=message)
|
||||
|
||||
|
||||
class SignatureError(AppException):
|
||||
"""Signature related error exception."""
|
||||
|
||||
def __init__(self, message: str, details: dict = None):
|
||||
super().__init__(code=40000, message=message, details=details)
|
||||
|
||||
|
||||
class SignatureExpiredError(SignatureError):
|
||||
"""Signature request expired exception."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(code=40001, message="Signature request has expired")
|
||||
|
||||
|
||||
class SignatureAlreadySignedError(SignatureError):
|
||||
"""Signature already completed exception."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(code=40002, message="Signature request already completed")
|
||||
|
||||
|
||||
class ContractError(AppException):
|
||||
"""Contract related error exception."""
|
||||
|
||||
def __init__(self, message: str, details: dict = None):
|
||||
super().__init__(code=50000, message=message, details=details)
|
||||
|
||||
|
||||
class InvalidContractStatusError(ContractError):
|
||||
"""Invalid contract status for operation."""
|
||||
|
||||
def __init__(self, current_status: str, required_status: str):
|
||||
super().__init__(
|
||||
message=f"Invalid contract status: {current_status}, expected: {required_status}",
|
||||
details={"current_status": current_status, "required_status": required_status}
|
||||
)
|
||||
|
||||
|
||||
class LLMError(AppException):
|
||||
"""LLM service error exception."""
|
||||
|
||||
def __init__(self, message: str = "LLM service error"):
|
||||
super().__init__(code=30001, message=message)
|
||||
@ -5,6 +5,8 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import init_db
|
||||
from app.core.exceptions import AppException
|
||||
from app.core.exception_handlers import app_exception_handler
|
||||
from app.api.v1 import laws, analyses, contracts, signatures
|
||||
|
||||
|
||||
@ -24,6 +26,9 @@ app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Register exception handlers
|
||||
app.add_exception_handler(AppException, app_exception_handler)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import List, Optional
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import InvalidContractStatusError, NotFoundError
|
||||
from app.models.contract import Contract, ContractTemplate, ContractApproval, ContractStatus
|
||||
|
||||
|
||||
@ -42,6 +43,13 @@ class ContractService:
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_contract_or_raise(self, contract_id: int) -> Contract:
|
||||
"""Get a contract by ID or raise exception."""
|
||||
contract = await self.get_contract_by_id(contract_id)
|
||||
if not contract:
|
||||
raise NotFoundError("Contract", contract_id)
|
||||
return contract
|
||||
|
||||
async def get_contracts_list(
|
||||
self,
|
||||
skip: int = 0,
|
||||
@ -81,6 +89,17 @@ class ContractService:
|
||||
|
||||
async def submit_for_approval(self, contract_id: int) -> Optional[Contract]:
|
||||
"""Submit a contract for approval."""
|
||||
contract = await self.get_contract_by_id(contract_id)
|
||||
if not contract:
|
||||
return None
|
||||
|
||||
# Only draft or rejected contracts can be submitted
|
||||
if contract.status not in [ContractStatus.DRAFT, ContractStatus.REJECTED]:
|
||||
raise InvalidContractStatusError(
|
||||
current_status=contract.status.value,
|
||||
required_status="draft or rejected"
|
||||
)
|
||||
|
||||
return await self.update_contract(
|
||||
contract_id,
|
||||
status=ContractStatus.PENDING_APPROVAL
|
||||
@ -93,6 +112,15 @@ class ContractService:
|
||||
comment: Optional[str] = None,
|
||||
) -> ContractApproval:
|
||||
"""Approve a contract."""
|
||||
contract = await self.get_contract_or_raise(contract_id)
|
||||
|
||||
# Only pending approval contracts can be approved
|
||||
if contract.status != ContractStatus.PENDING_APPROVAL:
|
||||
raise InvalidContractStatusError(
|
||||
current_status=contract.status.value,
|
||||
required_status="pending_approval"
|
||||
)
|
||||
|
||||
approval = ContractApproval(
|
||||
contract_id=contract_id,
|
||||
approver_id=approver_id,
|
||||
@ -102,7 +130,7 @@ class ContractService:
|
||||
self.db.add(approval)
|
||||
|
||||
# Update contract status
|
||||
await self.update_contract(contract_id, status=ContractStatus.APPROVED)
|
||||
contract.status = ContractStatus.APPROVED
|
||||
|
||||
await self.db.flush()
|
||||
await self.db.refresh(approval)
|
||||
@ -115,6 +143,15 @@ class ContractService:
|
||||
comment: Optional[str] = None,
|
||||
) -> ContractApproval:
|
||||
"""Reject a contract."""
|
||||
contract = await self.get_contract_or_raise(contract_id)
|
||||
|
||||
# Only pending approval contracts can be rejected
|
||||
if contract.status != ContractStatus.PENDING_APPROVAL:
|
||||
raise InvalidContractStatusError(
|
||||
current_status=contract.status.value,
|
||||
required_status="pending_approval"
|
||||
)
|
||||
|
||||
approval = ContractApproval(
|
||||
contract_id=contract_id,
|
||||
approver_id=approver_id,
|
||||
@ -124,7 +161,7 @@ class ContractService:
|
||||
self.db.add(approval)
|
||||
|
||||
# Update contract status
|
||||
await self.update_contract(contract_id, status=ContractStatus.REJECTED)
|
||||
contract.status = ContractStatus.REJECTED
|
||||
|
||||
await self.db.flush()
|
||||
await self.db.refresh(approval)
|
||||
|
||||
@ -8,6 +8,11 @@ from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import (
|
||||
NotFoundError,
|
||||
SignatureExpiredError,
|
||||
SignatureAlreadySignedError,
|
||||
)
|
||||
from app.models.signature import (
|
||||
SignatureRequest,
|
||||
Signature,
|
||||
@ -40,6 +45,14 @@ class SignatureService:
|
||||
expire_hours: int = None,
|
||||
) -> SignatureRequest:
|
||||
"""Create a signature request."""
|
||||
# Verify contract exists
|
||||
result = await self.db.execute(
|
||||
select(Contract).where(Contract.id == contract_id)
|
||||
)
|
||||
contract = result.scalar_one_or_none()
|
||||
if not contract:
|
||||
raise NotFoundError("Contract", contract_id)
|
||||
|
||||
expire_hours = expire_hours or settings.SIGNATURE_TOKEN_EXPIRE_HOURS
|
||||
|
||||
request = SignatureRequest(
|
||||
@ -51,6 +64,8 @@ class SignatureService:
|
||||
expires_at=datetime.utcnow() + timedelta(hours=expire_hours),
|
||||
)
|
||||
self.db.add(request)
|
||||
await self.db.flush()
|
||||
await self.db.refresh(request)
|
||||
|
||||
# Create audit log
|
||||
await self.create_audit_log(
|
||||
@ -59,8 +74,6 @@ class SignatureService:
|
||||
details={"signer_email": signer_email}
|
||||
)
|
||||
|
||||
await self.db.flush()
|
||||
await self.db.refresh(request)
|
||||
return request
|
||||
|
||||
async def get_signature_request_by_token(
|
||||
@ -83,13 +96,21 @@ class SignatureService:
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def validate_request(self, request: SignatureRequest) -> None:
|
||||
"""Validate a signature request. Raises exception if invalid."""
|
||||
if request.status != SignatureStatus.PENDING:
|
||||
raise SignatureAlreadySignedError()
|
||||
|
||||
if request.expires_at < datetime.utcnow():
|
||||
raise SignatureExpiredError()
|
||||
|
||||
async def is_request_valid(self, request: SignatureRequest) -> bool:
|
||||
"""Check if a signature request is valid."""
|
||||
if request.status != SignatureStatus.PENDING:
|
||||
try:
|
||||
await self.validate_request(request)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
if request.expires_at < datetime.utcnow():
|
||||
return False
|
||||
return True
|
||||
|
||||
async def sign_document(
|
||||
self,
|
||||
@ -99,6 +120,9 @@ class SignatureService:
|
||||
user_agent: Optional[str] = None,
|
||||
) -> Signature:
|
||||
"""Sign a document."""
|
||||
# Validate request first
|
||||
await self.validate_request(request)
|
||||
|
||||
# Get contract content for hash
|
||||
result = await self.db.execute(
|
||||
select(Contract).where(Contract.id == request.contract_id)
|
||||
@ -106,7 +130,7 @@ class SignatureService:
|
||||
contract = result.scalar_one_or_none()
|
||||
|
||||
if not contract:
|
||||
raise ValueError("Contract not found")
|
||||
raise NotFoundError("Contract", request.contract_id)
|
||||
|
||||
verification_hash = self.generate_verification_hash(contract.content)
|
||||
|
||||
|
||||
@ -76,6 +76,9 @@ class TestContractService:
|
||||
created_by=1,
|
||||
)
|
||||
|
||||
# Submit for approval first
|
||||
await service.submit_for_approval(contract.id)
|
||||
|
||||
approval = await service.approve_contract(contract.id, 1, "同意")
|
||||
|
||||
assert approval.status == ContractStatus.APPROVED
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user