"""Contract service for contract management.""" from typing import List, Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.exceptions import InvalidContractStatusError, NotFoundError from app.models.contract import Contract, ContractTemplate, ContractApproval, ContractStatus class ContractService: """Service for contract operations.""" def __init__(self, db: AsyncSession): self.db = db async def create_contract( self, title: str, party_a: str, party_b: str, content: str, created_by: int, **kwargs ) -> Contract: """Create a new contract.""" contract = Contract( title=title, party_a=party_a, party_b=party_b, content=content, created_by=created_by, **kwargs ) self.db.add(contract) await self.db.flush() await self.db.refresh(contract) return contract async def get_contract_by_id(self, contract_id: int) -> Optional[Contract]: """Get a contract by ID.""" result = await self.db.execute( select(Contract).where(Contract.id == contract_id) ) return result.scalar_one_or_none() async def get_contract_or_raise(self, contract_id: int) -> Contract: """Get a contract by ID or raise exception.""" contract = await self.get_contract_by_id(contract_id) if not contract: raise NotFoundError("Contract", contract_id) return contract async def get_contracts_list( self, skip: int = 0, limit: int = 20, status: Optional[ContractStatus] = None, created_by: Optional[int] = None, ) -> List[Contract]: """Get list of contracts.""" query = select(Contract) if status: query = query.where(Contract.status == status) if created_by: query = query.where(Contract.created_by == created_by) query = query.offset(skip).limit(limit).order_by(Contract.created_at.desc()) result = await self.db.execute(query) return list(result.scalars().all()) async def update_contract( self, contract_id: int, **kwargs ) -> Optional[Contract]: """Update a contract.""" contract = await self.get_contract_by_id(contract_id) if not contract: return None for key, value in kwargs.items(): if hasattr(contract, key) and value is not None: setattr(contract, key, value) await self.db.flush() await self.db.refresh(contract) return contract async def submit_for_approval(self, contract_id: int) -> Optional[Contract]: """Submit a contract for approval.""" contract = await self.get_contract_by_id(contract_id) if not contract: return None # Only draft or rejected contracts can be submitted if contract.status not in [ContractStatus.DRAFT, ContractStatus.REJECTED]: raise InvalidContractStatusError( current_status=contract.status.value, required_status="draft or rejected" ) return await self.update_contract( contract_id, status=ContractStatus.PENDING_APPROVAL ) async def approve_contract( self, contract_id: int, approver_id: int, comment: Optional[str] = None, ) -> ContractApproval: """Approve a contract.""" contract = await self.get_contract_or_raise(contract_id) # Only pending approval contracts can be approved if contract.status != ContractStatus.PENDING_APPROVAL: raise InvalidContractStatusError( current_status=contract.status.value, required_status="pending_approval" ) approval = ContractApproval( contract_id=contract_id, approver_id=approver_id, status=ContractStatus.APPROVED, comment=comment, ) self.db.add(approval) # Update contract status contract.status = ContractStatus.APPROVED await self.db.flush() await self.db.refresh(approval) return approval async def reject_contract( self, contract_id: int, approver_id: int, comment: Optional[str] = None, ) -> ContractApproval: """Reject a contract.""" contract = await self.get_contract_or_raise(contract_id) # Only pending approval contracts can be rejected if contract.status != ContractStatus.PENDING_APPROVAL: raise InvalidContractStatusError( current_status=contract.status.value, required_status="pending_approval" ) approval = ContractApproval( contract_id=contract_id, approver_id=approver_id, status=ContractStatus.REJECTED, comment=comment, ) self.db.add(approval) # Update contract status contract.status = ContractStatus.REJECTED await self.db.flush() await self.db.refresh(approval) return approval class ContractTemplateService: """Service for contract template operations.""" def __init__(self, db: AsyncSession): self.db = db async def create_template( self, name: str, content: str, created_by: int, contract_type: Optional[str] = None, ) -> ContractTemplate: """Create a contract template.""" template = ContractTemplate( name=name, content=content, created_by=created_by, contract_type=contract_type, ) self.db.add(template) await self.db.flush() await self.db.refresh(template) return template async def get_templates_list( self, skip: int = 0, limit: int = 20, ) -> List[ContractTemplate]: """Get list of templates.""" result = await self.db.execute( select(ContractTemplate) .offset(skip) .limit(limit) .order_by(ContractTemplate.created_at.desc()) ) return list(result.scalars().all())