fix: improve error handling and validation
- Add centralized exception classes - Validate signature request status before signing - Check contract status before approval/rejection - Add exception handlers to FastAPI app - Update tests for new validation logic Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
656f596d7e
commit
ae72b180e5
29
backend/app/core/exception_handlers.py
Normal file
29
backend/app/core/exception_handlers.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
"""Exception handlers for FastAPI."""
|
||||||
|
from fastapi import Request, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from app.core.exceptions import AppException
|
||||||
|
|
||||||
|
|
||||||
|
async def app_exception_handler(request: Request, exc: AppException):
|
||||||
|
"""Handle application exceptions."""
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
content={
|
||||||
|
"code": exc.code,
|
||||||
|
"message": exc.message,
|
||||||
|
"details": exc.details,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def generic_exception_handler(request: Request, exc: Exception):
|
||||||
|
"""Handle generic exceptions."""
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
content={
|
||||||
|
"code": 50000,
|
||||||
|
"message": "Internal server error",
|
||||||
|
"details": {},
|
||||||
|
}
|
||||||
|
)
|
||||||
87
backend/app/core/exceptions.py
Normal file
87
backend/app/core/exceptions.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
"""Core exceptions for the application."""
|
||||||
|
|
||||||
|
|
||||||
|
class AppException(Exception):
|
||||||
|
"""Base application exception."""
|
||||||
|
|
||||||
|
def __init__(self, code: int, message: str, details: dict = None):
|
||||||
|
self.code = code
|
||||||
|
self.message = message
|
||||||
|
self.details = details or {}
|
||||||
|
|
||||||
|
|
||||||
|
class NotFoundError(AppException):
|
||||||
|
"""Resource not found exception."""
|
||||||
|
|
||||||
|
def __init__(self, resource: str, resource_id: int = None):
|
||||||
|
super().__init__(
|
||||||
|
code=10002,
|
||||||
|
message=f"{resource} not found",
|
||||||
|
details={"resource": resource, "id": resource_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationError(AppException):
|
||||||
|
"""Validation error exception."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, details: dict = None):
|
||||||
|
super().__init__(code=10001, message=message, details=details)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationError(AppException):
|
||||||
|
"""Authentication error exception."""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "Authentication failed"):
|
||||||
|
super().__init__(code=20001, message=message)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizationError(AppException):
|
||||||
|
"""Authorization error exception."""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "Permission denied"):
|
||||||
|
super().__init__(code=20002, message=message)
|
||||||
|
|
||||||
|
|
||||||
|
class SignatureError(AppException):
|
||||||
|
"""Signature related error exception."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, details: dict = None):
|
||||||
|
super().__init__(code=40000, message=message, details=details)
|
||||||
|
|
||||||
|
|
||||||
|
class SignatureExpiredError(SignatureError):
|
||||||
|
"""Signature request expired exception."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(code=40001, message="Signature request has expired")
|
||||||
|
|
||||||
|
|
||||||
|
class SignatureAlreadySignedError(SignatureError):
|
||||||
|
"""Signature already completed exception."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(code=40002, message="Signature request already completed")
|
||||||
|
|
||||||
|
|
||||||
|
class ContractError(AppException):
|
||||||
|
"""Contract related error exception."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, details: dict = None):
|
||||||
|
super().__init__(code=50000, message=message, details=details)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidContractStatusError(ContractError):
|
||||||
|
"""Invalid contract status for operation."""
|
||||||
|
|
||||||
|
def __init__(self, current_status: str, required_status: str):
|
||||||
|
super().__init__(
|
||||||
|
message=f"Invalid contract status: {current_status}, expected: {required_status}",
|
||||||
|
details={"current_status": current_status, "required_status": required_status}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMError(AppException):
|
||||||
|
"""LLM service error exception."""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "LLM service error"):
|
||||||
|
super().__init__(code=30001, message=message)
|
||||||
@ -5,6 +5,8 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.database import init_db
|
from app.core.database import init_db
|
||||||
|
from app.core.exceptions import AppException
|
||||||
|
from app.core.exception_handlers import app_exception_handler
|
||||||
from app.api.v1 import laws, analyses, contracts, signatures
|
from app.api.v1 import laws, analyses, contracts, signatures
|
||||||
|
|
||||||
|
|
||||||
@ -24,6 +26,9 @@ app = FastAPI(
|
|||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Register exception handlers
|
||||||
|
app.add_exception_handler(AppException, app_exception_handler)
|
||||||
|
|
||||||
# CORS middleware
|
# CORS middleware
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from typing import List, Optional
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.exceptions import InvalidContractStatusError, NotFoundError
|
||||||
from app.models.contract import Contract, ContractTemplate, ContractApproval, ContractStatus
|
from app.models.contract import Contract, ContractTemplate, ContractApproval, ContractStatus
|
||||||
|
|
||||||
|
|
||||||
@ -42,6 +43,13 @@ class ContractService:
|
|||||||
)
|
)
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_contract_or_raise(self, contract_id: int) -> Contract:
|
||||||
|
"""Get a contract by ID or raise exception."""
|
||||||
|
contract = await self.get_contract_by_id(contract_id)
|
||||||
|
if not contract:
|
||||||
|
raise NotFoundError("Contract", contract_id)
|
||||||
|
return contract
|
||||||
|
|
||||||
async def get_contracts_list(
|
async def get_contracts_list(
|
||||||
self,
|
self,
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
@ -81,6 +89,17 @@ class ContractService:
|
|||||||
|
|
||||||
async def submit_for_approval(self, contract_id: int) -> Optional[Contract]:
|
async def submit_for_approval(self, contract_id: int) -> Optional[Contract]:
|
||||||
"""Submit a contract for approval."""
|
"""Submit a contract for approval."""
|
||||||
|
contract = await self.get_contract_by_id(contract_id)
|
||||||
|
if not contract:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Only draft or rejected contracts can be submitted
|
||||||
|
if contract.status not in [ContractStatus.DRAFT, ContractStatus.REJECTED]:
|
||||||
|
raise InvalidContractStatusError(
|
||||||
|
current_status=contract.status.value,
|
||||||
|
required_status="draft or rejected"
|
||||||
|
)
|
||||||
|
|
||||||
return await self.update_contract(
|
return await self.update_contract(
|
||||||
contract_id,
|
contract_id,
|
||||||
status=ContractStatus.PENDING_APPROVAL
|
status=ContractStatus.PENDING_APPROVAL
|
||||||
@ -93,6 +112,15 @@ class ContractService:
|
|||||||
comment: Optional[str] = None,
|
comment: Optional[str] = None,
|
||||||
) -> ContractApproval:
|
) -> ContractApproval:
|
||||||
"""Approve a contract."""
|
"""Approve a contract."""
|
||||||
|
contract = await self.get_contract_or_raise(contract_id)
|
||||||
|
|
||||||
|
# Only pending approval contracts can be approved
|
||||||
|
if contract.status != ContractStatus.PENDING_APPROVAL:
|
||||||
|
raise InvalidContractStatusError(
|
||||||
|
current_status=contract.status.value,
|
||||||
|
required_status="pending_approval"
|
||||||
|
)
|
||||||
|
|
||||||
approval = ContractApproval(
|
approval = ContractApproval(
|
||||||
contract_id=contract_id,
|
contract_id=contract_id,
|
||||||
approver_id=approver_id,
|
approver_id=approver_id,
|
||||||
@ -102,7 +130,7 @@ class ContractService:
|
|||||||
self.db.add(approval)
|
self.db.add(approval)
|
||||||
|
|
||||||
# Update contract status
|
# Update contract status
|
||||||
await self.update_contract(contract_id, status=ContractStatus.APPROVED)
|
contract.status = ContractStatus.APPROVED
|
||||||
|
|
||||||
await self.db.flush()
|
await self.db.flush()
|
||||||
await self.db.refresh(approval)
|
await self.db.refresh(approval)
|
||||||
@ -115,6 +143,15 @@ class ContractService:
|
|||||||
comment: Optional[str] = None,
|
comment: Optional[str] = None,
|
||||||
) -> ContractApproval:
|
) -> ContractApproval:
|
||||||
"""Reject a contract."""
|
"""Reject a contract."""
|
||||||
|
contract = await self.get_contract_or_raise(contract_id)
|
||||||
|
|
||||||
|
# Only pending approval contracts can be rejected
|
||||||
|
if contract.status != ContractStatus.PENDING_APPROVAL:
|
||||||
|
raise InvalidContractStatusError(
|
||||||
|
current_status=contract.status.value,
|
||||||
|
required_status="pending_approval"
|
||||||
|
)
|
||||||
|
|
||||||
approval = ContractApproval(
|
approval = ContractApproval(
|
||||||
contract_id=contract_id,
|
contract_id=contract_id,
|
||||||
approver_id=approver_id,
|
approver_id=approver_id,
|
||||||
@ -124,7 +161,7 @@ class ContractService:
|
|||||||
self.db.add(approval)
|
self.db.add(approval)
|
||||||
|
|
||||||
# Update contract status
|
# Update contract status
|
||||||
await self.update_contract(contract_id, status=ContractStatus.REJECTED)
|
contract.status = ContractStatus.REJECTED
|
||||||
|
|
||||||
await self.db.flush()
|
await self.db.flush()
|
||||||
await self.db.refresh(approval)
|
await self.db.refresh(approval)
|
||||||
|
|||||||
@ -8,6 +8,11 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.core.exceptions import (
|
||||||
|
NotFoundError,
|
||||||
|
SignatureExpiredError,
|
||||||
|
SignatureAlreadySignedError,
|
||||||
|
)
|
||||||
from app.models.signature import (
|
from app.models.signature import (
|
||||||
SignatureRequest,
|
SignatureRequest,
|
||||||
Signature,
|
Signature,
|
||||||
@ -40,6 +45,14 @@ class SignatureService:
|
|||||||
expire_hours: int = None,
|
expire_hours: int = None,
|
||||||
) -> SignatureRequest:
|
) -> SignatureRequest:
|
||||||
"""Create a signature request."""
|
"""Create a signature request."""
|
||||||
|
# Verify contract exists
|
||||||
|
result = await self.db.execute(
|
||||||
|
select(Contract).where(Contract.id == contract_id)
|
||||||
|
)
|
||||||
|
contract = result.scalar_one_or_none()
|
||||||
|
if not contract:
|
||||||
|
raise NotFoundError("Contract", contract_id)
|
||||||
|
|
||||||
expire_hours = expire_hours or settings.SIGNATURE_TOKEN_EXPIRE_HOURS
|
expire_hours = expire_hours or settings.SIGNATURE_TOKEN_EXPIRE_HOURS
|
||||||
|
|
||||||
request = SignatureRequest(
|
request = SignatureRequest(
|
||||||
@ -51,6 +64,8 @@ class SignatureService:
|
|||||||
expires_at=datetime.utcnow() + timedelta(hours=expire_hours),
|
expires_at=datetime.utcnow() + timedelta(hours=expire_hours),
|
||||||
)
|
)
|
||||||
self.db.add(request)
|
self.db.add(request)
|
||||||
|
await self.db.flush()
|
||||||
|
await self.db.refresh(request)
|
||||||
|
|
||||||
# Create audit log
|
# Create audit log
|
||||||
await self.create_audit_log(
|
await self.create_audit_log(
|
||||||
@ -59,8 +74,6 @@ class SignatureService:
|
|||||||
details={"signer_email": signer_email}
|
details={"signer_email": signer_email}
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.db.flush()
|
|
||||||
await self.db.refresh(request)
|
|
||||||
return request
|
return request
|
||||||
|
|
||||||
async def get_signature_request_by_token(
|
async def get_signature_request_by_token(
|
||||||
@ -83,13 +96,21 @@ class SignatureService:
|
|||||||
)
|
)
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def validate_request(self, request: SignatureRequest) -> None:
|
||||||
|
"""Validate a signature request. Raises exception if invalid."""
|
||||||
|
if request.status != SignatureStatus.PENDING:
|
||||||
|
raise SignatureAlreadySignedError()
|
||||||
|
|
||||||
|
if request.expires_at < datetime.utcnow():
|
||||||
|
raise SignatureExpiredError()
|
||||||
|
|
||||||
async def is_request_valid(self, request: SignatureRequest) -> bool:
|
async def is_request_valid(self, request: SignatureRequest) -> bool:
|
||||||
"""Check if a signature request is valid."""
|
"""Check if a signature request is valid."""
|
||||||
if request.status != SignatureStatus.PENDING:
|
try:
|
||||||
return False
|
await self.validate_request(request)
|
||||||
if request.expires_at < datetime.utcnow():
|
|
||||||
return False
|
|
||||||
return True
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
async def sign_document(
|
async def sign_document(
|
||||||
self,
|
self,
|
||||||
@ -99,6 +120,9 @@ class SignatureService:
|
|||||||
user_agent: Optional[str] = None,
|
user_agent: Optional[str] = None,
|
||||||
) -> Signature:
|
) -> Signature:
|
||||||
"""Sign a document."""
|
"""Sign a document."""
|
||||||
|
# Validate request first
|
||||||
|
await self.validate_request(request)
|
||||||
|
|
||||||
# Get contract content for hash
|
# Get contract content for hash
|
||||||
result = await self.db.execute(
|
result = await self.db.execute(
|
||||||
select(Contract).where(Contract.id == request.contract_id)
|
select(Contract).where(Contract.id == request.contract_id)
|
||||||
@ -106,7 +130,7 @@ class SignatureService:
|
|||||||
contract = result.scalar_one_or_none()
|
contract = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not contract:
|
if not contract:
|
||||||
raise ValueError("Contract not found")
|
raise NotFoundError("Contract", request.contract_id)
|
||||||
|
|
||||||
verification_hash = self.generate_verification_hash(contract.content)
|
verification_hash = self.generate_verification_hash(contract.content)
|
||||||
|
|
||||||
|
|||||||
@ -76,6 +76,9 @@ class TestContractService:
|
|||||||
created_by=1,
|
created_by=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Submit for approval first
|
||||||
|
await service.submit_for_approval(contract.id)
|
||||||
|
|
||||||
approval = await service.approve_contract(contract.id, 1, "同意")
|
approval = await service.approve_contract(contract.id, 1, "同意")
|
||||||
|
|
||||||
assert approval.status == ContractStatus.APPROVED
|
assert approval.status == ContractStatus.APPROVED
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user