root 315326d0a2 feat(middleware): add auth, logging, and audit middleware
- Add authentication middleware with API key validation
- Add request logging middleware for observability
- Add audit logging middleware for admin operations
- Refactor API endpoints to use centralized auth middleware
- Add comprehensive unit tests for all middleware
- Add API documentation and deployment guide
- Update README with health endpoints and documentation links
- Fix test data isolation in router tests

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-03 03:23:37 +08:00

88 lines
2.9 KiB
Python

"""Request logging middleware."""
import time
from decimal import Decimal
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from app.db.database import AsyncSessionLocal
from app.models.usage import RequestLog
from app.utils.logging import get_logger
logger = get_logger(__name__)
class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""Middleware to log all API requests."""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""Process request and log details."""
start_time = time.time()
# Process request
response = await call_next(request)
# Only log API requests (not health checks, docs, etc.)
if request.url.path.startswith("/v1/"):
latency_ms = int((time.time() - start_time) * 1000)
# Try to log request
try:
await self._log_request(request, response, latency_ms)
except Exception as e:
logger.error(f"Failed to log request: {e}")
return response
async def _log_request(
self,
request: Request,
response: Response,
latency_ms: int,
) -> None:
"""Log request details to database."""
# Get API key info from state if available
api_key_id = getattr(request.state, "api_key_id", None)
project_id = getattr(request.state, "project_id", None)
provider = getattr(request.state, "provider", "unknown")
model = getattr(request.state, "model", "unknown")
input_tokens = getattr(request.state, "input_tokens", 0)
output_tokens = getattr(request.state, "output_tokens", 0)
total_tokens = getattr(request.state, "total_tokens", 0)
cost = getattr(request.state, "cost", Decimal("0"))
# Determine request type from path
request_type = "unknown"
if "/chat/completions" in request.url.path:
request_type = "chat"
elif "/messages" in request.url.path:
request_type = "messages"
elif "/responses" in request.url.path:
request_type = "responses"
# Create log entry
log = RequestLog(
virtual_key_id=api_key_id,
project_id=project_id,
provider=provider,
model=model,
model_alias=model,
request_type=request_type,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
status_code=response.status_code,
latency_ms=latency_ms,
cost_usd=cost,
)
# Save to database
async with AsyncSessionLocal() as session:
session.add(log)
await session.commit()
def setup_request_logging(app) -> None:
"""Add request logging middleware to the app."""
app.add_middleware(RequestLoggingMiddleware)