diff --git a/llm-gateway/.env.example b/llm-gateway/.env.example new file mode 100644 index 0000000..a2bac02 --- /dev/null +++ b/llm-gateway/.env.example @@ -0,0 +1,29 @@ +# Application +APP_NAME=LLM Gateway +DEBUG=false +LOG_LEVEL=INFO + +# Server +HOST=0.0.0.0 +PORT=8000 + +# Database +DATABASE_URL=sqlite:///data/gateway.db + +# Security +# Generate with: python -c "import secrets; print(secrets.token_hex(32))" +MASTER_KEY=your-master-key-here-at-least-32-characters + +# Rate Limiting +RATE_LIMIT_WINDOW_SECONDS=60 +# GLOBAL_RPM_LIMIT=1000 +# GLOBAL_TPM_LIMIT=1000000 + +# Retry +MAX_RETRIES=3 +RETRY_INITIAL_DELAY=1.0 +RETRY_MAX_DELAY=30.0 + +# Health Check +HEALTH_CHECK_INTERVAL=30 +HEALTH_CHECK_TIMEOUT=10 diff --git a/llm-gateway/.gitignore b/llm-gateway/.gitignore new file mode 100644 index 0000000..9186eb7 --- /dev/null +++ b/llm-gateway/.gitignore @@ -0,0 +1,65 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual environments +venv/ +ENV/ +env/ +.venv/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# Testing +.pytest_cache/ +.coverage +htmlcov/ +.tox/ +.nox/ + +# Type checking +.mypy_cache/ +.dmypy.json +dmypy.json + +# Environment +.env +.env.local +.env.*.local + +# Database +*.db +*.sqlite +*.sqlite3 +data/ + +# Logs +*.log +logs/ + +# OS +.DS_Store +Thumbs.db diff --git a/llm-gateway/Dockerfile b/llm-gateway/Dockerfile new file mode 100644 index 0000000..baf90ce --- /dev/null +++ b/llm-gateway/Dockerfile @@ -0,0 +1,28 @@ +FROM python:3.11-slim + +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements first for caching +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY . . + +# Create data directory +RUN mkdir -p /app/data + +# Expose port +EXPOSE 8000 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 + +# Run the application +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/llm-gateway/README.md b/llm-gateway/README.md new file mode 100644 index 0000000..2ab33d5 --- /dev/null +++ b/llm-gateway/README.md @@ -0,0 +1,41 @@ +# LLM Gateway + +A unified LLM Gateway supporting multiple providers and API formats. + +## Features + +- **Multi-format API**: OpenAI Chat Completions, OpenAI Responses API, Anthropic Messages API +- **Multi-provider support**: OpenAI, Anthropic, Azure OpenAI, Google Gemini, AWS Bedrock +- **Model aliasing and routing**: Flexible model mapping with load balancing +- **Rate limiting**: RPM/TPM limits at multiple levels +- **Budget control**: Key and Project level spending limits +- **High availability**: Fallback, retry, and circuit breaker +- **Observability**: Request logging and usage statistics + +## Quick Start + +```bash +# Install dependencies +pip install -r requirements.txt + +# Run the server +uvicorn app.main:app --reload +``` + +## API Endpoints + +### LLM APIs +- `POST /v1/chat/completions` - OpenAI-compatible chat completions +- `POST /v1/responses` - OpenAI Responses API +- `POST /v1/messages` - Anthropic Messages API + +### Admin APIs +- `GET|POST|PUT|DELETE /admin/providers` - Provider management +- `GET|POST|PUT|DELETE /admin/keys` - API Key management +- `GET|POST|PUT|DELETE /admin/projects` - Project management +- `GET|POST|PUT|DELETE /admin/models/aliases` - Model alias management +- `GET /admin/usage/stats` - Usage statistics + +## Configuration + +See `.env.example` for configuration options. diff --git a/llm-gateway/app/__init__.py b/llm-gateway/app/__init__.py new file mode 100644 index 0000000..b05ff53 --- /dev/null +++ b/llm-gateway/app/__init__.py @@ -0,0 +1 @@ +# app module diff --git a/llm-gateway/app/adapters/__init__.py b/llm-gateway/app/adapters/__init__.py new file mode 100644 index 0000000..9e429c0 --- /dev/null +++ b/llm-gateway/app/adapters/__init__.py @@ -0,0 +1,69 @@ +"""Provider adapters package.""" +from typing import Type + +from app.adapters.anthropic import AnthropicAdapter +from app.adapters.azure import AzureOpenAIAdapter +from app.adapters.base import BaseAdapter, HealthStatus, ProviderConfig +from app.adapters.bedrock import BedrockAdapter +from app.adapters.gemini import GeminiAdapter +from app.adapters.openai import OpenAIAdapter + + +# Registry of provider adapters +ADAPTER_REGISTRY: dict[str, Type[BaseAdapter]] = { + "openai": OpenAIAdapter, + "anthropic": AnthropicAdapter, + "azure": AzureOpenAIAdapter, + "gemini": GeminiAdapter, + "bedrock": BedrockAdapter, + "google": GeminiAdapter, # Alias + "aws": BedrockAdapter, # Alias +} + + +def get_adapter_class(provider: str) -> Type[BaseAdapter]: + """ + Get the adapter class for a provider. + + Args: + provider: Provider name. + + Returns: + Adapter class. + + Raises: + ValueError: If provider is not supported. + """ + provider_lower = provider.lower() + if provider_lower not in ADAPTER_REGISTRY: + raise ValueError(f"Unsupported provider: {provider}") + return ADAPTER_REGISTRY[provider_lower] + + +def create_adapter(config: ProviderConfig) -> BaseAdapter: + """ + Create an adapter instance for a provider. + + Args: + config: Provider configuration. + + Returns: + Adapter instance. + """ + adapter_class = get_adapter_class(config.name) + return adapter_class(config) + + +__all__ = [ + "BaseAdapter", + "HealthStatus", + "ProviderConfig", + "OpenAIAdapter", + "AnthropicAdapter", + "AzureOpenAIAdapter", + "GeminiAdapter", + "BedrockAdapter", + "get_adapter_class", + "create_adapter", + "ADAPTER_REGISTRY", +] diff --git a/llm-gateway/app/adapters/anthropic.py b/llm-gateway/app/adapters/anthropic.py new file mode 100644 index 0000000..a3bbaa4 --- /dev/null +++ b/llm-gateway/app/adapters/anthropic.py @@ -0,0 +1,206 @@ +"""Anthropic provider adapter.""" +import json +import time +import uuid +from typing import AsyncIterator + +import httpx + +from app.adapters.base import BaseAdapter, HealthStatus, ProviderConfig +from app.core.fallback import RetryableError, PermanentError, classify_error +from app.core.transformer import RequestTransformer +from app.schemas.anthropic import ( + AnthropicMessagesRequest, + AnthropicMessagesResponse, + AnthropicTextBlock, + AnthropicUsage, +) +from app.schemas.openai import ( + OpenAIChatCompletionChunk, + OpenAIChatCompletionRequest, + OpenAIChatCompletionResponse, +) + + +class AnthropicAdapter(BaseAdapter): + """Adapter for Anthropic Claude API.""" + + def __init__(self, config: ProviderConfig): + super().__init__(config) + self.transformer = RequestTransformer() + + def get_headers(self) -> dict[str, str]: + """Get headers for Anthropic API requests.""" + headers = { + **super().get_headers(), + "x-api-key": self.config.api_key, + "anthropic-version": self.config.api_version or "2023-06-01", + } + return headers + + async def chat_completions( + self, + request: OpenAIChatCompletionRequest, + ) -> OpenAIChatCompletionResponse: + """ + Execute an OpenAI-format chat completion request via Anthropic. + Converts request to Anthropic format, calls Anthropic, converts response back. + """ + # Convert OpenAI request to Anthropic format + anthropic_request = self.transformer.openai_to_anthropic(request) + + # Call Anthropic + anthropic_response = await self.messages(anthropic_request) + + # Convert response back to OpenAI format + return self.transformer.anthropic_response_to_openai( + anthropic_response, request.model + ) + + async def stream_chat_completions( + self, + request: OpenAIChatCompletionRequest, + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """Execute a streaming chat completion request via Anthropic.""" + # Convert to Anthropic format + anthropic_request = self.transformer.openai_to_anthropic(request) + anthropic_request.stream = True + + url = f"{self.config.api_base}/v1/messages" + payload = anthropic_request.model_dump(exclude_none=True) + + async with httpx.AsyncClient(timeout=120.0) as client: + try: + async with client.stream( + "POST", + url, + headers=self.get_headers(), + json=payload, + ) as response: + if response.status_code >= 400: + error = classify_error( + Exception(f"{response.status_code}: {await response.aread()}") + ) + raise error + + message_id = None + model = request.model + + async for line in response.aiter_lines(): + if not line.startswith("data: "): + continue + + data = line[6:] + try: + event = json.loads(data) + except json.JSONDecodeError: + continue + + event_type = event.get("type") + + if event_type == "message_start": + message_id = event.get("message", {}).get("id") + + elif event_type == "content_block_delta": + delta = event.get("delta", {}) + text = delta.get("text", "") + + chunk = OpenAIChatCompletionChunk( + id=message_id or f"chatcmpl-{uuid.uuid4().hex[:24]}", + object="chat.completion.chunk", + created=int(time.time()), + model=model, + choices=[ + { + "index": 0, + "delta": {"content": text, "role": "assistant"}, + "finish_reason": None, + } + ], + ) + yield chunk + + elif event_type == "message_delta": + delta = event.get("delta", {}) + stop_reason = delta.get("stop_reason") + + if stop_reason: + finish_reason = "stop" if stop_reason == "end_turn" else "length" + chunk = OpenAIChatCompletionChunk( + id=message_id or f"chatcmpl-{uuid.uuid4().hex[:24]}", + object="chat.completion.chunk", + created=int(time.time()), + model=model, + choices=[ + { + "index": 0, + "delta": {}, + "finish_reason": finish_reason, + } + ], + ) + yield chunk + + except httpx.TimeoutException as e: + raise RetryableError(str(e), error_type="timeout") + except httpx.ConnectError as e: + raise RetryableError(str(e), error_type="connection_error") + + async def messages( + self, + request: AnthropicMessagesRequest, + ) -> AnthropicMessagesResponse: + """Execute an Anthropic Messages API request.""" + url = f"{self.config.api_base}/v1/messages" + + payload = request.model_dump(exclude_none=True) + + async with httpx.AsyncClient(timeout=120.0) as client: + try: + response = await client.post( + url, + headers=self.get_headers(), + json=payload, + ) + + if response.status_code >= 400: + error = classify_error(Exception(f"{response.status_code}: {response.text}")) + raise error + + data = response.json() + return AnthropicMessagesResponse(**data) + + except httpx.TimeoutException as e: + raise RetryableError(str(e), error_type="timeout") + except httpx.ConnectError as e: + raise RetryableError(str(e), error_type="connection_error") + + async def check_health(self) -> HealthStatus: + """Check Anthropic API health by making a minimal request.""" + # Anthropic doesn't have a health endpoint, so we make a minimal request + url = f"{self.config.api_base}/v1/messages" + + # Minimal request + payload = { + "model": "claude-3-haiku-20240307", + "max_tokens": 1, + "messages": [{"role": "user", "content": "hi"}], + } + + async with httpx.AsyncClient(timeout=10.0) as client: + try: + response = await client.post( + url, + headers=self.get_headers(), + json=payload, + ) + + if response.status_code == 200: + return HealthStatus.HEALTHY + elif response.status_code >= 500: + return HealthStatus.UNHEALTHY + else: + return HealthStatus.DEGRADED + + except Exception: + return HealthStatus.UNHEALTHY diff --git a/llm-gateway/app/adapters/azure.py b/llm-gateway/app/adapters/azure.py new file mode 100644 index 0000000..8faf30c --- /dev/null +++ b/llm-gateway/app/adapters/azure.py @@ -0,0 +1,79 @@ +"""Azure OpenAI provider adapter.""" +from typing import AsyncIterator + +from app.adapters.base import BaseAdapter, HealthStatus, ProviderConfig +from app.adapters.openai import OpenAIAdapter +from app.schemas.anthropic import AnthropicMessagesRequest, AnthropicMessagesResponse +from app.schemas.openai import ( + OpenAIChatCompletionChunk, + OpenAIChatCompletionRequest, + OpenAIChatCompletionResponse, +) + + +class AzureOpenAIAdapter(BaseAdapter): + """Adapter for Azure OpenAI API.""" + + def __init__(self, config: ProviderConfig): + super().__init__(config) + # Azure uses OpenAI-compatible API + self._openai_adapter: OpenAIAdapter | None = None + + def _get_openai_adapter(self) -> OpenAIAdapter: + """Get or create OpenAI adapter for Azure.""" + if self._openai_adapter is None: + # Azure config should include deployment_name + azure_config = self.config.config or {} + deployment_name = azure_config.get("deployment_name", "") + + # Build Azure-specific API base + api_base = self.config.api_base + if deployment_name: + api_base = f"{api_base}/openai/deployments/{deployment_name}" + + openai_config = ProviderConfig( + name=self.config.name, + api_base=api_base, + api_key=self.config.api_key, + api_version=self.config.api_version, + ) + self._openai_adapter = OpenAIAdapter(openai_config) + return self._openai_adapter + + def get_headers(self) -> dict[str, str]: + """Get headers for Azure OpenAI API requests.""" + return { + "Content-Type": "application/json", + "api-key": self.config.api_key, # Azure uses api-key header + } + + async def chat_completions( + self, + request: OpenAIChatCompletionRequest, + ) -> OpenAIChatCompletionResponse: + """Execute a chat completion request to Azure OpenAI.""" + adapter = self._get_openai_adapter() + return await adapter.chat_completions(request) + + async def stream_chat_completions( + self, + request: OpenAIChatCompletionRequest, + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """Execute a streaming chat completion request to Azure OpenAI.""" + adapter = self._get_openai_adapter() + async for chunk in adapter.stream_chat_completions(request): + yield chunk + + async def messages( + self, + request: AnthropicMessagesRequest, + ) -> AnthropicMessagesResponse: + """Execute an Anthropic Messages API request via Azure OpenAI.""" + adapter = self._get_openai_adapter() + return await adapter.messages(request) + + async def check_health(self) -> HealthStatus: + """Check Azure OpenAI API health.""" + # Use OpenAI-style health check + adapter = self._get_openai_adapter() + return await adapter.check_health() diff --git a/llm-gateway/app/adapters/base.py b/llm-gateway/app/adapters/base.py new file mode 100644 index 0000000..a3b3b50 --- /dev/null +++ b/llm-gateway/app/adapters/base.py @@ -0,0 +1,135 @@ +"""Base adapter interface for LLM providers.""" +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any, AsyncIterator + +from app.schemas.anthropic import AnthropicMessagesRequest, AnthropicMessagesResponse +from app.schemas.openai import ( + OpenAIChatCompletionRequest, + OpenAIChatCompletionResponse, + OpenAIChatCompletionChunk, +) + + +class HealthStatus(Enum): + """Provider health status.""" + + HEALTHY = "healthy" + UNHEALTHY = "unhealthy" + DEGRADED = "degraded" + + +@dataclass +class ProviderConfig: + """Configuration for a provider.""" + + name: str + api_base: str + api_key: str + api_version: str | None = None + config: dict[str, Any] | None = None + rpm_limit: int | None = None + tpm_limit: int | None = None + + +class BaseAdapter(ABC): + """Abstract base class for LLM provider adapters.""" + + def __init__(self, config: ProviderConfig): + self.config = config + + @abstractmethod + async def chat_completions( + self, + request: OpenAIChatCompletionRequest, + ) -> OpenAIChatCompletionResponse: + """ + Execute a chat completion request. + + Args: + request: OpenAI-format chat completion request. + + Returns: + OpenAI-format chat completion response. + """ + pass + + @abstractmethod + async def stream_chat_completions( + self, + request: OpenAIChatCompletionRequest, + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """ + Execute a streaming chat completion request. + + Args: + request: OpenAI-format chat completion request. + + Yields: + OpenAI-format chat completion chunks. + """ + pass + + @abstractmethod + async def messages( + self, + request: AnthropicMessagesRequest, + ) -> AnthropicMessagesResponse: + """ + Execute an Anthropic Messages API request. + + Args: + request: Anthropic-format messages request. + + Returns: + Anthropic-format messages response. + """ + pass + + async def count_tokens( + self, + request: OpenAIChatCompletionRequest | AnthropicMessagesRequest, + ) -> int: + """ + Estimate token count for a request. + + Args: + request: The request to count tokens for. + + Returns: + Estimated token count. + """ + # Default implementation: rough estimation + if isinstance(request, OpenAIChatCompletionRequest): + total = 0 + for msg in request.messages: + if isinstance(msg.content, str): + total += len(msg.content.split()) + return int(total * 1.3) # Rough word-to-token ratio + elif isinstance(request, AnthropicMessagesRequest): + total = 0 + for msg in request.messages: + if isinstance(msg.content, str): + total += len(msg.content.split()) + if request.system: + total += len(request.system.split()) + return int(total * 1.3) + return 0 + + @abstractmethod + async def check_health(self) -> HealthStatus: + """ + Check the health of the provider. + + Returns: + Health status of the provider. + """ + pass + + def get_headers(self) -> dict[str, str]: + """Get common headers for requests.""" + return { + "Content-Type": "application/json", + "Accept": "application/json", + } diff --git a/llm-gateway/app/adapters/bedrock.py b/llm-gateway/app/adapters/bedrock.py new file mode 100644 index 0000000..0b4c627 --- /dev/null +++ b/llm-gateway/app/adapters/bedrock.py @@ -0,0 +1,240 @@ +"""AWS Bedrock provider adapter.""" +import json +import time +import uuid +from typing import AsyncIterator, Any + +try: + import boto3 + from botocore.config import Config + from botocore.exceptions import ClientError + HAS_BOTO3 = True +except ImportError: + HAS_BOTO3 = False + +from app.adapters.base import BaseAdapter, HealthStatus, ProviderConfig +from app.core.fallback import RetryableError, classify_error +from app.core.transformer import RequestTransformer +from app.schemas.anthropic import ( + AnthropicMessagesRequest, + AnthropicMessagesResponse, +) +from app.schemas.openai import ( + OpenAIChatCompletionChunk, + OpenAIChatCompletionRequest, + OpenAIChatCompletionResponse, + OpenAIChatCompletionChoice, + OpenAIChatMessage, + OpenAIUsage, +) + + +class BedrockAdapter(BaseAdapter): + """Adapter for AWS Bedrock API.""" + + def __init__(self, config: ProviderConfig): + super().__init__(config) + self.transformer = RequestTransformer() + self._client = None + + def _get_client(self): + """Get or create Bedrock runtime client.""" + if self._client is None and HAS_BOTO3: + aws_config = self.config.config or {} + region = aws_config.get("region", "us-east-1") + + self._client = boto3.client( + "bedrock-runtime", + region_name=region, + config=Config( + retries={"max_attempts": 3, "mode": "adaptive"}, + connect_timeout=10, + read_timeout=120, + ), + ) + return self._client + + def _openai_to_bedrock_anthropic( + self, request: OpenAIChatCompletionRequest + ) -> tuple[str, dict[str, Any]]: + """Convert OpenAI request to Bedrock Anthropic format.""" + messages = [] + system = None + + for msg in request.messages: + if msg.role == "system": + system = msg.content if isinstance(msg.content, str) else None + else: + messages.append({ + "role": msg.role, + "content": [{"text": msg.content}] if isinstance(msg.content, str) else msg.content, + }) + + bedrock_request = { + "messages": messages, + "max_tokens": request.max_tokens or 4096, + } + + if system: + bedrock_request["system"] = system + if request.temperature is not None: + bedrock_request["temperature"] = request.temperature + if request.top_p is not None: + bedrock_request["top_p"] = request.top_p + + # Return model ID and request body + model_id = request.model + return model_id, bedrock_request + + def _bedrock_to_openai( + self, + response: dict[str, Any], + model: str, + ) -> OpenAIChatCompletionResponse: + """Convert Bedrock Anthropic response to OpenAI format.""" + content = "" + output = response.get("output", {}) + message = output.get("message", {}) + + for block in message.get("content", []): + if "text" in block: + content += block["text"] + + stop_reason = response.get("stopReason", "end_turn") + finish_reason_map = { + "end_turn": "stop", + "max_tokens": "length", + "stop_sequence": "stop", + "tool_use": "tool_calls", + } + finish_reason = finish_reason_map.get(stop_reason, "stop") + + usage = response.get("usage", {}) + + return OpenAIChatCompletionResponse( + id=f"bedrock-{uuid.uuid4().hex[:24]}", + object="chat.completion", + created=int(time.time()), + model=model, + choices=[ + OpenAIChatCompletionChoice( + index=0, + message=OpenAIChatMessage(role="assistant", content=content), + finish_reason=finish_reason, + ) + ], + usage=OpenAIUsage( + prompt_tokens=usage.get("inputTokens", 0), + completion_tokens=usage.get("outputTokens", 0), + total_tokens=usage.get("inputTokens", 0) + usage.get("outputTokens", 0), + ), + ) + + async def chat_completions( + self, + request: OpenAIChatCompletionRequest, + ) -> OpenAIChatCompletionResponse: + """Execute a chat completion request to Bedrock.""" + client = self._get_client() + model_id, bedrock_request = self._openai_to_bedrock_anthropic(request) + + try: + response = client.invoke_model( + modelId=model_id, + contentType="application/json", + accept="application/json", + body=json.dumps(bedrock_request), + ) + + response_body = json.loads(response["body"].read()) + return self._bedrock_to_openai(response_body, request.model) + + except ClientError as e: + error = classify_error(Exception(str(e))) + raise error + + async def stream_chat_completions( + self, + request: OpenAIChatCompletionRequest, + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """Execute a streaming chat completion request to Bedrock.""" + client = self._get_client() + model_id, bedrock_request = self._openai_to_bedrock_anthropic(request) + + try: + response = client.invoke_model_with_response_stream( + modelId=model_id, + contentType="application/json", + accept="application/json", + body=json.dumps(bedrock_request), + ) + + chunk_id = f"bedrock-{uuid.uuid4().hex[:24]}" + + for event in response["body"]: + chunk_data = json.loads(event["chunk"]["bytes"]) + + if chunk_data.get("type") == "content_block_delta": + delta = chunk_data.get("delta", {}) + text = delta.get("text", "") + + if text: + chunk = OpenAIChatCompletionChunk( + id=chunk_id, + object="chat.completion.chunk", + created=int(time.time()), + model=request.model, + choices=[ + { + "index": 0, + "delta": {"content": text}, + "finish_reason": None, + } + ], + ) + yield chunk + + elif chunk_data.get("type") == "message_delta": + stop_reason = chunk_data.get("delta", {}).get("stop_reason") + if stop_reason: + finish_reason = "stop" if stop_reason == "end_turn" else "length" + chunk = OpenAIChatCompletionChunk( + id=chunk_id, + object="chat.completion.chunk", + created=int(time.time()), + model=request.model, + choices=[ + { + "index": 0, + "delta": {}, + "finish_reason": finish_reason, + } + ], + ) + yield chunk + + except ClientError as e: + error = classify_error(Exception(str(e))) + raise error + + async def messages( + self, + request: AnthropicMessagesRequest, + ) -> AnthropicMessagesResponse: + """Execute an Anthropic Messages API request via Bedrock.""" + openai_request = self.transformer.anthropic_to_openai(request) + openai_response = await self.chat_completions(openai_request) + return self.transformer.openai_response_to_anthropic(openai_response) + + async def check_health(self) -> HealthStatus: + """Check Bedrock API health.""" + if not HAS_BOTO3: + return HealthStatus.UNHEALTHY + + client = self._get_client() + try: + # List available models to check health + client.list_foundation_models() + return HealthStatus.HEALTHY + except Exception: + return HealthStatus.UNHEALTHY diff --git a/llm-gateway/app/adapters/gemini.py b/llm-gateway/app/adapters/gemini.py new file mode 100644 index 0000000..8aada32 --- /dev/null +++ b/llm-gateway/app/adapters/gemini.py @@ -0,0 +1,236 @@ +"""Google Gemini provider adapter.""" +import json +import time +import uuid +from typing import AsyncIterator, Any + +import httpx + +from app.adapters.base import BaseAdapter, HealthStatus, ProviderConfig +from app.core.fallback import RetryableError, classify_error +from app.core.transformer import RequestTransformer +from app.schemas.anthropic import ( + AnthropicMessagesRequest, + AnthropicMessagesResponse, + AnthropicTextBlock, + AnthropicUsage, +) +from app.schemas.openai import ( + OpenAIChatCompletionChunk, + OpenAIChatCompletionRequest, + OpenAIChatCompletionResponse, + OpenAIChatCompletionChoice, + OpenAIChatMessage, + OpenAIUsage, +) + + +class GeminiAdapter(BaseAdapter): + """Adapter for Google Gemini API.""" + + def __init__(self, config: ProviderConfig): + super().__init__(config) + self.transformer = RequestTransformer() + + def get_headers(self) -> dict[str, str]: + """Get headers for Gemini API requests.""" + return { + "Content-Type": "application/json", + } + + def _openai_to_gemini(self, request: OpenAIChatCompletionRequest) -> dict[str, Any]: + """Convert OpenAI request to Gemini format.""" + contents = [] + system_instruction = None + + for msg in request.messages: + if msg.role == "system": + system_instruction = {"parts": [{"text": msg.content}]} + else: + role = "user" if msg.role == "user" else "model" + if isinstance(msg.content, str): + parts = [{"text": msg.content}] + else: + parts = [{"text": str(part)} for part in msg.content] + contents.append({"role": role, "parts": parts}) + + gemini_request = { + "contents": contents, + "generationConfig": {}, + } + + if system_instruction: + gemini_request["systemInstruction"] = system_instruction + + if request.temperature is not None: + gemini_request["generationConfig"]["temperature"] = request.temperature + if request.max_tokens is not None: + gemini_request["generationConfig"]["maxOutputTokens"] = request.max_tokens + if request.top_p is not None: + gemini_request["generationConfig"]["topP"] = request.top_p + if request.stop: + if isinstance(request.stop, str): + gemini_request["generationConfig"]["stopSequences"] = [request.stop] + else: + gemini_request["generationConfig"]["stopSequences"] = request.stop + + return gemini_request + + def _gemini_to_openai( + self, + response: dict[str, Any], + model: str, + ) -> OpenAIChatCompletionResponse: + """Convert Gemini response to OpenAI format.""" + candidates = response.get("candidates", []) + content = "" + finish_reason = "stop" + + if candidates: + candidate = candidates[0] + content_parts = candidate.get("content", {}).get("parts", []) + content = "".join(p.get("text", "") for p in content_parts) + + finish_reason_map = { + "STOP": "stop", + "MAX_TOKENS": "length", + "SAFETY": "content_filter", + } + finish_reason = finish_reason_map.get( + candidate.get("finishReason", "STOP"), "stop" + ) + + usage = response.get("usageMetadata", {}) + + return OpenAIChatCompletionResponse( + id=f"gemini-{uuid.uuid4().hex[:24]}", + object="chat.completion", + created=int(time.time()), + model=model, + choices=[ + OpenAIChatCompletionChoice( + index=0, + message=OpenAIChatMessage(role="assistant", content=content), + finish_reason=finish_reason, + ) + ], + usage=OpenAIUsage( + prompt_tokens=usage.get("promptTokenCount", 0), + completion_tokens=usage.get("candidatesTokenCount", 0), + total_tokens=usage.get("totalTokenCount", 0), + ), + ) + + async def chat_completions( + self, + request: OpenAIChatCompletionRequest, + ) -> OpenAIChatCompletionResponse: + """Execute a chat completion request to Gemini.""" + # Extract model name + model = request.model + url = f"{self.config.api_base}/v1beta/models/{model}:generateContent?key={self.config.api_key}" + + gemini_request = self._openai_to_gemini(request) + + async with httpx.AsyncClient(timeout=120.0) as client: + try: + response = await client.post( + url, + headers=self.get_headers(), + json=gemini_request, + ) + + if response.status_code >= 400: + error = classify_error(Exception(f"{response.status_code}: {response.text}")) + raise error + + data = response.json() + return self._gemini_to_openai(data, model) + + except httpx.TimeoutException as e: + raise RetryableError(str(e), error_type="timeout") + except httpx.ConnectError as e: + raise RetryableError(str(e), error_type="connection_error") + + async def stream_chat_completions( + self, + request: OpenAIChatCompletionRequest, + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """Execute a streaming chat completion request to Gemini.""" + model = request.model + url = f"{self.config.api_base}/v1beta/models/{model}:streamGenerateContent?key={self.config.api_key}&alt=sse" + + gemini_request = self._openai_to_gemini(request) + + async with httpx.AsyncClient(timeout=120.0) as client: + try: + async with client.stream( + "POST", + url, + headers=self.get_headers(), + json=gemini_request, + ) as response: + if response.status_code >= 400: + error = classify_error( + Exception(f"{response.status_code}: {await response.aread()}") + ) + raise error + + async for line in response.aiter_lines(): + if line.startswith("data: "): + data = line[6:] + try: + gemini_response = json.loads(data) + openai_response = self._gemini_to_openai(gemini_response, model) + + if openai_response.choices: + chunk = OpenAIChatCompletionChunk( + id=openai_response.id, + object="chat.completion.chunk", + created=int(time.time()), + model=model, + choices=[ + { + "index": 0, + "delta": { + "content": openai_response.choices[0].message.content + }, + "finish_reason": openai_response.choices[0].finish_reason, + } + ], + ) + yield chunk + except json.JSONDecodeError: + continue + + except httpx.TimeoutException as e: + raise RetryableError(str(e), error_type="timeout") + except httpx.ConnectError as e: + raise RetryableError(str(e), error_type="connection_error") + + async def messages( + self, + request: AnthropicMessagesRequest, + ) -> AnthropicMessagesResponse: + """Execute an Anthropic Messages API request via Gemini.""" + openai_request = self.transformer.anthropic_to_openai(request) + openai_response = await self.chat_completions(openai_request) + return self.transformer.openai_response_to_anthropic(openai_response) + + async def check_health(self) -> HealthStatus: + """Check Gemini API health.""" + url = f"{self.config.api_base}/v1beta/models?key={self.config.api_key}" + + async with httpx.AsyncClient(timeout=10.0) as client: + try: + response = await client.get(url, headers=self.get_headers()) + + if response.status_code == 200: + return HealthStatus.HEALTHY + elif response.status_code >= 500: + return HealthStatus.UNHEALTHY + else: + return HealthStatus.DEGRADED + + except Exception: + return HealthStatus.UNHEALTHY diff --git a/llm-gateway/app/adapters/openai.py b/llm-gateway/app/adapters/openai.py new file mode 100644 index 0000000..38026c5 --- /dev/null +++ b/llm-gateway/app/adapters/openai.py @@ -0,0 +1,134 @@ +"""OpenAI provider adapter.""" +import json +from typing import Any, AsyncIterator + +import httpx + +from app.adapters.base import BaseAdapter, HealthStatus, ProviderConfig +from app.core.fallback import RetryableError, PermanentError, classify_error +from app.schemas.anthropic import AnthropicMessagesRequest, AnthropicMessagesResponse +from app.schemas.openai import ( + OpenAIChatCompletionChunk, + OpenAIChatCompletionRequest, + OpenAIChatCompletionResponse, +) +from app.core.transformer import RequestTransformer + + +class OpenAIAdapter(BaseAdapter): + """Adapter for OpenAI API.""" + + def __init__(self, config: ProviderConfig): + super().__init__(config) + self.transformer = RequestTransformer() + + def get_headers(self) -> dict[str, str]: + """Get headers for OpenAI API requests.""" + return { + **super().get_headers(), + "Authorization": f"Bearer {self.config.api_key}", + } + + async def chat_completions( + self, + request: OpenAIChatCompletionRequest, + ) -> OpenAIChatCompletionResponse: + """Execute a chat completion request to OpenAI.""" + url = f"{self.config.api_base}/chat/completions" + + payload = request.model_dump(exclude_none=True) + + async with httpx.AsyncClient(timeout=120.0) as client: + try: + response = await client.post( + url, + headers=self.get_headers(), + json=payload, + ) + + if response.status_code >= 400: + error = classify_error(Exception(f"{response.status_code}: {response.text}")) + raise error + + data = response.json() + return OpenAIChatCompletionResponse(**data) + + except httpx.TimeoutException as e: + raise RetryableError(str(e), error_type="timeout") + except httpx.ConnectError as e: + raise RetryableError(str(e), error_type="connection_error") + + async def stream_chat_completions( + self, + request: OpenAIChatCompletionRequest, + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """Execute a streaming chat completion request to OpenAI.""" + url = f"{self.config.api_base}/chat/completions" + + payload = request.model_dump(exclude_none=True) + payload["stream"] = True + + async with httpx.AsyncClient(timeout=120.0) as client: + try: + async with client.stream( + "POST", + url, + headers=self.get_headers(), + json=payload, + ) as response: + if response.status_code >= 400: + error = classify_error( + Exception(f"{response.status_code}: {await response.aread()}") + ) + raise error + + async for line in response.aiter_lines(): + if line.startswith("data: "): + data = line[6:] + if data == "[DONE]": + break + try: + chunk = json.loads(data) + yield OpenAIChatCompletionChunk(**chunk) + except json.JSONDecodeError: + continue + + except httpx.TimeoutException as e: + raise RetryableError(str(e), error_type="timeout") + except httpx.ConnectError as e: + raise RetryableError(str(e), error_type="connection_error") + + async def messages( + self, + request: AnthropicMessagesRequest, + ) -> AnthropicMessagesResponse: + """ + Execute an Anthropic Messages API request via OpenAI. + Converts request to OpenAI format, calls OpenAI, converts response back. + """ + # Convert Anthropic request to OpenAI format + openai_request = self.transformer.anthropic_to_openai(request) + + # Call OpenAI + openai_response = await self.chat_completions(openai_request) + + # Convert response back to Anthropic format + return self.transformer.openai_response_to_anthropic(openai_response) + + async def check_health(self) -> HealthStatus: + """Check OpenAI API health by listing models.""" + url = f"{self.config.api_base}/models" + + async with httpx.AsyncClient(timeout=10.0) as client: + try: + response = await client.get(url, headers=self.get_headers()) + + if response.status_code == 200: + return HealthStatus.HEALTHY + elif response.status_code >= 500: + return HealthStatus.UNHEALTHY + else: + return HealthStatus.DEGRADED + + except Exception: + return HealthStatus.UNHEALTHY diff --git a/llm-gateway/app/api/__init__.py b/llm-gateway/app/api/__init__.py new file mode 100644 index 0000000..fa227e0 --- /dev/null +++ b/llm-gateway/app/api/__init__.py @@ -0,0 +1 @@ +# api module diff --git a/llm-gateway/app/api/admin/__init__.py b/llm-gateway/app/api/admin/__init__.py new file mode 100644 index 0000000..3f9310c --- /dev/null +++ b/llm-gateway/app/api/admin/__init__.py @@ -0,0 +1 @@ +# admin module diff --git a/llm-gateway/app/api/admin/health.py b/llm-gateway/app/api/admin/health.py new file mode 100644 index 0000000..832296f --- /dev/null +++ b/llm-gateway/app/api/admin/health.py @@ -0,0 +1,68 @@ +"""Health check API endpoints.""" +from datetime import datetime +from typing import Annotated + +from fastapi import APIRouter, Depends +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import get_settings +from app.db.database import get_db +from app.models.provider import Provider + +router = APIRouter(tags=["Health"]) + + +@router.get("/health") +async def health_check() -> dict: + """Basic health check endpoint.""" + settings = get_settings() + return { + "status": "healthy", + "app": settings.app_name, + "timestamp": datetime.utcnow().isoformat(), + } + + +@router.get("/admin/providers/{provider_id}/health") +async def provider_health_check( + provider_id: str, + db: Annotated[AsyncSession, Depends(get_db)], +) -> dict: + """Check health status of a specific provider.""" + result = await db.execute(select(Provider).where(Provider.id == provider_id)) + provider = result.scalar_one_or_none() + + if not provider: + return { + "status": "not_found", + "provider_id": provider_id, + } + + return { + "status": provider.health_status, + "provider_id": provider_id, + "provider_name": provider.name, + "last_check": provider.last_health_check.isoformat() if provider.last_health_check else None, + "enabled": provider.enabled, + } + + +@router.get("/ready") +async def readiness_check( + db: Annotated[AsyncSession, Depends(get_db)], +) -> dict: + """Readiness check - verify database connection.""" + try: + # Simple query to verify database connection + await db.execute(select(1)) + return { + "status": "ready", + "database": "connected", + } + except Exception as e: + return { + "status": "not_ready", + "database": "disconnected", + "error": str(e), + } diff --git a/llm-gateway/app/api/admin/keys.py b/llm-gateway/app/api/admin/keys.py new file mode 100644 index 0000000..e3a18d5 --- /dev/null +++ b/llm-gateway/app/api/admin/keys.py @@ -0,0 +1,233 @@ +"""API Key management API endpoints.""" +import json +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db.database import get_db +from app.models.api_key import APIKey +from app.models.project import Project +from app.schemas.api_key import ( + APIKeyCreate, + APIKeyCreateResponse, + APIKeyListResponse, + APIKeyResponse, + APIKeyUpdate, +) +from app.utils.crypto import generate_api_key + +router = APIRouter(prefix="/keys", tags=["Admin - Keys"]) + + +@router.post("", response_model=APIKeyCreateResponse, status_code=status.HTTP_201_CREATED) +async def create_api_key( + data: APIKeyCreate, + db: Annotated[AsyncSession, Depends(get_db)], +) -> dict: + """Create a new API key.""" + # Verify project exists if specified + if data.project_id: + result = await db.execute(select(Project).where(Project.id == data.project_id)) + if not result.scalar_one_or_none(): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Project '{data.project_id}' not found", + ) + + # Generate API key + full_key, key_hash, key_prefix = generate_api_key(data.prefix) + + # Create API key record + api_key = APIKey( + key_hash=key_hash, + key_prefix=key_prefix, + name=data.name, + project_id=data.project_id, + enabled=data.enabled, + expires_at=data.expires_at, + rpm_limit=data.rpm_limit, + tpm_limit=data.tpm_limit, + budget_limit=data.budget_limit, + budget_period=data.budget_period, + allowed_models=json.dumps(data.allowed_models) if data.allowed_models else None, + ) + + db.add(api_key) + await db.flush() + await db.refresh(api_key) + + # Return response with full key (only time it's shown) + return { + "id": api_key.id, + "key": full_key, # Full key - shown only once! + "key_prefix": api_key.key_prefix, + "name": api_key.name, + "project_id": api_key.project_id, + "enabled": api_key.enabled, + "expires_at": api_key.expires_at, + "rpm_limit": api_key.rpm_limit, + "tpm_limit": api_key.tpm_limit, + "budget_limit": api_key.budget_limit, + "budget_period": api_key.budget_period, + "allowed_models": json.loads(api_key.allowed_models) if api_key.allowed_models else None, + "current_usage": api_key.current_usage, + "total_requests": api_key.total_requests, + "created_at": api_key.created_at, + "updated_at": api_key.updated_at, + } + + +@router.get("", response_model=APIKeyListResponse) +async def list_api_keys( + db: Annotated[AsyncSession, Depends(get_db)], + page: Annotated[int, Query(ge=1)] = 1, + page_size: Annotated[int, Query(ge=1, le=100)] = 20, + project_id: str | None = None, + enabled: bool | None = None, +) -> APIKeyListResponse: + """List all API keys.""" + query = select(APIKey) + count_query = select(func.count()).select_from(APIKey) + + if project_id: + query = query.where(APIKey.project_id == project_id) + count_query = count_query.where(APIKey.project_id == project_id) + + if enabled is not None: + query = query.where(APIKey.enabled == enabled) + count_query = count_query.where(APIKey.enabled == enabled) + + offset = (page - 1) * page_size + query = query.offset(offset).limit(page_size).order_by(APIKey.created_at.desc()) + + result = await db.execute(query) + keys = result.scalars().all() + + total_result = await db.execute(count_query) + total = total_result.scalar() or 0 + + return APIKeyListResponse( + keys=[ + APIKeyResponse( + id=k.id, + key_prefix=k.key_prefix, + name=k.name, + project_id=k.project_id, + enabled=k.enabled, + expires_at=k.expires_at, + rpm_limit=k.rpm_limit, + tpm_limit=k.tpm_limit, + budget_limit=k.budget_limit, + budget_period=k.budget_period, + allowed_models=json.loads(k.allowed_models) if k.allowed_models else None, + current_usage=k.current_usage, + total_requests=k.total_requests, + created_at=k.created_at, + updated_at=k.updated_at, + ) + for k in keys + ], + total=total, + page=page, + page_size=page_size, + ) + + +@router.get("/{key_id}", response_model=APIKeyResponse) +async def get_api_key( + key_id: str, + db: Annotated[AsyncSession, Depends(get_db)], +) -> dict: + """Get an API key by ID.""" + result = await db.execute(select(APIKey).where(APIKey.id == key_id)) + key = result.scalar_one_or_none() + + if not key: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"API key '{key_id}' not found", + ) + + return APIKeyResponse( + id=key.id, + key_prefix=key.key_prefix, + name=key.name, + project_id=key.project_id, + enabled=key.enabled, + expires_at=key.expires_at, + rpm_limit=key.rpm_limit, + tpm_limit=key.tpm_limit, + budget_limit=key.budget_limit, + budget_period=key.budget_period, + allowed_models=json.loads(key.allowed_models) if key.allowed_models else None, + current_usage=key.current_usage, + total_requests=key.total_requests, + created_at=key.created_at, + updated_at=key.updated_at, + ) + + +@router.put("/{key_id}", response_model=APIKeyResponse) +async def update_api_key( + key_id: str, + data: APIKeyUpdate, + db: Annotated[AsyncSession, Depends(get_db)], +) -> dict: + """Update an API key.""" + result = await db.execute(select(APIKey).where(APIKey.id == key_id)) + key = result.scalar_one_or_none() + + if not key: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"API key '{key_id}' not found", + ) + + update_data = data.model_dump(exclude_unset=True) + + if "allowed_models" in update_data: + key.allowed_models = json.dumps(update_data.pop("allowed_models")) + + for k, v in update_data.items(): + setattr(key, k, v) + + await db.flush() + await db.refresh(key) + + return APIKeyResponse( + id=key.id, + key_prefix=key.key_prefix, + name=key.name, + project_id=key.project_id, + enabled=key.enabled, + expires_at=key.expires_at, + rpm_limit=key.rpm_limit, + tpm_limit=key.tpm_limit, + budget_limit=key.budget_limit, + budget_period=key.budget_period, + allowed_models=json.loads(key.allowed_models) if key.allowed_models else None, + current_usage=key.current_usage, + total_requests=key.total_requests, + created_at=key.created_at, + updated_at=key.updated_at, + ) + + +@router.delete("/{key_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_api_key( + key_id: str, + db: Annotated[AsyncSession, Depends(get_db)], +) -> None: + """Delete an API key.""" + result = await db.execute(select(APIKey).where(APIKey.id == key_id)) + key = result.scalar_one_or_none() + + if not key: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"API key '{key_id}' not found", + ) + + await db.delete(key) diff --git a/llm-gateway/app/api/admin/models.py b/llm-gateway/app/api/admin/models.py new file mode 100644 index 0000000..604c73e --- /dev/null +++ b/llm-gateway/app/api/admin/models.py @@ -0,0 +1,169 @@ +"""Model alias management API endpoints.""" +import json +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db.database import get_db +from app.models.model_alias import ModelAlias +from app.schemas.model_alias import ( + ModelAliasCreate, + ModelAliasListResponse, + ModelAliasResponse, + ModelAliasUpdate, +) + +router = APIRouter(prefix="/models/aliases", tags=["Admin - Models"]) + + +@router.post("", response_model=ModelAliasResponse, status_code=status.HTTP_201_CREATED) +async def create_model_alias( + data: ModelAliasCreate, + db: Annotated[AsyncSession, Depends(get_db)], +) -> ModelAlias: + """Create a new model alias.""" + # Check if alias already exists + result = await db.execute(select(ModelAlias).where(ModelAlias.alias == data.alias)) + if result.scalar_one_or_none(): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"Model alias '{data.alias}' already exists", + ) + + # Create model alias + alias = ModelAlias( + alias=data.alias, + provider=data.provider, + model=data.model, + enabled=data.enabled, + routing_type=data.routing_type, + routing_config=json.dumps(data.routing_config.model_dump()) if data.routing_config else None, + input_price_per_1k=data.input_price_per_1k, + output_price_per_1k=data.output_price_per_1k, + ) + + db.add(alias) + await db.flush() + await db.refresh(alias) + + return alias + + +@router.get("", response_model=ModelAliasListResponse) +async def list_model_aliases( + db: Annotated[AsyncSession, Depends(get_db)], + page: Annotated[int, Query(ge=1)] = 1, + page_size: Annotated[int, Query(ge=1, le=100)] = 20, + provider: str | None = None, + enabled: bool | None = None, +) -> ModelAliasListResponse: + """List all model aliases.""" + query = select(ModelAlias) + count_query = select(func.count()).select_from(ModelAlias) + + if provider: + query = query.where(ModelAlias.provider == provider) + count_query = count_query.where(ModelAlias.provider == provider) + + if enabled is not None: + query = query.where(ModelAlias.enabled == enabled) + count_query = count_query.where(ModelAlias.enabled == enabled) + + offset = (page - 1) * page_size + query = query.offset(offset).limit(page_size).order_by(ModelAlias.created_at.desc()) + + result = await db.execute(query) + aliases = result.scalars().all() + + total_result = await db.execute(count_query) + total = total_result.scalar() or 0 + + return ModelAliasListResponse( + aliases=[ + ModelAliasResponse( + id=a.id, + alias=a.alias, + provider=a.provider, + model=a.model, + enabled=a.enabled, + routing_type=a.routing_type, + routing_config=json.loads(a.routing_config) if a.routing_config else None, + input_price_per_1k=a.input_price_per_1k, + output_price_per_1k=a.output_price_per_1k, + created_at=a.created_at, + updated_at=a.updated_at, + ) + for a in aliases + ], + total=total, + page=page, + page_size=page_size, + ) + + +@router.get("/{alias_id}", response_model=ModelAliasResponse) +async def get_model_alias( + alias_id: str, + db: Annotated[AsyncSession, Depends(get_db)], +) -> ModelAlias: + """Get a model alias by ID.""" + result = await db.execute(select(ModelAlias).where(ModelAlias.id == alias_id)) + alias = result.scalar_one_or_none() + + if not alias: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Model alias '{alias_id}' not found", + ) + + return alias + + +@router.put("/{alias_id}", response_model=ModelAliasResponse) +async def update_model_alias( + alias_id: str, + data: ModelAliasUpdate, + db: Annotated[AsyncSession, Depends(get_db)], +) -> ModelAlias: + """Update a model alias.""" + result = await db.execute(select(ModelAlias).where(ModelAlias.id == alias_id)) + alias = result.scalar_one_or_none() + + if not alias: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Model alias '{alias_id}' not found", + ) + + update_data = data.model_dump(exclude_unset=True) + + if "routing_config" in update_data and update_data["routing_config"]: + alias.routing_config = json.dumps(update_data.pop("routing_config")) + + for key, value in update_data.items(): + setattr(alias, key, value) + + await db.flush() + await db.refresh(alias) + + return alias + + +@router.delete("/{alias_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_model_alias( + alias_id: str, + db: Annotated[AsyncSession, Depends(get_db)], +) -> None: + """Delete a model alias.""" + result = await db.execute(select(ModelAlias).where(ModelAlias.id == alias_id)) + alias = result.scalar_one_or_none() + + if not alias: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Model alias '{alias_id}' not found", + ) + + await db.delete(alias) diff --git a/llm-gateway/app/api/admin/projects.py b/llm-gateway/app/api/admin/projects.py new file mode 100644 index 0000000..f7c2404 --- /dev/null +++ b/llm-gateway/app/api/admin/projects.py @@ -0,0 +1,132 @@ +"""Project management API endpoints.""" +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db.database import get_db +from app.models.project import Project +from app.schemas.project import ( + ProjectCreate, + ProjectListResponse, + ProjectResponse, + ProjectUpdate, +) + +router = APIRouter(prefix="/projects", tags=["Admin - Projects"]) + + +@router.post("", response_model=ProjectResponse, status_code=status.HTTP_201_CREATED) +async def create_project( + data: ProjectCreate, + db: Annotated[AsyncSession, Depends(get_db)], +) -> Project: + """Create a new project.""" + project = Project( + name=data.name, + description=data.description, + budget_limit=data.budget_limit, + budget_period=data.budget_period, + enabled=data.enabled, + ) + + db.add(project) + await db.flush() + await db.refresh(project) + + return project + + +@router.get("", response_model=ProjectListResponse) +async def list_projects( + db: Annotated[AsyncSession, Depends(get_db)], + page: Annotated[int, Query(ge=1)] = 1, + page_size: Annotated[int, Query(ge=1, le=100)] = 20, + enabled: bool | None = None, +) -> ProjectListResponse: + """List all projects.""" + query = select(Project) + count_query = select(func.count()).select_from(Project) + + if enabled is not None: + query = query.where(Project.enabled == enabled) + count_query = count_query.where(Project.enabled == enabled) + + offset = (page - 1) * page_size + query = query.offset(offset).limit(page_size).order_by(Project.created_at.desc()) + + result = await db.execute(query) + projects = result.scalars().all() + + total_result = await db.execute(count_query) + total = total_result.scalar() or 0 + + return ProjectListResponse( + projects=[ProjectResponse.model_validate(p) for p in projects], + total=total, + page=page, + page_size=page_size, + ) + + +@router.get("/{project_id}", response_model=ProjectResponse) +async def get_project( + project_id: str, + db: Annotated[AsyncSession, Depends(get_db)], +) -> Project: + """Get a project by ID.""" + result = await db.execute(select(Project).where(Project.id == project_id)) + project = result.scalar_one_or_none() + + if not project: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Project '{project_id}' not found", + ) + + return project + + +@router.put("/{project_id}", response_model=ProjectResponse) +async def update_project( + project_id: str, + data: ProjectUpdate, + db: Annotated[AsyncSession, Depends(get_db)], +) -> Project: + """Update a project.""" + result = await db.execute(select(Project).where(Project.id == project_id)) + project = result.scalar_one_or_none() + + if not project: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Project '{project_id}' not found", + ) + + update_data = data.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(project, key, value) + + await db.flush() + await db.refresh(project) + + return project + + +@router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_project( + project_id: str, + db: Annotated[AsyncSession, Depends(get_db)], +) -> None: + """Delete a project.""" + result = await db.execute(select(Project).where(Project.id == project_id)) + project = result.scalar_one_or_none() + + if not project: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Project '{project_id}' not found", + ) + + await db.delete(project) diff --git a/llm-gateway/app/api/admin/providers.py b/llm-gateway/app/api/admin/providers.py new file mode 100644 index 0000000..bc432e6 --- /dev/null +++ b/llm-gateway/app/api/admin/providers.py @@ -0,0 +1,192 @@ +"""Provider management API endpoints.""" +import json +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db.database import get_db +from app.models.provider import Provider +from app.schemas.common import PaginationParams +from app.schemas.provider import ( + ProviderCreate, + ProviderListResponse, + ProviderResponse, + ProviderUpdate, +) +from app.utils.crypto import encrypt_value + +router = APIRouter(prefix="/providers", tags=["Admin - Providers"]) + + +@router.post("", response_model=ProviderResponse, status_code=status.HTTP_201_CREATED) +async def create_provider( + data: ProviderCreate, + db: Annotated[AsyncSession, Depends(get_db)], +) -> Provider: + """Create a new provider.""" + # Check if provider already exists + result = await db.execute(select(Provider).where(Provider.name == data.name)) + if result.scalar_one_or_none(): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"Provider '{data.name}' already exists", + ) + + # Encrypt API key + encrypted_key = encrypt_value(data.api_key) + + # Create provider + provider = Provider( + name=data.name, + api_base=data.api_base, + api_key_encrypted=encrypted_key, + api_version=data.api_version, + config=json.dumps(data.config) if data.config else None, + rpm_limit=data.rpm_limit, + tpm_limit=data.tpm_limit, + enabled=data.enabled, + ) + + db.add(provider) + await db.flush() + await db.refresh(provider) + + # Decrypt config for response + if provider.config: + provider.config = json.loads(provider.config) + + return provider + + +@router.get("", response_model=ProviderListResponse) +async def list_providers( + db: Annotated[AsyncSession, Depends(get_db)], + page: Annotated[int, Query(ge=1)] = 1, + page_size: Annotated[int, Query(ge=1, le=100)] = 20, + enabled: bool | None = None, +) -> ProviderListResponse: + """List all providers.""" + # Build query + query = select(Provider) + count_query = select(func.count()).select_from(Provider) + + if enabled is not None: + query = query.where(Provider.enabled == enabled) + count_query = count_query.where(Provider.enabled == enabled) + + # Apply pagination + offset = (page - 1) * page_size + query = query.offset(offset).limit(page_size).order_by(Provider.created_at.desc()) + + # Execute queries + result = await db.execute(query) + providers = result.scalars().all() + + total_result = await db.execute(count_query) + total = total_result.scalar() or 0 + + # Decrypt config for response + for p in providers: + if p.config: + p.config = json.loads(p.config) + + return ProviderListResponse( + providers=[ + ProviderResponse( + id=p.id, + name=p.name, + enabled=p.enabled, + api_base=p.api_base, + api_version=p.api_version, + config=json.loads(p.config) if p.config else None, + rpm_limit=p.rpm_limit, + tpm_limit=p.tpm_limit, + health_status=p.health_status, + last_health_check=p.last_health_check, + created_at=p.created_at, + updated_at=p.updated_at, + ) + for p in providers + ], + total=total, + page=page, + page_size=page_size, + ) + + +@router.get("/{provider_id}", response_model=ProviderResponse) +async def get_provider( + provider_id: str, + db: Annotated[AsyncSession, Depends(get_db)], +) -> Provider: + """Get a provider by ID.""" + result = await db.execute(select(Provider).where(Provider.id == provider_id)) + provider = result.scalar_one_or_none() + + if not provider: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Provider '{provider_id}' not found", + ) + + if provider.config: + provider.config = json.loads(provider.config) + + return provider + + +@router.put("/{provider_id}", response_model=ProviderResponse) +async def update_provider( + provider_id: str, + data: ProviderUpdate, + db: Annotated[AsyncSession, Depends(get_db)], +) -> Provider: + """Update a provider.""" + result = await db.execute(select(Provider).where(Provider.id == provider_id)) + provider = result.scalar_one_or_none() + + if not provider: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Provider '{provider_id}' not found", + ) + + # Update fields + update_data = data.model_dump(exclude_unset=True) + + if "api_key" in update_data: + provider.api_key_encrypted = encrypt_value(update_data.pop("api_key")) + + if "config" in update_data: + provider.config = json.dumps(update_data.pop("config")) + + for key, value in update_data.items(): + setattr(provider, key, value) + + await db.flush() + await db.refresh(provider) + + if provider.config: + provider.config = json.loads(provider.config) + + return provider + + +@router.delete("/{provider_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_provider( + provider_id: str, + db: Annotated[AsyncSession, Depends(get_db)], +) -> None: + """Delete a provider.""" + result = await db.execute(select(Provider).where(Provider.id == provider_id)) + provider = result.scalar_one_or_none() + + if not provider: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Provider '{provider_id}' not found", + ) + + await db.delete(provider) diff --git a/llm-gateway/app/api/admin/usage.py b/llm-gateway/app/api/admin/usage.py new file mode 100644 index 0000000..9c62eca --- /dev/null +++ b/llm-gateway/app/api/admin/usage.py @@ -0,0 +1,221 @@ +"""Usage statistics API endpoints.""" +from datetime import datetime, timedelta +from decimal import Decimal +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, Query +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db.database import get_db +from app.models.usage import RequestLog, UsageStatsHourly + +router = APIRouter(prefix="/usage", tags=["Admin - Usage"]) + + +@router.get("/stats") +async def get_usage_stats( + db: Annotated[AsyncSession, Depends(get_db)], + start_date: Annotated[datetime | None, Query()] = None, + end_date: Annotated[datetime | None, Query()] = None, + group_by: Annotated[str, Query(pattern="^(model|provider|project|key)$")] = "model", + project_id: str | None = None, + key_id: str | None = None, +) -> dict[str, Any]: + """Get usage statistics.""" + # Default to last 7 days + if not start_date: + start_date = datetime.utcnow() - timedelta(days=7) + if not end_date: + end_date = datetime.utcnow() + + # Build base query + query = select( + getattr(UsageStatsHourly, group_by).label("group"), + func.sum(UsageStatsHourly.request_count).label("requests"), + func.sum(UsageStatsHourly.input_tokens).label("input_tokens"), + func.sum(UsageStatsHourly.output_tokens).label("output_tokens"), + func.sum(UsageStatsHourly.total_tokens).label("total_tokens"), + func.sum(UsageStatsHourly.cost_usd).label("cost_usd"), + func.avg(UsageStatsHourly.avg_latency_ms).label("avg_latency_ms"), + func.sum(UsageStatsHourly.error_count).label("errors"), + ).where( + UsageStatsHourly.timestamp >= start_date, + UsageStatsHourly.timestamp <= end_date, + ) + + if project_id: + query = query.where(UsageStatsHourly.project_id == project_id) + if key_id: + query = query.where(UsageStatsHourly.virtual_key_id == key_id) + + query = query.group_by(getattr(UsageStatsHourly, group_by)) + + result = await db.execute(query) + rows = result.all() + + # Calculate totals + total_query = select( + func.sum(UsageStatsHourly.request_count).label("requests"), + func.sum(UsageStatsHourly.input_tokens).label("input_tokens"), + func.sum(UsageStatsHourly.output_tokens).label("output_tokens"), + func.sum(UsageStatsHourly.total_tokens).label("total_tokens"), + func.sum(UsageStatsHourly.cost_usd).label("cost_usd"), + func.sum(UsageStatsHourly.error_count).label("errors"), + ).where( + UsageStatsHourly.timestamp >= start_date, + UsageStatsHourly.timestamp <= end_date, + ) + + if project_id: + total_query = total_query.where(UsageStatsHourly.project_id == project_id) + if key_id: + total_query = total_query.where(UsageStatsHourly.virtual_key_id == key_id) + + total_result = await db.execute(total_query) + total_row = total_result.one() + + return { + "period": { + "start": start_date.isoformat(), + "end": end_date.isoformat(), + }, + "totals": { + "requests": total_row.requests or 0, + "input_tokens": total_row.input_tokens or 0, + "output_tokens": total_row.output_tokens or 0, + "total_tokens": total_row.total_tokens or 0, + "cost_usd": float(total_row.cost_usd or 0), + "errors": total_row.errors or 0, + }, + "by_" + group_by: [ + { + group_by: row.group, + "requests": row.requests or 0, + "input_tokens": row.input_tokens or 0, + "output_tokens": row.output_tokens or 0, + "total_tokens": row.total_tokens or 0, + "cost_usd": float(row.cost_usd or 0), + "avg_latency_ms": int(row.avg_latency_ms or 0), + "errors": row.errors or 0, + } + for row in rows + ], + } + + +@router.get("/logs") +async def get_usage_logs( + db: Annotated[AsyncSession, Depends(get_db)], + page: Annotated[int, Query(ge=1)] = 1, + page_size: Annotated[int, Query(ge=1, le=100)] = 50, + project_id: str | None = None, + key_id: str | None = None, + provider: str | None = None, + model: str | None = None, + status_code: int | None = None, +) -> dict[str, Any]: + """Get request logs.""" + query = select(RequestLog) + count_query = select(func.count()).select_from(RequestLog) + + # Apply filters + if project_id: + query = query.where(RequestLog.project_id == project_id) + count_query = count_query.where(RequestLog.project_id == project_id) + if key_id: + query = query.where(RequestLog.virtual_key_id == key_id) + count_query = count_query.where(RequestLog.virtual_key_id == key_id) + if provider: + query = query.where(RequestLog.provider == provider) + count_query = count_query.where(RequestLog.provider == provider) + if model: + query = query.where(RequestLog.model == model) + count_query = count_query.where(RequestLog.model == model) + if status_code: + query = query.where(RequestLog.status_code == status_code) + count_query = count_query.where(RequestLog.status_code == status_code) + + # Apply pagination + offset = (page - 1) * page_size + query = query.offset(offset).limit(page_size).order_by(RequestLog.timestamp.desc()) + + result = await db.execute(query) + logs = result.scalars().all() + + total_result = await db.execute(count_query) + total = total_result.scalar() or 0 + + return { + "logs": [ + { + "id": log.id, + "timestamp": log.timestamp.isoformat(), + "key_id": log.virtual_key_id, + "project_id": log.project_id, + "provider": log.provider, + "model": log.model, + "model_alias": log.model_alias, + "request_type": log.request_type, + "input_tokens": log.input_tokens, + "output_tokens": log.output_tokens, + "total_tokens": log.total_tokens, + "status_code": log.status_code, + "latency_ms": log.latency_ms, + "finish_reason": log.finish_reason, + "cost_usd": float(log.cost_usd), + } + for log in logs + ], + "total": total, + "page": page, + "page_size": page_size, + } + + +@router.get("/costs") +async def get_cost_breakdown( + db: Annotated[AsyncSession, Depends(get_db)], + start_date: Annotated[datetime | None, Query()] = None, + end_date: Annotated[datetime | None, Query()] = None, + group_by: Annotated[str, Query(pattern="^(model|provider|project|key)$")] = "provider", +) -> dict[str, Any]: + """Get cost breakdown.""" + # Default to current month + if not start_date: + start_date = datetime.utcnow().replace(day=1, hour=0, minute=0, second=0, microsecond=0) + if not end_date: + end_date = datetime.utcnow() + + query = select( + getattr(UsageStatsHourly, group_by).label("group"), + func.sum(UsageStatsHourly.cost_usd).label("cost_usd"), + func.sum(UsageStatsHourly.request_count).label("requests"), + func.sum(UsageStatsHourly.total_tokens).label("tokens"), + ).where( + UsageStatsHourly.timestamp >= start_date, + UsageStatsHourly.timestamp <= end_date, + ).group_by(getattr(UsageStatsHourly, group_by)) + + result = await db.execute(query) + rows = result.all() + + total_cost = sum(float(row.cost_usd or 0) for row in rows) + + return { + "period": { + "start": start_date.isoformat(), + "end": end_date.isoformat(), + }, + "total_cost_usd": total_cost, + "breakdown": [ + { + group_by: row.group, + "cost_usd": float(row.cost_usd or 0), + "percentage": float(row.cost_usd or 0) / total_cost * 100 if total_cost > 0 else 0, + "requests": row.requests or 0, + "tokens": row.tokens or 0, + } + for row in sorted(rows, key=lambda r: r.cost_usd or 0, reverse=True) + ], + } diff --git a/llm-gateway/app/api/v1/__init__.py b/llm-gateway/app/api/v1/__init__.py new file mode 100644 index 0000000..53d66c8 --- /dev/null +++ b/llm-gateway/app/api/v1/__init__.py @@ -0,0 +1 @@ +# v1 module diff --git a/llm-gateway/app/api/v1/chat.py b/llm-gateway/app/api/v1/chat.py new file mode 100644 index 0000000..bebc29b --- /dev/null +++ b/llm-gateway/app/api/v1/chat.py @@ -0,0 +1,327 @@ +"""Chat Completions API endpoint (OpenAI-compatible).""" +import time +import uuid +from datetime import datetime +from decimal import Decimal +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, Header, HTTPException, Request +from fastapi.responses import StreamingResponse +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.adapters import ProviderConfig, create_adapter +from app.core.budget import BudgetController, BudgetExceeded +from app.core.circuit_breaker import get_circuit_breaker +from app.core.fallback import RetryExecutor, RetryConfig, RetryableError, classify_error +from app.core.load_balancer import LoadBalancer +from app.core.rate_limiter import RateLimiter, RateLimitExceeded +from app.core.router import Router, RoutingResult +from app.db.database import get_db +from app.models.api_key import APIKey +from app.models.provider import Provider +from app.models.usage import RequestLog +from app.schemas.openai import ( + OpenAIChatCompletionRequest, + OpenAIChatCompletionResponse, +) +from app.utils.crypto import decrypt_value, verify_api_key +from app.utils.logging import get_logger + +logger = get_logger(__name__) + +router = APIRouter(prefix="/v1", tags=["Chat"]) + + +async def authenticate( + authorization: str | None = Header(None), + x_api_key: str | None = Header(None), + db: AsyncSession = Depends(get_db), +) -> APIKey: + """Authenticate request using virtual API key.""" + # Extract key from header + key = None + if authorization: + if authorization.startswith("Bearer "): + key = authorization[7:] + elif x_api_key: + key = x_api_key + + if not key: + raise HTTPException( + status_code=401, + detail={"error": {"type": "authentication_error", "message": "Missing API key"}}, + ) + + # Find and verify key + result = await db.execute(select(APIKey)) + api_keys = result.scalars().all() + + for api_key in api_keys: + if verify_api_key(key, api_key.key_hash): + if not api_key.enabled: + raise HTTPException( + status_code=403, + detail={"error": {"type": "permission_error", "message": "API key is disabled"}}, + ) + if api_key.expires_at and api_key.expires_at < datetime.utcnow(): + raise HTTPException( + status_code=403, + detail={"error": {"type": "permission_error", "message": "API key has expired"}}, + ) + return api_key + + raise HTTPException( + status_code=401, + detail={"error": {"type": "authentication_error", "message": "Invalid API key"}}, + ) + + +@router.post("/chat/completions") +async def chat_completions( + request: OpenAIChatCompletionRequest, + db: Annotated[AsyncSession, Depends(get_db)], + api_key: Annotated[APIKey, Depends(authenticate)], +) -> OpenAIChatCompletionResponse: + """Execute a chat completion request.""" + start_time = time.time() + routing_result: RoutingResult | None = None + response: OpenAIChatCompletionResponse | None = None + error: Exception | None = None + + try: + # Initialize services + router_service = Router(db) + rate_limiter = RateLimiter(db) + budget_controller = BudgetController(db) + load_balancer = LoadBalancer(db) + + # Check rate limits + try: + await rate_limiter.check_and_increment( + f"key:{api_key.id}", + rpm_limit=api_key.rpm_limit, + tpm_limit=api_key.tpm_limit, + ) + except RateLimitExceeded as e: + raise HTTPException( + status_code=429, + detail={ + "error": { + "type": "rate_limit_error", + "message": str(e), + "details": { + "limit": e.limit, + "remaining": e.remaining, + "reset_at": e.reset_at.isoformat(), + }, + } + }, + ) + + # Check budget + try: + await budget_controller.check_budget(api_key.id) + except BudgetExceeded as e: + raise HTTPException( + status_code=402, + detail={ + "error": { + "type": "budget_exceeded_error", + "message": str(e), + "details": { + "limit": float(e.limit), + "current_usage": float(e.current_usage), + "scope": e.scope, + }, + } + }, + ) + + # Resolve model routing + routing_result = await router_service.resolve_model(request.model) + + # Check if model is allowed + if api_key.allowed_models: + import json + allowed = json.loads(api_key.allowed_models) + if request.model not in allowed: + raise HTTPException( + status_code=403, + detail={ + "error": { + "type": "permission_error", + "message": f"Model '{request.model}' is not allowed for this API key", + } + }, + ) + + # Get provider config + result = await db.execute( + select(Provider).where(Provider.name == routing_result.provider, Provider.enabled == True) + ) + provider = result.scalar_one_or_none() + + if not provider: + raise HTTPException( + status_code=503, + detail={ + "error": { + "type": "provider_error", + "message": f"Provider '{routing_result.provider}' not found or disabled", + } + }, + ) + + # Check circuit breaker + circuit_breaker = get_circuit_breaker() + if not circuit_breaker.is_available(provider.name): + raise HTTPException( + status_code=503, + detail={ + "error": { + "type": "service_unavailable", + "message": f"Provider '{provider.name}' is currently unavailable", + } + }, + ) + + # Create adapter + provider_config = ProviderConfig( + name=provider.name, + api_base=provider.api_base, + api_key=decrypt_value(provider.api_key_encrypted), + api_version=provider.api_version, + config=json.loads(provider.config) if provider.config else None, + ) + + adapter = create_adapter(provider_config) + + # Update request model to actual model + request.model = routing_result.model + + # Execute request with retry + retry_executor = RetryExecutor(RetryConfig()) + + async def execute(): + return await adapter.chat_completions(request) + + response = await retry_executor.execute(execute, provider.name, "chat_completions") + circuit_breaker.record_success(provider.name) + + # Record usage + if response.usage: + cost = await _calculate_cost(db, routing_result, response.usage) + await budget_controller.record_usage( + api_key.id, + cost, + response.usage.prompt_tokens, + response.usage.completion_tokens, + ) + + # Record tokens for rate limiting + await rate_limiter.record_tokens( + f"key:{api_key.id}", + response.usage.total_tokens, + ) + + return response + + except HTTPException: + raise + except RetryableError as e: + error = e + logger.error( + "Chat completion failed", + model=request.model, + provider=routing_result.provider if routing_result else None, + error=str(e), + ) + raise HTTPException( + status_code=502, + detail={ + "error": { + "type": "provider_error", + "message": str(e), + } + }, + ) + except Exception as e: + error = e + logger.exception("Unexpected error in chat completions") + raise HTTPException( + status_code=500, + detail={ + "error": { + "type": "internal_error", + "message": str(e), + } + }, + ) + finally: + # Log request + await _log_request( + db, + api_key.id, + routing_result, + request, + response, + error, + start_time, + ) + + +async def _calculate_cost( + db: AsyncSession, + routing_result: RoutingResult, + usage: Any, +) -> Decimal: + """Calculate cost for a request.""" + from app.models.model_alias import ModelAlias + import json + + # Try to find pricing from model alias + result = await db.execute( + select(ModelAlias).where(ModelAlias.alias == routing_result.model) + ) + alias = result.scalar_one_or_none() + + if alias and alias.input_price_per_1k and alias.output_price_per_1k: + input_cost = (usage.prompt_tokens / 1000) * alias.input_price_per_1k + output_cost = (usage.completion_tokens / 1000) * alias.output_price_per_1k + return Decimal(str(input_cost + output_cost)) + + # Default pricing (rough estimate) + return Decimal("0.001") + + +async def _log_request( + db: AsyncSession, + key_id: str, + routing_result: RoutingResult | None, + request: OpenAIChatCompletionRequest, + response: OpenAIChatCompletionResponse | None, + error: Exception | None, + start_time: float, +) -> None: + """Log request details.""" + latency_ms = int((time.time() - start_time) * 1000) + + log = RequestLog( + virtual_key_id=key_id, + project_id=None, # Will be filled from API key + provider=routing_result.provider if routing_result else "unknown", + model=request.model, + model_alias=request.model, + request_type="chat", + input_tokens=response.usage.prompt_tokens if response and response.usage else 0, + output_tokens=response.usage.completion_tokens if response and response.usage else 0, + total_tokens=response.usage.total_tokens if response and response.usage else 0, + status_code=500 if error else (response and 200 or 500), + latency_ms=latency_ms, + finish_reason=response.choices[0].finish_reason if response and response.choices else None, + cost_usd=Decimal("0.001") if response and response.usage else Decimal("0"), + ) + + db.add(log) + await db.commit() diff --git a/llm-gateway/app/api/v1/messages.py b/llm-gateway/app/api/v1/messages.py new file mode 100644 index 0000000..bf9a6c7 --- /dev/null +++ b/llm-gateway/app/api/v1/messages.py @@ -0,0 +1,53 @@ +"""Anthropic Messages API endpoint.""" +from typing import Annotated + +from fastapi import APIRouter, Depends + +from app.api.v1.chat import authenticate, _calculate_cost, _log_request +from app.core.transformer import RequestTransformer +from app.db.database import get_db +from app.models.api_key import APIKey +from app.schemas.anthropic import AnthropicMessagesRequest, AnthropicMessagesResponse +from app.schemas.openai import OpenAIChatCompletionRequest + +router = APIRouter(prefix="/v1", tags=["Messages"]) + + +@router.post("/messages") +async def messages( + request: AnthropicMessagesRequest, + db: Annotated[None, Depends(get_db)], + api_key: Annotated[APIKey, Depends(authenticate)], +) -> AnthropicMessagesResponse: + """ + Execute an Anthropic Messages API request. + + This endpoint accepts Anthropic-format requests and forwards them + to the appropriate provider (Anthropic, OpenAI, etc.). + """ + from app.api.v1.chat import chat_completions + from app.schemas.openai import OpenAIChatCompletionRequest + + # Convert Anthropic request to OpenAI format + transformer = RequestTransformer() + openai_request = transformer.anthropic_to_openai(request) + + # Use the chat completions endpoint + from fastapi import HTTPException + from sqlalchemy.ext.asyncio import AsyncSession + + # Re-import with correct type + from app.db.database import get_db as _get_db + from typing import AsyncGenerator + + # Get database session + async for session in _get_db(): + # Call chat completions with converted request + response = await chat_completions(openai_request, session, api_key) + + # Convert response back to Anthropic format + anthropic_response = transformer.openai_response_to_anthropic(response) + + return anthropic_response + + raise HTTPException(status_code=500, detail="Database session error") diff --git a/llm-gateway/app/api/v1/responses.py b/llm-gateway/app/api/v1/responses.py new file mode 100644 index 0000000..c532b72 --- /dev/null +++ b/llm-gateway/app/api/v1/responses.py @@ -0,0 +1,72 @@ +"""OpenAI Responses API endpoint.""" +from typing import Annotated + +from fastapi import APIRouter, Depends + +from app.api.v1.chat import authenticate +from app.db.database import get_db +from app.models.api_key import APIKey +from app.schemas.openai import OpenAIResponseRequest + +router = APIRouter(prefix="/v1", tags=["Responses"]) + + +@router.post("/responses") +async def responses( + request: OpenAIResponseRequest, + db: Annotated[None, Depends(get_db)], + api_key: Annotated[APIKey, Depends(authenticate)], +) -> dict: + """ + Execute an OpenAI Responses API request. + + This is the new OpenAI Responses API format. + Converts to chat completions internally. + """ + from app.schemas.openai import OpenAIChatCompletionRequest, OpenAIChatMessage + from app.api.v1.chat import chat_completions + from app.db.database import get_db as _get_db + + # Convert Responses API format to Chat Completions format + messages = [] + + # Add instructions as system message if present + if request.instructions: + messages.append(OpenAIChatMessage(role="system", content=request.instructions)) + + # Add input as user message + if isinstance(request.input, str): + messages.append(OpenAIChatMessage(role="user", content=request.input)) + else: + # Handle structured input + messages.append(OpenAIChatMessage(role="user", content=str(request.input))) + + chat_request = OpenAIChatCompletionRequest( + model=request.model, + messages=messages, + max_tokens=request.max_output_tokens, + temperature=request.temperature, + tools=request.tools, + tool_choice=request.tool_choice, + ) + + # Get database session and call chat completions + async for session in _get_db(): + response = await chat_completions(chat_request, session, api_key) + + # Convert to Responses API format + return { + "id": response.id, + "object": "response", + "created": response.created, + "model": response.model, + "output": response.choices[0].message.content if response.choices else "", + "usage": { + "input_tokens": response.usage.prompt_tokens if response.usage else 0, + "output_tokens": response.usage.completion_tokens if response.usage else 0, + "total_tokens": response.usage.total_tokens if response.usage else 0, + }, + } + + from fastapi import HTTPException + raise HTTPException(status_code=500, detail="Database session error") diff --git a/llm-gateway/app/config.py b/llm-gateway/app/config.py new file mode 100644 index 0000000..0278105 --- /dev/null +++ b/llm-gateway/app/config.py @@ -0,0 +1,73 @@ +"""Application configuration using Pydantic Settings.""" +import secrets +from functools import lru_cache +from pathlib import Path + +from pydantic import Field, field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """Application settings loaded from environment variables.""" + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + case_sensitive=False, + extra="ignore", + ) + + # Application + app_name: str = "LLM Gateway" + debug: bool = False + log_level: str = "INFO" + + # Server + host: str = "0.0.0.0" + port: int = 8000 + + # Database + database_url: str = Field( + default="sqlite:///data/gateway.db", + description="SQLite database URL", + ) + + # Security + master_key: str = Field( + default_factory=lambda: secrets.token_hex(32), + description="Master key for encrypting provider API keys", + ) + + # Rate Limiting + rate_limit_window_seconds: int = 60 + global_rpm_limit: int | None = None + global_tpm_limit: int | None = None + + # Retry + max_retries: int = 3 + retry_initial_delay: float = 1.0 + retry_max_delay: float = 30.0 + retry_exponential_base: float = 2.0 + + # Health Check + health_check_interval: int = 30 + health_check_timeout: int = 10 + + @field_validator("master_key") + @classmethod + def validate_master_key(cls, v: str) -> str: + """Ensure master key is at least 32 characters.""" + if len(v) < 32: + raise ValueError("Master key must be at least 32 characters") + return v + + @staticmethod + def generate_master_key() -> str: + """Generate a secure random master key.""" + return secrets.token_hex(32) + + +@lru_cache +def get_settings() -> Settings: + """Get cached settings instance.""" + return Settings() diff --git a/llm-gateway/app/core/__init__.py b/llm-gateway/app/core/__init__.py new file mode 100644 index 0000000..7c111a2 --- /dev/null +++ b/llm-gateway/app/core/__init__.py @@ -0,0 +1 @@ +# core module diff --git a/llm-gateway/app/core/budget.py b/llm-gateway/app/core/budget.py new file mode 100644 index 0000000..744ba1e --- /dev/null +++ b/llm-gateway/app/core/budget.py @@ -0,0 +1,160 @@ +"""Budget controller for spending limits.""" +from datetime import datetime, timedelta +from decimal import Decimal +from typing import Any + +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.api_key import APIKey +from app.models.project import Project +from app.models.usage import RequestLog + + +class BudgetExceeded(Exception): + """Exception raised when budget is exceeded.""" + + def __init__( + self, + budget_type: str, + limit: Decimal, + current_usage: Decimal, + scope: str, + ): + self.budget_type = budget_type + self.limit = limit + self.current_usage = current_usage + self.scope = scope + super().__init__(f"Budget exceeded: {scope} {budget_type}") + + +class BudgetController: + """Controller for budget limits at key and project levels.""" + + def __init__(self, db: AsyncSession): + self.db = db + + async def check_budget( + self, + key_id: str, + estimated_cost: Decimal = Decimal("0"), + ) -> dict[str, Any]: + """ + Check budget limits before making a request. + + Args: + key_id: The API key ID. + estimated_cost: Estimated cost for this request. + + Returns: + Dict with budget status. + + Raises: + BudgetExceeded: If budget is exceeded. + """ + # Get API key + result = await self.db.execute(select(APIKey).where(APIKey.id == key_id)) + api_key = result.scalar_one_or_none() + + if not api_key: + raise ValueError(f"API key '{key_id}' not found") + + now = datetime.utcnow() + + # Check key-level budget + if api_key.budget_limit: + period_start = self._get_period_start(api_key.budget_period, now) + current_usage = await self._get_usage_for_period(key_id, None, period_start, now) + + if current_usage + estimated_cost > api_key.budget_limit: + raise BudgetExceeded( + budget_type=api_key.budget_period or "total", + limit=api_key.budget_limit, + current_usage=current_usage, + scope=f"key:{api_key.name}", + ) + + # Check soft limit (90% of hard limit) + soft_limit = api_key.budget_limit * Decimal("0.9") + if current_usage + estimated_cost > soft_limit: + # Log warning but don't block + pass + + # Check project-level budget + if api_key.project_id: + result = await self.db.execute( + select(Project).where(Project.id == api_key.project_id) + ) + project = result.scalar_one_or_none() + + if project and project.budget_limit: + period_start = self._get_period_start(project.budget_period, now) + current_usage = await self._get_usage_for_period( + None, project.id, period_start, now + ) + + if current_usage + estimated_cost > project.budget_limit: + raise BudgetExceeded( + budget_type=project.budget_period or "total", + limit=project.budget_limit, + current_usage=current_usage, + scope=f"project:{project.name}", + ) + + return { + "key_budget_limit": float(api_key.budget_limit) if api_key.budget_limit else None, + "key_current_usage": float(api_key.current_usage), + "project_budget_limit": None, + "project_current_usage": None, + } + + async def record_usage( + self, + key_id: str, + cost: Decimal, + input_tokens: int, + output_tokens: int, + ) -> None: + """Record actual usage after request completes.""" + result = await self.db.execute(select(APIKey).where(APIKey.id == key_id)) + api_key = result.scalar_one_or_none() + + if api_key: + api_key.current_usage += cost + api_key.total_requests += 1 + + def _get_period_start(self, period: str | None, now: datetime) -> datetime: + """Get the start of the budget period.""" + if period == "daily": + return now.replace(hour=0, minute=0, second=0, microsecond=0) + elif period == "weekly": + start = now - timedelta(days=now.weekday()) + return start.replace(hour=0, minute=0, second=0, microsecond=0) + elif period == "monthly": + return now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + else: + # No period means all-time budget + return datetime.min + + async def _get_usage_for_period( + self, + key_id: str | None, + project_id: str | None, + period_start: datetime, + now: datetime, + ) -> Decimal: + """Get total usage for a period.""" + query = select(func.sum(RequestLog.cost_usd)).where( + RequestLog.timestamp >= period_start, + RequestLog.timestamp <= now, + ) + + if key_id: + query = query.where(RequestLog.virtual_key_id == key_id) + if project_id: + query = query.where(RequestLog.project_id == project_id) + + result = await self.db.execute(query) + total = result.scalar() + + return Decimal(str(total or 0)) diff --git a/llm-gateway/app/core/circuit_breaker.py b/llm-gateway/app/core/circuit_breaker.py new file mode 100644 index 0000000..f274d84 --- /dev/null +++ b/llm-gateway/app/core/circuit_breaker.py @@ -0,0 +1,168 @@ +"""Circuit breaker for provider fault tolerance.""" +import asyncio +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class CircuitState(Enum): + """Circuit breaker states.""" + + CLOSED = "closed" # Normal operation + OPEN = "open" # Failing, reject all requests + HALF_OPEN = "half_open" # Testing if recovered + + +@dataclass +class CircuitStats: + """Statistics for circuit breaker.""" + + failure_count: int = 0 + success_count: int = 0 + last_failure_time: float = 0 + last_success_time: float = 0 + state: CircuitState = CircuitState.CLOSED + last_state_change: float = field(default_factory=time.time) + + +class CircuitBreaker: + """ + Circuit breaker for provider fault tolerance. + + States: + - CLOSED: Normal operation, requests pass through + - OPEN: Too many failures, reject all requests + - HALF_OPEN: Testing if provider has recovered + """ + + def __init__( + self, + failure_threshold: int = 5, + recovery_timeout: float = 30.0, + success_threshold: int = 2, + ): + """ + Initialize circuit breaker. + + Args: + failure_threshold: Number of failures before opening circuit. + recovery_timeout: Seconds to wait before trying half-open state. + success_threshold: Number of successes in half-open to close circuit. + """ + self.failure_threshold = failure_threshold + self.recovery_timeout = recovery_timeout + self.success_threshold = success_threshold + self._circuits: dict[str, CircuitStats] = {} + + def get_stats(self, provider: str) -> CircuitStats: + """Get or create circuit stats for a provider.""" + if provider not in self._circuits: + self._circuits[provider] = CircuitStats() + return self._circuits[provider] + + def is_available(self, provider: str) -> bool: + """ + Check if a provider is available (circuit is not open). + + Args: + provider: Provider name. + + Returns: + True if requests should be allowed. + """ + stats = self.get_stats(provider) + now = time.time() + + if stats.state == CircuitState.CLOSED: + return True + + elif stats.state == CircuitState.OPEN: + # Check if recovery timeout has passed + if now - stats.last_state_change >= self.recovery_timeout: + # Transition to half-open + stats.state = CircuitState.HALF_OPEN + stats.success_count = 0 + stats.last_state_change = now + return True + return False + + elif stats.state == CircuitState.HALF_OPEN: + return True + + return True + + def record_success(self, provider: str) -> None: + """ + Record a successful request. + + Args: + provider: Provider name. + """ + stats = self.get_stats(provider) + now = time.time() + + stats.success_count += 1 + stats.last_success_time = now + stats.failure_count = 0 # Reset failure count on success + + if stats.state == CircuitState.HALF_OPEN: + if stats.success_count >= self.success_threshold: + # Transition to closed + stats.state = CircuitState.CLOSED + stats.last_state_change = now + + def record_failure(self, provider: str) -> None: + """ + Record a failed request. + + Args: + provider: Provider name. + """ + stats = self.get_stats(provider) + now = time.time() + + stats.failure_count += 1 + stats.last_failure_time = now + + if stats.state == CircuitState.HALF_OPEN: + # Immediately open on failure in half-open + stats.state = CircuitState.OPEN + stats.last_state_change = now + + elif stats.state == CircuitState.CLOSED: + if stats.failure_count >= self.failure_threshold: + # Transition to open + stats.state = CircuitState.OPEN + stats.last_state_change = now + + def get_state(self, provider: str) -> CircuitState: + """Get current circuit state for a provider.""" + return self.get_stats(provider).state + + def reset(self, provider: str) -> None: + """Reset circuit breaker for a provider.""" + if provider in self._circuits: + del self._circuits[provider] + + def reset_all(self) -> None: + """Reset all circuit breakers.""" + self._circuits.clear() + + +# Global circuit breaker instance +_circuit_breaker: CircuitBreaker | None = None + + +def get_circuit_breaker() -> CircuitBreaker: + """Get the global circuit breaker instance.""" + global _circuit_breaker + if _circuit_breaker is None: + from app.config import get_settings + + settings = get_settings() + _circuit_breaker = CircuitBreaker( + failure_threshold=5, + recovery_timeout=30.0, + ) + return _circuit_breaker diff --git a/llm-gateway/app/core/fallback.py b/llm-gateway/app/core/fallback.py new file mode 100644 index 0000000..3944faf --- /dev/null +++ b/llm-gateway/app/core/fallback.py @@ -0,0 +1,223 @@ +"""Fallback and retry logic for provider failures.""" +import asyncio +import random +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, TypeVar + +from app.config import get_settings +from app.core.circuit_breaker import CircuitBreaker, get_circuit_breaker +from app.utils.logging import get_logger + +logger = get_logger(__name__) + +T = TypeVar("T") + + +class RetryableError(Exception): + """Error that can be retried.""" + + def __init__(self, message: str, error_type: str = "unknown"): + self.message = message + self.error_type = error_type + super().__init__(message) + + +class PermanentError(Exception): + """Error that should not be retried.""" + + def __init__(self, message: str, error_type: str = "unknown"): + self.message = message + self.error_type = error_type + super().__init__(message) + + +@dataclass +class RetryConfig: + """Retry configuration.""" + + max_retries: int = 3 + initial_delay: float = 1.0 + max_delay: float = 30.0 + exponential_base: float = 2.0 + jitter: bool = True + retryable_errors: set[str] | None = None + + def __post_init__(self): + if self.retryable_errors is None: + self.retryable_errors = { + "rate_limit_exceeded", + "timeout", + "service_unavailable", + "internal_error", + "connection_error", + "overloaded", + } + + +class RetryExecutor: + """Executor for retrying failed operations.""" + + def __init__(self, config: RetryConfig | None = None): + self.config = config or RetryConfig() + self.settings = get_settings() + + async def execute( + self, + operation: Callable[[], T], + provider: str, + operation_name: str = "operation", + ) -> T: + """ + Execute an operation with retry logic. + + Args: + operation: Async function to execute. + provider: Provider name for circuit breaker. + operation_name: Name for logging. + + Returns: + Result of the operation. + + Raises: + PermanentError: If operation fails permanently. + RetryableError: If all retries are exhausted. + """ + circuit_breaker = get_circuit_breaker() + last_error: Exception | None = None + + for attempt in range(self.config.max_retries + 1): + # Check circuit breaker + if not circuit_breaker.is_available(provider): + raise RetryableError( + f"Circuit breaker open for provider '{provider}'", + error_type="circuit_open", + ) + + try: + result = await operation() + circuit_breaker.record_success(provider) + return result + + except PermanentError as e: + logger.warning( + "Permanent error in %s", + operation_name, + provider=provider, + attempt=attempt, + error=e.message, + ) + circuit_breaker.record_failure(provider) + raise + + except RetryableError as e: + last_error = e + circuit_breaker.record_failure(provider) + + if attempt < self.config.max_retries: + delay = self._calculate_delay(attempt) + logger.warning( + "Retryable error in %s, retrying in %.2fs", + operation_name, + delay, + provider=provider, + attempt=attempt, + error=e.message, + error_type=e.error_type, + ) + await asyncio.sleep(delay) + else: + logger.error( + "All retries exhausted for %s", + operation_name, + provider=provider, + attempts=attempt + 1, + error=e.message, + ) + + except Exception as e: + # Unknown error - treat as retryable + last_error = RetryableError(str(e), error_type="unknown") + circuit_breaker.record_failure(provider) + + if attempt < self.config.max_retries: + delay = self._calculate_delay(attempt) + logger.warning( + "Unknown error in %s, retrying in %.2fs", + operation_name, + delay, + provider=provider, + attempt=attempt, + error=str(e), + ) + await asyncio.sleep(delay) + + # All retries exhausted + if last_error: + raise last_error + raise RetryableError("All retries exhausted", error_type="max_retries") + + def _calculate_delay(self, attempt: int) -> float: + """Calculate delay for a retry attempt with exponential backoff and jitter.""" + delay = self.config.initial_delay * (self.config.exponential_base**attempt) + delay = min(delay, self.config.max_delay) + + if self.config.jitter: + # Add jitter (0.5 to 1.5 times the delay) + jitter = random.uniform(0.5, 1.5) + delay *= jitter + + return delay + + +def classify_error(error: Exception) -> Exception: + """ + Classify an error as retryable or permanent. + + Args: + error: The error to classify. + + Returns: + RetryableError or PermanentError. + """ + error_str = str(error).lower() + + # Retryable errors + retryable_patterns = [ + "rate limit", + "429", + "timeout", + "503", + "502", + "service unavailable", + "internal error", + "500", + "connection", + "overloaded", + "capacity", + ] + + for pattern in retryable_patterns: + if pattern in error_str: + return RetryableError(str(error), error_type=pattern.replace(" ", "_")) + + # Permanent errors + permanent_patterns = [ + "401", + "unauthorized", + "invalid api key", + "403", + "forbidden", + "400", + "invalid request", + "404", + "not found", + "model not found", + ] + + for pattern in permanent_patterns: + if pattern in error_str: + return PermanentError(str(error), error_type=pattern.replace(" ", "_")) + + # Default to retryable + return RetryableError(str(error), error_type="unknown") diff --git a/llm-gateway/app/core/load_balancer.py b/llm-gateway/app/core/load_balancer.py new file mode 100644 index 0000000..490117f --- /dev/null +++ b/llm-gateway/app/core/load_balancer.py @@ -0,0 +1,137 @@ +"""Load balancer for distributing requests across providers.""" +import random +from dataclasses import dataclass +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.provider import Provider + + +@dataclass +class ProviderInfo: + """Information about a provider for load balancing.""" + + name: str + weight: int + health_status: str + current_requests: int = 0 + + +class LoadBalancer: + """Load balancer for distributing requests across providers.""" + + def __init__(self, db: AsyncSession): + self.db = db + + async def select_provider( + self, + providers: list[dict[str, Any]], + strategy: str = "weighted_round_robin", + ) -> dict[str, Any] | None: + """ + Select a provider based on the specified strategy. + + Args: + providers: List of provider configs with weights. + strategy: Load balancing strategy. + + Returns: + Selected provider config, or None if none available. + """ + if not providers: + return None + + if len(providers) == 1: + return providers[0] + + # Filter healthy providers + healthy_providers = await self._filter_healthy_providers(providers) + + if not healthy_providers: + return providers[0] # Fallback to first if all unhealthy + + if strategy == "weighted_round_robin": + return self._weighted_round_robin(healthy_providers) + elif strategy == "random": + return self._random_select(healthy_providers) + elif strategy == "least_connections": + return await self._least_connections(healthy_providers) + else: + return self._weighted_round_robin(healthy_providers) + + def _weighted_round_robin( + self, providers: list[dict[str, Any]] + ) -> dict[str, Any]: + """Select provider using weighted round robin.""" + total_weight = sum(p.get("weight", 1) for p in providers) + r = random.uniform(0, total_weight) + + cumulative = 0 + for provider in providers: + cumulative += provider.get("weight", 1) + if r <= cumulative: + return provider + + return providers[0] + + def _random_select(self, providers: list[dict[str, Any]]) -> dict[str, Any]: + """Select provider randomly.""" + return random.choice(providers) + + async def _least_connections( + self, providers: list[dict[str, Any]] + ) -> dict[str, Any]: + """Select provider with least connections (placeholder).""" + # In a real implementation, this would track active connections + # For now, use random as a fallback + return self._random_select(providers) + + async def _filter_healthy_providers( + self, providers: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + """Filter out unhealthy providers.""" + healthy = [] + + for p in providers: + provider_name = p.get("provider") + if not provider_name: + continue + + result = await self.db.execute( + select(Provider).where( + Provider.name == provider_name, + Provider.enabled == True, + Provider.health_status == "healthy", + ) + ) + provider = result.scalar_one_or_none() + + if provider: + healthy.append(p) + + return healthy + + async def get_provider_health( + self, provider_name: str + ) -> dict[str, Any]: + """Get health status for a provider.""" + result = await self.db.execute( + select(Provider).where(Provider.name == provider_name) + ) + provider = result.scalar_one_or_none() + + if not provider: + return { + "name": provider_name, + "status": "not_found", + "available": False, + } + + return { + "name": provider.name, + "status": provider.health_status, + "available": provider.enabled and provider.health_status == "healthy", + "last_check": provider.last_health_check.isoformat() if provider.last_health_check else None, + } diff --git a/llm-gateway/app/core/rate_limiter.py b/llm-gateway/app/core/rate_limiter.py new file mode 100644 index 0000000..7a37ea6 --- /dev/null +++ b/llm-gateway/app/core/rate_limiter.py @@ -0,0 +1,210 @@ +"""Rate limiter for RPM/TPM limits.""" +import time +from datetime import datetime, timedelta +from typing import Any + +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import get_settings +from app.models.usage import RateLimitCounter + + +class RateLimitExceeded(Exception): + """Exception raised when rate limit is exceeded.""" + + def __init__( + self, + limit_type: str, + limit: int, + remaining: int, + reset_at: datetime, + ): + self.limit_type = limit_type + self.limit = limit + self.remaining = remaining + self.reset_at = reset_at + super().__init__(f"Rate limit exceeded: {limit_type}") + + +class RateLimiter: + """Rate limiter for RPM and TPM limits.""" + + def __init__(self, db: AsyncSession): + self.db = db + self.settings = get_settings() + self.window_seconds = self.settings.rate_limit_window_seconds + + async def check_and_increment( + self, + key: str, + rpm_limit: int | None = None, + tpm_limit: int | None = None, + estimated_tokens: int = 0, + ) -> dict[str, Any]: + """ + Check rate limits and increment counters. + + Args: + key: Unique key for rate limiting (e.g., "key:{key_id}:rpm"). + rpm_limit: Requests per minute limit. + tpm_limit: Tokens per minute limit. + estimated_tokens: Estimated tokens for this request. + + Returns: + Dict with remaining limits. + + Raises: + RateLimitExceeded: If rate limit is exceeded. + """ + now = datetime.utcnow() + window_start = now - timedelta(seconds=self.window_seconds) + + # Check global limits + if self.settings.global_rpm_limit: + await self._check_limit( + "global:rpm", + self.settings.global_rpm_limit, + window_start, + now, + 1, + "request", + ) + + if self.settings.global_tpm_limit and estimated_tokens > 0: + await self._check_limit( + "global:tpm", + self.settings.global_tpm_limit, + window_start, + now, + estimated_tokens, + "token", + ) + + # Check specific limits + if rpm_limit: + await self._check_limit( + f"{key}:rpm", + rpm_limit, + window_start, + now, + 1, + "request", + ) + + if tpm_limit and estimated_tokens > 0: + await self._check_limit( + f"{key}:tpm", + tpm_limit, + window_start, + now, + estimated_tokens, + "token", + ) + + # Get remaining limits + remaining = await self._get_remaining(key, rpm_limit, tpm_limit, window_start, now) + + return remaining + + async def _check_limit( + self, + key: str, + limit: int, + window_start: datetime, + now: datetime, + increment: int, + counter_type: str, + ) -> None: + """Check if a limit would be exceeded and increment counter.""" + # Get current counter + result = await self.db.execute( + select(RateLimitCounter).where(RateLimitCounter.key == key) + ) + counter = result.scalar_one_or_none() + + if counter: + # Check if window has expired + if counter.window_start < window_start: + # Reset counter for new window + counter.request_count = 0 + counter.token_count = 0 + counter.window_start = now + + # Check limit + current = ( + counter.request_count if counter_type == "request" else counter.token_count + ) + if current >= limit: + reset_at = counter.window_start + timedelta(seconds=counter.window_duration) + raise RateLimitExceeded( + limit_type=f"{counter_type}_per_{self.window_seconds}s", + limit=limit, + remaining=0, + reset_at=reset_at, + ) + + # Increment counter + if counter_type == "request": + counter.request_count += increment + else: + counter.token_count += increment + else: + # Create new counter + counter = RateLimitCounter( + key=key, + request_count=1 if counter_type == "request" else 0, + token_count=increment if counter_type == "token" else 0, + window_start=now, + window_duration=self.window_seconds, + ) + self.db.add(counter) + + async def _get_remaining( + self, + key: str, + rpm_limit: int | None, + tpm_limit: int | None, + window_start: datetime, + now: datetime, + ) -> dict[str, Any]: + """Get remaining limits for a key.""" + remaining = {} + + if rpm_limit: + result = await self.db.execute( + select(RateLimitCounter).where(RateLimitCounter.key == f"{key}:rpm") + ) + counter = result.scalar_one_or_none() + remaining["rpm"] = rpm_limit - (counter.request_count if counter else 0) + + if tpm_limit: + result = await self.db.execute( + select(RateLimitCounter).where(RateLimitCounter.key == f"{key}:tpm") + ) + counter = result.scalar_one_or_none() + remaining["tpm"] = tpm_limit - (counter.token_count if counter else 0) + + remaining["reset_at"] = (now + timedelta(seconds=self.window_seconds)).isoformat() + + return remaining + + async def record_tokens(self, key: str, tokens: int) -> None: + """Record actual token usage after request completes.""" + tpm_key = f"{key}:tpm" + result = await self.db.execute( + select(RateLimitCounter).where(RateLimitCounter.key == tpm_key) + ) + counter = result.scalar_one_or_none() + + if counter: + counter.token_count += tokens + else: + counter = RateLimitCounter( + key=tpm_key, + request_count=0, + token_count=tokens, + window_start=datetime.utcnow(), + window_duration=self.window_seconds, + ) + self.db.add(counter) diff --git a/llm-gateway/app/core/router.py b/llm-gateway/app/core/router.py new file mode 100644 index 0000000..b94bd6c --- /dev/null +++ b/llm-gateway/app/core/router.py @@ -0,0 +1,173 @@ +"""Router for model alias resolution and routing.""" +import json +import random +from dataclasses import dataclass +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.model_alias import ModelAlias +from app.models.provider import Provider + + +@dataclass +class RoutingResult: + """Result of routing decision.""" + + provider: str + model: str + provider_config: dict[str, Any] | None = None + fallback_chain: list[dict[str, str]] | None = None + + +class Router: + """Router for resolving model aliases and making routing decisions.""" + + def __init__(self, db: AsyncSession): + self.db = db + + async def resolve_model(self, model_alias: str) -> RoutingResult: + """ + Resolve a model alias to a provider and model. + + Args: + model_alias: The model alias to resolve. + + Returns: + RoutingResult with provider, model, and optional fallback chain. + + Raises: + ValueError: If the model alias is not found. + """ + # Look up the model alias + result = await self.db.execute( + select(ModelAlias).where(ModelAlias.alias == model_alias, ModelAlias.enabled == True) + ) + alias = result.scalar_one_or_none() + + if not alias: + # If no alias found, treat as direct provider/model reference + # Format: "provider/model" (e.g., "openai/gpt-4") + if "/" in model_alias: + provider, model = model_alias.split("/", 1) + return RoutingResult(provider=provider, model=model) + raise ValueError(f"Model alias '{model_alias}' not found") + + # Parse routing config + routing_config = json.loads(alias.routing_config) if alias.routing_config else None + + if alias.routing_type == "simple": + return RoutingResult( + provider=alias.provider, + model=alias.model, + ) + + elif alias.routing_type == "load_balance": + # Weighted random selection + providers = routing_config.get("providers", []) if routing_config else [] + if not providers: + return RoutingResult(provider=alias.provider, model=alias.model) + + # Filter healthy providers + healthy_providers = await self._filter_healthy_providers(providers) + + if not healthy_providers: + # Fallback to default if all unhealthy + return RoutingResult(provider=alias.provider, model=alias.model) + + # Weighted random selection + total_weight = sum(p.get("weight", 1) for p in healthy_providers) + r = random.uniform(0, total_weight) + + cumulative = 0 + for p in healthy_providers: + cumulative += p.get("weight", 1) + if r <= cumulative: + return RoutingResult( + provider=p["provider"], + model=p.get("model", alias.model), + ) + + # Fallback + return RoutingResult( + provider=healthy_providers[0]["provider"], + model=healthy_providers[0].get("model", alias.model), + ) + + elif alias.routing_type == "fallback": + # Return primary with fallback chain + primary = ( + routing_config.get("primary", {}) if routing_config else {} + ) + fallback = routing_config.get("fallback", []) if routing_config else [] + + return RoutingResult( + provider=primary.get("provider", alias.provider), + model=primary.get("model", alias.model), + fallback_chain=fallback, + ) + + # Default to simple routing + return RoutingResult(provider=alias.provider, model=alias.model) + + async def _filter_healthy_providers( + self, providers: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + """Filter out unhealthy providers.""" + healthy = [] + for p in providers: + provider_name = p.get("provider") + if not provider_name: + continue + + result = await self.db.execute( + select(Provider).where( + Provider.name == provider_name, + Provider.enabled == True, + Provider.health_status == "healthy", + ) + ) + provider = result.scalar_one_or_none() + + if provider: + healthy.append(p) + + return healthy + + async def get_fallback_provider( + self, failed_provider: str, fallback_chain: list[dict[str, str]] | None + ) -> RoutingResult | None: + """ + Get the next fallback provider after a failure. + + Args: + failed_provider: The provider that failed. + fallback_chain: List of fallback providers. + + Returns: + RoutingResult for the next provider, or None if no more fallbacks. + """ + if not fallback_chain: + return None + + # Find the failed provider in the chain and return the next one + for i, fallback in enumerate(fallback_chain): + if fallback.get("provider") == failed_provider and i + 1 < len(fallback_chain): + next_fallback = fallback_chain[i + 1] + return RoutingResult( + provider=next_fallback.get("provider", ""), + model=next_fallback.get("model", ""), + fallback_chain=fallback_chain[i + 1 :], + ) + + # If failed provider not in chain or at the end, try from the beginning + for fallback in fallback_chain: + if fallback.get("provider") != failed_provider: + return RoutingResult( + provider=fallback.get("provider", ""), + model=fallback.get("model", ""), + fallback_chain=fallback_chain, + ) + + return None diff --git a/llm-gateway/app/core/transformer.py b/llm-gateway/app/core/transformer.py new file mode 100644 index 0000000..37e87aa --- /dev/null +++ b/llm-gateway/app/core/transformer.py @@ -0,0 +1,228 @@ +"""Request transformer for converting between API formats.""" +import time +import uuid +from typing import Any + +from app.schemas.anthropic import ( + AnthropicContentBlock, + AnthropicMessage, + AnthropicMessagesRequest, + AnthropicMessagesResponse, + AnthropicTextBlock, + AnthropicUsage, +) +from app.schemas.openai import ( + OpenAIChatCompletionChoice, + OpenAIChatCompletionRequest, + OpenAIChatCompletionResponse, + OpenAIChatMessage, + OpenAIUsage, +) + + +class RequestTransformer: + """Transform requests between different API formats.""" + + @staticmethod + def openai_to_anthropic( + request: OpenAIChatCompletionRequest, + ) -> AnthropicMessagesRequest: + """Convert OpenAI Chat Completions request to Anthropic Messages request.""" + messages = [] + system = None + + for msg in request.messages: + if msg.role == "system": + # Extract system message + system = msg.content if isinstance(msg.content, str) else None + elif msg.role in ("user", "assistant"): + # Convert content to Anthropic format + if isinstance(msg.content, str): + content = msg.content + elif isinstance(msg.content, list): + # Handle multi-modal content + content = [] + for part in msg.content: + if part.get("type") == "text": + content.append( + AnthropicContentBlock(type="text", text=part.get("text", "")) + ) + elif part.get("type") == "image_url": + # Convert OpenAI image format to Anthropic + content.append( + AnthropicContentBlock( + type="image", + source={ + "type": "url", + "url": part.get("image_url", {}).get("url", ""), + }, + ) + ) + else: + content.append( + AnthropicContentBlock(type="text", text=str(part)) + ) + else: + content = str(msg.content) if msg.content else "" + + messages.append( + AnthropicMessage( + role=msg.role, # type: ignore + content=content, + ) + ) + + # Handle max_tokens + max_tokens = request.max_tokens or 4096 + + return AnthropicMessagesRequest( + model=request.model, + messages=messages, + system=system, + max_tokens=max_tokens, + temperature=request.temperature, + top_p=request.top_p, + stop_sequences=[request.stop] if isinstance(request.stop, str) else request.stop, + stream=request.stream, + tools=request.tools, + tool_choice=request.tool_choice, + ) + + @staticmethod + def anthropic_to_openai( + request: AnthropicMessagesRequest, + model: str | None = None, + ) -> OpenAIChatCompletionRequest: + """Convert Anthropic Messages request to OpenAI Chat Completions request.""" + messages = [] + + # Add system message if present + if request.system: + messages.append( + OpenAIChatMessage( + role="system", + content=request.system, + ) + ) + + # Convert messages + for msg in request.messages: + if isinstance(msg.content, str): + content = msg.content + elif isinstance(msg.content, list): + # Convert Anthropic content blocks to OpenAI format + content = [] + for block in msg.content: + if hasattr(block, "type") and block.type == "text": + content.append({"type": "text", "text": block.text or ""}) + elif isinstance(block, dict) and block.get("type") == "text": + content.append({"type": "text", "text": block.get("text", "")}) + else: + content.append({"type": "text", "text": str(block)}) + else: + content = str(msg.content) + + messages.append( + OpenAIChatMessage( + role=msg.role, # type: ignore + content=content, + ) + ) + + return OpenAIChatCompletionRequest( + model=model or request.model, + messages=messages, + max_tokens=request.max_tokens, + temperature=request.temperature, + top_p=request.top_p, + stream=request.stream, + tools=request.tools, + tool_choice=request.tool_choice, + ) + + @staticmethod + def anthropic_response_to_openai( + response: AnthropicMessagesResponse, + model: str, + ) -> OpenAIChatCompletionResponse: + """Convert Anthropic Messages response to OpenAI Chat Completions response.""" + # Extract text content + content = "" + for block in response.content: + if hasattr(block, "type") and block.type == "text": + content += block.text or "" + elif isinstance(block, dict) and block.get("type") == "text": + content += block.get("text", "") + + # Map stop reasons + finish_reason_map = { + "end_turn": "stop", + "max_tokens": "length", + "stop_sequence": "stop", + "tool_use": "tool_calls", + } + finish_reason = finish_reason_map.get(response.stop_reason or "", "stop") + + return OpenAIChatCompletionResponse( + id=response.id, + object="chat.completion", + created=int(time.time()), + model=model, + choices=[ + OpenAIChatCompletionChoice( + index=0, + message=OpenAIChatMessage( + role="assistant", + content=content, + ), + finish_reason=finish_reason, + ) + ], + usage=OpenAIUsage( + prompt_tokens=response.usage.input_tokens, + completion_tokens=response.usage.output_tokens, + total_tokens=response.usage.input_tokens + response.usage.output_tokens, + ) + if response.usage + else None, + ) + + @staticmethod + def openai_response_to_anthropic( + response: OpenAIChatCompletionResponse, + ) -> AnthropicMessagesResponse: + """Convert OpenAI Chat Completions response to Anthropic Messages response.""" + # Extract content from choices + content = [] + for choice in response.choices: + if choice.message.content: + content.append( + AnthropicTextBlock( + type="text", + text=choice.message.content, + ) + ) + + # Map finish reasons + stop_reason_map = { + "stop": "end_turn", + "length": "max_tokens", + "tool_calls": "tool_use", + "content_filter": "end_turn", + } + stop_reason = stop_reason_map.get( + response.choices[0].finish_reason or "", "end_turn" + ) + + return AnthropicMessagesResponse( + id=response.id or f"msg_{uuid.uuid4().hex[:24]}", + type="message", + role="assistant", + content=content, + model=response.model, + stop_reason=stop_reason, + usage=AnthropicUsage( + input_tokens=response.usage.prompt_tokens if response.usage else 0, + output_tokens=response.usage.completion_tokens if response.usage else 0, + ), + ) diff --git a/llm-gateway/app/db/__init__.py b/llm-gateway/app/db/__init__.py new file mode 100644 index 0000000..fc4742a --- /dev/null +++ b/llm-gateway/app/db/__init__.py @@ -0,0 +1 @@ +# db module diff --git a/llm-gateway/app/db/database.py b/llm-gateway/app/db/database.py new file mode 100644 index 0000000..5ed1ffc --- /dev/null +++ b/llm-gateway/app/db/database.py @@ -0,0 +1,70 @@ +"""Database connection and session management.""" +import os +from pathlib import Path + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import DeclarativeBase + +from app.config import get_settings + + +class Base(DeclarativeBase): + """Base class for all models.""" + + pass + + +# Create async engine +def get_engine(): + """Create database engine based on settings.""" + settings = get_settings() + db_url = settings.database_url + + # Ensure data directory exists for SQLite + if db_url.startswith("sqlite:///"): + db_path = db_url.replace("sqlite:///", "") + Path(db_path).parent.mkdir(parents=True, exist_ok=True) + # Use aiosqlite for async SQLite + db_url = db_url.replace("sqlite:///", "sqlite+aiosqlite:///") + + return create_async_engine(db_url, echo=settings.debug) + + +engine = get_engine() + +# Session factory +AsyncSessionLocal = async_sessionmaker( + engine, + class_=AsyncSession, + expire_on_commit=False, + autocommit=False, + autoflush=False, +) + + +async def init_db(): + """Initialize database tables.""" + # Import all models to ensure they are registered + from app.models import ( # noqa: F401 + provider, + api_key, + project, + model_alias, + usage, + ) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + +async def get_db(): + """Dependency to get database session.""" + async with AsyncSessionLocal() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() diff --git a/llm-gateway/app/main.py b/llm-gateway/app/main.py new file mode 100644 index 0000000..1b59b4b --- /dev/null +++ b/llm-gateway/app/main.py @@ -0,0 +1,61 @@ +"""FastAPI application entry point.""" +import os +from contextlib import asynccontextmanager + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from app.api.admin import health, keys, models, projects, providers, usage +from app.api.v1 import chat, messages, responses +from app.config import get_settings +from app.db.database import init_db + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan handler.""" + # Startup - skip if testing + if not os.environ.get("TESTING"): + await init_db() + yield + # Shutdown + + +def create_app() -> FastAPI: + """Create and configure the FastAPI application.""" + settings = get_settings() + + app = FastAPI( + title=settings.app_name, + description="A unified LLM Gateway supporting multiple providers and API formats", + version="0.1.0", + debug=settings.debug, + lifespan=lifespan, + ) + + # CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Admin API routers + app.include_router(providers.router, prefix="/admin") + app.include_router(projects.router, prefix="/admin") + app.include_router(keys.router, prefix="/admin") + app.include_router(models.router, prefix="/admin") + app.include_router(usage.router, prefix="/admin") + app.include_router(health.router) + + # LLM API routers + app.include_router(chat.router) + app.include_router(messages.router) + app.include_router(responses.router) + + return app + + +app = create_app() diff --git a/llm-gateway/app/middleware/__init__.py b/llm-gateway/app/middleware/__init__.py new file mode 100644 index 0000000..9b50d18 --- /dev/null +++ b/llm-gateway/app/middleware/__init__.py @@ -0,0 +1 @@ +# middleware module diff --git a/llm-gateway/app/models/__init__.py b/llm-gateway/app/models/__init__.py new file mode 100644 index 0000000..c381246 --- /dev/null +++ b/llm-gateway/app/models/__init__.py @@ -0,0 +1,17 @@ +"""Models package.""" +from app.models.provider import Provider +from app.models.project import Project +from app.models.api_key import APIKey +from app.models.model_alias import ModelAlias +from app.models.usage import RequestLog, UsageStatsHourly, AuditLog, RateLimitCounter + +__all__ = [ + "Provider", + "Project", + "APIKey", + "ModelAlias", + "RequestLog", + "UsageStatsHourly", + "AuditLog", + "RateLimitCounter", +] diff --git a/llm-gateway/app/models/api_key.py b/llm-gateway/app/models/api_key.py new file mode 100644 index 0000000..de81659 --- /dev/null +++ b/llm-gateway/app/models/api_key.py @@ -0,0 +1,48 @@ +"""API Key model.""" +from datetime import datetime +from decimal import Decimal +from uuid import uuid4 + +from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, Numeric, String, Text +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.db.database import Base + + +class APIKey(Base): + """Virtual API Key for authentication and usage tracking.""" + + __tablename__ = "api_keys" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4())) + key_hash: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True) + key_prefix: Mapped[str] = mapped_column(String(20), nullable=False) # e.g., "sk-proj_abc..." + name: Mapped[str] = mapped_column(String(100), nullable=False) + + # Project relationship + project_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("projects.id")) + project: Mapped["Project"] = relationship("Project", backref="api_keys") + + enabled: Mapped[bool] = mapped_column(Boolean, default=True) + expires_at: Mapped[datetime | None] = mapped_column(DateTime) + + # Rate Limits + rpm_limit: Mapped[int | None] = mapped_column(Integer) + tpm_limit: Mapped[int | None] = mapped_column(Integer) + + # Budget + budget_limit: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + budget_period: Mapped[str | None] = mapped_column(String(20)) # daily, weekly, monthly + + # Permissions + allowed_models: Mapped[str | None] = mapped_column(Text) # JSON array + + # Usage Stats + current_usage: Mapped[Decimal] = mapped_column(Numeric(10, 2), default=Decimal("0")) + total_requests: Mapped[int] = mapped_column(Integer, default=0) + + # Timestamps + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow + ) diff --git a/llm-gateway/app/models/model_alias.py b/llm-gateway/app/models/model_alias.py new file mode 100644 index 0000000..8e139d7 --- /dev/null +++ b/llm-gateway/app/models/model_alias.py @@ -0,0 +1,36 @@ +"""Model alias model.""" +from datetime import datetime +from decimal import Decimal +from uuid import uuid4 + +from sqlalchemy import Boolean, DateTime, Numeric, String, Text +from sqlalchemy.orm import Mapped, mapped_column + +from app.db.database import Base + + +class ModelAlias(Base): + """Model alias for routing and cost tracking.""" + + __tablename__ = "model_aliases" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4())) + alias: Mapped[str] = mapped_column(String(100), unique=True, nullable=False, index=True) + provider: Mapped[str] = mapped_column(String(50), nullable=False) + model: Mapped[str] = mapped_column(String(100), nullable=False) + + enabled: Mapped[bool] = mapped_column(Boolean, default=True) + + # Routing: simple, load_balance, fallback + routing_type: Mapped[str] = mapped_column(String(20), default="simple") + routing_config: Mapped[str | None] = mapped_column(Text) # JSON + + # Pricing (per 1K tokens) + input_price_per_1k: Mapped[Decimal | None] = mapped_column(Numeric(10, 6)) + output_price_per_1k: Mapped[Decimal | None] = mapped_column(Numeric(10, 6)) + + # Timestamps + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow + ) diff --git a/llm-gateway/app/models/project.py b/llm-gateway/app/models/project.py new file mode 100644 index 0000000..7075f5a --- /dev/null +++ b/llm-gateway/app/models/project.py @@ -0,0 +1,31 @@ +"""Project model.""" +from datetime import datetime +from decimal import Decimal +from uuid import uuid4 + +from sqlalchemy import Boolean, DateTime, Numeric, String, Text +from sqlalchemy.orm import Mapped, mapped_column + +from app.db.database import Base + + +class Project(Base): + """Project for organizing API keys and tracking usage.""" + + __tablename__ = "projects" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4())) + name: Mapped[str] = mapped_column(String(100), nullable=False, index=True) + description: Mapped[str | None] = mapped_column(Text) + + # Budget + budget_limit: Mapped[Decimal | None] = mapped_column(Numeric(10, 2)) + budget_period: Mapped[str | None] = mapped_column(String(20)) # daily, weekly, monthly + + enabled: Mapped[bool] = mapped_column(Boolean, default=True) + + # Timestamps + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow + ) diff --git a/llm-gateway/app/models/provider.py b/llm-gateway/app/models/provider.py new file mode 100644 index 0000000..19fe336 --- /dev/null +++ b/llm-gateway/app/models/provider.py @@ -0,0 +1,40 @@ +"""Provider model.""" +from datetime import datetime +from uuid import uuid4 + +from sqlalchemy import Boolean, DateTime, Integer, String, Text +from sqlalchemy.orm import Mapped, mapped_column + +from app.db.database import Base + + +class Provider(Base): + """LLM Provider configuration.""" + + __tablename__ = "providers" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4())) + name: Mapped[str] = mapped_column(String(50), unique=True, nullable=False, index=True) + enabled: Mapped[bool] = mapped_column(Boolean, default=True) + + # API Configuration + api_base: Mapped[str] = mapped_column(String(500), nullable=False) + api_key_encrypted: Mapped[str] = mapped_column(Text, nullable=False) + api_version: Mapped[str | None] = mapped_column(String(50)) + + # Provider-specific config (JSON) + config: Mapped[str | None] = mapped_column(Text) + + # Rate Limits + rpm_limit: Mapped[int | None] = mapped_column(Integer) + tpm_limit: Mapped[int | None] = mapped_column(Integer) + + # Health Status + health_status: Mapped[str] = mapped_column(String(20), default="healthy") + last_health_check: Mapped[datetime | None] = mapped_column(DateTime) + + # Timestamps + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow + ) diff --git a/llm-gateway/app/models/usage.py b/llm-gateway/app/models/usage.py new file mode 100644 index 0000000..1f9ab80 --- /dev/null +++ b/llm-gateway/app/models/usage.py @@ -0,0 +1,110 @@ +"""Usage tracking models.""" +from datetime import datetime +from decimal import Decimal +from uuid import uuid4 + +from sqlalchemy import BigInteger, DateTime, ForeignKey, Integer, Numeric, String, Text +from sqlalchemy.orm import Mapped, mapped_column + +from app.db.database import Base + + +class RequestLog(Base): + """Request log for tracking individual API calls.""" + + __tablename__ = "request_logs" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4())) + timestamp: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, index=True) + + # Keys and Projects + virtual_key_id: Mapped[str | None] = mapped_column( + String(36), ForeignKey("api_keys.id"), index=True + ) + project_id: Mapped[str | None] = mapped_column( + String(36), ForeignKey("projects.id"), index=True + ) + + # Request details + provider: Mapped[str] = mapped_column(String(50), nullable=False) + model: Mapped[str] = mapped_column(String(100), nullable=False) + model_alias: Mapped[str | None] = mapped_column(String(100)) + request_type: Mapped[str] = mapped_column(String(20)) # chat, completion, embedding + + # Tokens + input_tokens: Mapped[int] = mapped_column(Integer) + output_tokens: Mapped[int] = mapped_column(Integer) + total_tokens: Mapped[int] = mapped_column(Integer) + + # Response + status_code: Mapped[int] = mapped_column(Integer) + latency_ms: Mapped[int] = mapped_column(Integer) + finish_reason: Mapped[str | None] = mapped_column(String(50)) + + # Cost + cost_usd: Mapped[Decimal] = mapped_column(Numeric(10, 6)) + + # Request metadata (JSON) + request_metadata: Mapped[str | None] = mapped_column(Text) + + +class UsageStatsHourly(Base): + """Hourly aggregated usage statistics.""" + + __tablename__ = "usage_stats_hourly" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + timestamp: Mapped[datetime] = mapped_column(DateTime, nullable=False, index=True) + + virtual_key_id: Mapped[str | None] = mapped_column(String(36), index=True) + project_id: Mapped[str | None] = mapped_column(String(36), index=True) + provider: Mapped[str] = mapped_column(String(50)) + model: Mapped[str] = mapped_column(String(100)) + + # Aggregates + request_count: Mapped[int] = mapped_column(Integer, default=0) + input_tokens: Mapped[int] = mapped_column(BigInteger, default=0) + output_tokens: Mapped[int] = mapped_column(BigInteger, default=0) + total_tokens: Mapped[int] = mapped_column(BigInteger, default=0) + cost_usd: Mapped[Decimal] = mapped_column(Numeric(10, 6), default=Decimal("0")) + + avg_latency_ms: Mapped[int | None] = mapped_column(Integer) + error_count: Mapped[int] = mapped_column(Integer, default=0) + + +class AuditLog(Base): + """Audit log for tracking admin operations.""" + + __tablename__ = "audit_logs" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4())) + timestamp: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, index=True) + + # Who + actor: Mapped[str] = mapped_column(String(100), nullable=False) + + # What + action: Mapped[str] = mapped_column(String(50), nullable=False) # create, update, delete + + # Which + resource: Mapped[str] = mapped_column(String(50), nullable=False) # provider, key, project + resource_id: Mapped[str | None] = mapped_column(String(36)) + + # Changes + changes: Mapped[str | None] = mapped_column(Text) # JSON with before/after + + # Request info + ip_address: Mapped[str | None] = mapped_column(String(50)) + user_agent: Mapped[str | None] = mapped_column(String(500)) + + +class RateLimitCounter(Base): + """Rate limit counter for tracking usage windows.""" + + __tablename__ = "rate_limit_counters" + + key: Mapped[str] = mapped_column(String(200), primary_key=True) + request_count: Mapped[int] = mapped_column(Integer, default=0) + token_count: Mapped[int] = mapped_column(BigInteger, default=0) + window_start: Mapped[datetime] = mapped_column(DateTime, nullable=False) + window_duration: Mapped[int] = mapped_column(Integer, nullable=False) # seconds diff --git a/llm-gateway/app/schemas/__init__.py b/llm-gateway/app/schemas/__init__.py new file mode 100644 index 0000000..ce1d435 --- /dev/null +++ b/llm-gateway/app/schemas/__init__.py @@ -0,0 +1 @@ +# schemas module diff --git a/llm-gateway/app/schemas/anthropic.py b/llm-gateway/app/schemas/anthropic.py new file mode 100644 index 0000000..1a5ff51 --- /dev/null +++ b/llm-gateway/app/schemas/anthropic.py @@ -0,0 +1,107 @@ +"""Anthropic Messages API request and response schemas.""" +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +# === Request Models === + + +class AnthropicContentBlock(BaseModel): + """Anthropic content block.""" + + type: Literal["text", "image", "tool_use", "tool_result"] + text: str | None = None + source: dict[str, Any] | None = None # For images + tool_use_id: str | None = None # For tool_result + content: str | list[dict[str, Any]] | None = None # For tool_result + name: str | None = None # For tool_use + input: dict[str, Any] | None = None # For tool_use + + +class AnthropicMessage(BaseModel): + """Anthropic message.""" + + role: Literal["user", "assistant"] + content: str | list[AnthropicContentBlock] + + +class AnthropicMessagesRequest(BaseModel): + """Anthropic Messages API request.""" + + model: str + messages: list[AnthropicMessage] + system: str | None = None + max_tokens: int = Field(..., ge=1) + temperature: float | None = Field(None, ge=0, le=1) + top_p: float | None = Field(None, ge=0, le=1) + top_k: int | None = Field(None, ge=0) + stop_sequences: list[str] | None = None + stream: bool = False + tools: list[dict[str, Any]] | None = None + tool_choice: dict[str, Any] | None = None + metadata: dict[str, Any] | None = None + + +# === Response Models === + + +class AnthropicTextBlock(BaseModel): + """Anthropic text content block.""" + + type: Literal["text"] = "text" + text: str + + +class AnthropicToolUseBlock(BaseModel): + """Anthropic tool use content block.""" + + type: Literal["tool_use"] = "tool_use" + id: str + name: str + input: dict[str, Any] + + +class AnthropicUsage(BaseModel): + """Anthropic token usage.""" + + input_tokens: int + output_tokens: int + cache_creation_input_tokens: int | None = None + cache_read_input_tokens: int | None = None + + +class AnthropicMessagesResponse(BaseModel): + """Anthropic Messages API response.""" + + id: str + type: Literal["message"] = "message" + role: Literal["assistant"] = "assistant" + content: list[AnthropicTextBlock | AnthropicToolUseBlock] + model: str + stop_reason: Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"] | None = None + stop_sequence: str | None = None + usage: AnthropicUsage + + +class AnthropicStreamMessage(BaseModel): + """Anthropic streaming message start.""" + + type: Literal["message_start"] = "message_start" + message: AnthropicMessagesResponse + + +class AnthropicContentBlockDelta(BaseModel): + """Anthropic streaming content block delta.""" + + type: Literal["content_block_delta"] = "content_block_delta" + index: int + delta: dict[str, Any] + + +class AnthropicMessageDelta(BaseModel): + """Anthropic streaming message delta.""" + + type: Literal["message_delta"] = "message_delta" + delta: dict[str, Any] + usage: AnthropicUsage | None = None diff --git a/llm-gateway/app/schemas/api_key.py b/llm-gateway/app/schemas/api_key.py new file mode 100644 index 0000000..b254b1c --- /dev/null +++ b/llm-gateway/app/schemas/api_key.py @@ -0,0 +1,71 @@ +"""API Key schemas for API requests and responses.""" +from datetime import datetime +from decimal import Decimal + +from pydantic import BaseModel, ConfigDict, Field + +from app.schemas.common import BaseResponse + + +class APIKeyCreate(BaseModel): + """Schema for creating an API key.""" + + name: str = Field(..., min_length=1, max_length=100) + project_id: str | None = Field(None, description="Associated project ID") + prefix: str = Field(default="sk", max_length=10, description="Key prefix") + expires_at: datetime | None = Field(None, description="Expiration date") + rpm_limit: int | None = Field(None, ge=1) + tpm_limit: int | None = Field(None, ge=1) + budget_limit: Decimal | None = Field(None, ge=0) + budget_period: str | None = Field(None, pattern="^(daily|weekly|monthly)$") + allowed_models: list[str] | None = Field(None, description="List of allowed model aliases") + enabled: bool = Field(True) + + +class APIKeyUpdate(BaseModel): + """Schema for updating an API key.""" + + name: str | None = Field(None, min_length=1, max_length=100) + project_id: str | None = None + expires_at: datetime | None = None + rpm_limit: int | None = Field(None, ge=1) + tpm_limit: int | None = Field(None, ge=1) + budget_limit: Decimal | None = Field(None, ge=0) + budget_period: str | None = Field(None, pattern="^(daily|weekly|monthly)$") + allowed_models: list[str] | None = None + enabled: bool | None = None + + +class APIKeyResponse(BaseResponse): + """Schema for API key response.""" + + id: str + key_prefix: str + name: str + project_id: str | None + enabled: bool + expires_at: datetime | None + rpm_limit: int | None + tpm_limit: int | None + budget_limit: Decimal | None + budget_period: str | None + allowed_models: list[str] | None + current_usage: Decimal + total_requests: int + created_at: datetime + updated_at: datetime + + +class APIKeyCreateResponse(APIKeyResponse): + """Schema for API key creation response (includes the full key).""" + + key: str = Field(..., description="Full API key (shown only once)") + + +class APIKeyListResponse(BaseResponse): + """Schema for API key list response.""" + + keys: list[APIKeyResponse] + total: int + page: int + page_size: int diff --git a/llm-gateway/app/schemas/common.py b/llm-gateway/app/schemas/common.py new file mode 100644 index 0000000..a7609e6 --- /dev/null +++ b/llm-gateway/app/schemas/common.py @@ -0,0 +1,30 @@ +"""Common schemas for API requests and responses.""" +from datetime import datetime +from decimal import Decimal +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + + +class BaseResponse(BaseModel): + """Base response model.""" + + model_config = ConfigDict(from_attributes=True) + + +class ErrorResponse(BaseModel): + """Error response model.""" + + error: dict[str, Any] + + +class PaginationParams(BaseModel): + """Pagination parameters.""" + + page: int = Field(default=1, ge=1) + page_size: int = Field(default=20, ge=1, le=100) + + @property + def offset(self) -> int: + """Calculate offset for database query.""" + return (self.page - 1) * self.page_size diff --git a/llm-gateway/app/schemas/model_alias.py b/llm-gateway/app/schemas/model_alias.py new file mode 100644 index 0000000..0335451 --- /dev/null +++ b/llm-gateway/app/schemas/model_alias.py @@ -0,0 +1,71 @@ +"""Model alias schemas for API requests and responses.""" +from datetime import datetime +from decimal import Decimal +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from app.schemas.common import BaseResponse + + +class RoutingConfig(BaseModel): + """Routing configuration for load balancing and fallback.""" + + providers: list[dict[str, Any]] | None = Field( + None, description="List of providers with weights for load balancing" + ) + primary: dict[str, str] | None = Field(None, description="Primary provider for fallback") + fallback: list[dict[str, str]] | None = Field(None, description="Fallback providers") + + +class ModelAliasCreate(BaseModel): + """Schema for creating a model alias.""" + + alias: str = Field(..., min_length=1, max_length=100, description="User-facing model name") + provider: str = Field(..., min_length=1, max_length=50, description="Provider name") + model: str = Field(..., min_length=1, max_length=100, description="Actual model name") + routing_type: str = Field( + default="simple", pattern="^(simple|load_balance|fallback)$" + ) + routing_config: RoutingConfig | None = None + input_price_per_1k: Decimal | None = Field(None, ge=0) + output_price_per_1k: Decimal | None = Field(None, ge=0) + enabled: bool = Field(True) + + +class ModelAliasUpdate(BaseModel): + """Schema for updating a model alias.""" + + alias: str | None = Field(None, min_length=1, max_length=100) + provider: str | None = Field(None, min_length=1, max_length=50) + model: str | None = Field(None, min_length=1, max_length=100) + routing_type: str | None = Field(None, pattern="^(simple|load_balance|fallback)$") + routing_config: RoutingConfig | None = None + input_price_per_1k: Decimal | None = Field(None, ge=0) + output_price_per_1k: Decimal | None = Field(None, ge=0) + enabled: bool | None = None + + +class ModelAliasResponse(BaseResponse): + """Schema for model alias response.""" + + id: str + alias: str + provider: str + model: str + enabled: bool + routing_type: str + routing_config: dict[str, Any] | None + input_price_per_1k: Decimal | None + output_price_per_1k: Decimal | None + created_at: datetime + updated_at: datetime + + +class ModelAliasListResponse(BaseResponse): + """Schema for model alias list response.""" + + aliases: list[ModelAliasResponse] + total: int + page: int + page_size: int diff --git a/llm-gateway/app/schemas/openai.py b/llm-gateway/app/schemas/openai.py new file mode 100644 index 0000000..db68259 --- /dev/null +++ b/llm-gateway/app/schemas/openai.py @@ -0,0 +1,100 @@ +"""OpenAI API request and response schemas.""" +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +# === Request Models === + + +class OpenAIChatMessage(BaseModel): + """OpenAI chat message.""" + + role: Literal["system", "user", "assistant", "tool", "function"] + content: str | list[dict[str, Any]] | None = None + name: str | None = None + tool_calls: list[dict[str, Any]] | None = None + tool_call_id: str | None = None + + +class OpenAIChatCompletionRequest(BaseModel): + """OpenAI Chat Completions API request.""" + + model: str + messages: list[OpenAIChatMessage] + temperature: float | None = Field(None, ge=0, le=2) + top_p: float | None = Field(None, ge=0, le=1) + n: int | None = Field(None, ge=1) + stream: bool = False + stop: str | list[str] | None = None + max_tokens: int | None = Field(None, ge=1) + presence_penalty: float | None = Field(None, ge=-2, le=2) + frequency_penalty: float | None = Field(None, ge=-2, le=2) + logit_bias: dict[str, float] | None = None + user: str | None = None + tools: list[dict[str, Any]] | None = None + tool_choice: str | dict[str, Any] | None = None + response_format: dict[str, Any] | None = None + + +class OpenAIResponseRequest(BaseModel): + """OpenAI Responses API request (new format).""" + + model: str + input: str | list[dict[str, Any]] + instructions: str | None = None + temperature: float | None = Field(None, ge=0, le=2) + max_output_tokens: int | None = Field(None, ge=1) + tools: list[dict[str, Any]] | None = None + tool_choice: str | dict[str, Any] | None = None + metadata: dict[str, Any] | None = None + + +# === Response Models === + + +class OpenAIChatCompletionChoice(BaseModel): + """OpenAI chat completion choice.""" + + index: int + message: OpenAIChatMessage + finish_reason: str | None + + +class OpenAIUsage(BaseModel): + """Token usage information.""" + + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class OpenAIChatCompletionResponse(BaseModel): + """OpenAI Chat Completions API response.""" + + id: str + object: Literal["chat.completion"] = "chat.completion" + created: int + model: str + choices: list[OpenAIChatCompletionChoice] + usage: OpenAIUsage | None = None + system_fingerprint: str | None = None + + +class OpenAIStreamChoice(BaseModel): + """OpenAI streaming choice.""" + + index: int + delta: dict[str, Any] + finish_reason: str | None + + +class OpenAIChatCompletionChunk(BaseModel): + """OpenAI Chat Completions streaming chunk.""" + + id: str + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int + model: str + choices: list[OpenAIStreamChoice] + system_fingerprint: str | None = None diff --git a/llm-gateway/app/schemas/project.py b/llm-gateway/app/schemas/project.py new file mode 100644 index 0000000..bad904e --- /dev/null +++ b/llm-gateway/app/schemas/project.py @@ -0,0 +1,49 @@ +"""Project schemas for API requests and responses.""" +from datetime import datetime +from decimal import Decimal + +from pydantic import BaseModel, ConfigDict, Field + +from app.schemas.common import BaseResponse + + +class ProjectCreate(BaseModel): + """Schema for creating a project.""" + + name: str = Field(..., min_length=1, max_length=100) + description: str | None = Field(None, max_length=1000) + budget_limit: Decimal | None = Field(None, ge=0) + budget_period: str | None = Field(None, pattern="^(daily|weekly|monthly)$") + enabled: bool = Field(True) + + +class ProjectUpdate(BaseModel): + """Schema for updating a project.""" + + name: str | None = Field(None, min_length=1, max_length=100) + description: str | None = Field(None, max_length=1000) + budget_limit: Decimal | None = Field(None, ge=0) + budget_period: str | None = Field(None, pattern="^(daily|weekly|monthly)$") + enabled: bool | None = None + + +class ProjectResponse(BaseResponse): + """Schema for project response.""" + + id: str + name: str + description: str | None + budget_limit: Decimal | None + budget_period: str | None + enabled: bool + created_at: datetime + updated_at: datetime + + +class ProjectListResponse(BaseResponse): + """Schema for project list response.""" + + projects: list[ProjectResponse] + total: int + page: int + page_size: int diff --git a/llm-gateway/app/schemas/provider.py b/llm-gateway/app/schemas/provider.py new file mode 100644 index 0000000..d4f4928 --- /dev/null +++ b/llm-gateway/app/schemas/provider.py @@ -0,0 +1,58 @@ +"""Provider schemas for API requests and responses.""" +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from app.schemas.common import BaseResponse + + +class ProviderCreate(BaseModel): + """Schema for creating a provider.""" + + name: str = Field(..., min_length=1, max_length=50, description="Provider name") + api_base: str = Field(..., min_length=1, max_length=500, description="API base URL") + api_key: str = Field(..., min_length=1, description="API key (will be encrypted)") + api_version: str | None = Field(None, max_length=50, description="API version") + config: dict[str, Any] | None = Field(None, description="Provider-specific configuration") + rpm_limit: int | None = Field(None, ge=1, description="Requests per minute limit") + tpm_limit: int | None = Field(None, ge=1, description="Tokens per minute limit") + enabled: bool = Field(True, description="Whether provider is enabled") + + +class ProviderUpdate(BaseModel): + """Schema for updating a provider.""" + + api_base: str | None = Field(None, min_length=1, max_length=500) + api_key: str | None = Field(None, min_length=1) + api_version: str | None = Field(None, max_length=50) + config: dict[str, Any] | None = None + rpm_limit: int | None = Field(None, ge=1) + tpm_limit: int | None = Field(None, ge=1) + enabled: bool | None = None + + +class ProviderResponse(BaseResponse): + """Schema for provider response.""" + + id: str + name: str + enabled: bool + api_base: str + api_version: str | None + config: dict[str, Any] | None + rpm_limit: int | None + tpm_limit: int | None + health_status: str + last_health_check: datetime | None + created_at: datetime + updated_at: datetime + + +class ProviderListResponse(BaseResponse): + """Schema for provider list response.""" + + providers: list[ProviderResponse] + total: int + page: int + page_size: int diff --git a/llm-gateway/app/utils/__init__.py b/llm-gateway/app/utils/__init__.py new file mode 100644 index 0000000..525d392 --- /dev/null +++ b/llm-gateway/app/utils/__init__.py @@ -0,0 +1 @@ +# utils module diff --git a/llm-gateway/app/utils/crypto.py b/llm-gateway/app/utils/crypto.py new file mode 100644 index 0000000..8709b6f --- /dev/null +++ b/llm-gateway/app/utils/crypto.py @@ -0,0 +1,123 @@ +"""Cryptographic utilities for key management.""" +import base64 +import os +from typing import Tuple + +import bcrypt +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + +from app.config import get_settings + + +def hash_api_key(key: str) -> str: + """ + Hash an API key using bcrypt. + + Args: + key: The plaintext API key to hash. + + Returns: + The hashed key string. + """ + salt = bcrypt.gensalt() + hashed = bcrypt.hashpw(key.encode("utf-8"), salt) + return hashed.decode("utf-8") + + +def verify_api_key(key: str, hashed: str) -> bool: + """ + Verify an API key against its hash. + + Args: + key: The plaintext API key to verify. + hashed: The hashed key to compare against. + + Returns: + True if the key matches, False otherwise. + """ + try: + return bcrypt.checkpw(key.encode("utf-8"), hashed.encode("utf-8")) + except Exception: + return False + + +def generate_api_key(prefix: str = "sk") -> Tuple[str, str, str]: + """ + Generate a new API key. + + Args: + prefix: The prefix for the key (e.g., "sk", "sk-proj"). + + Returns: + Tuple of (full_key, key_hash, key_prefix). + """ + # Generate random key part + random_part = base64.urlsafe_b64encode(os.urandom(24)).decode("utf-8").rstrip("=") + full_key = f"{prefix}_{random_part}" + + # Hash the key + key_hash = hash_api_key(full_key) + + # Create prefix for display (first 12 chars) + key_prefix = full_key[:12] + "..." + + return full_key, key_hash, key_prefix + + +def get_encryption_key() -> bytes: + """ + Derive an encryption key from the master key. + + Returns: + A URL-safe base64-encoded 32-byte key for Fernet. + """ + settings = get_settings() + master_key = settings.master_key.encode("utf-8") + + # Use a fixed salt for deterministic key derivation + # In production, consider using a random salt stored securely + salt = b"llm-gateway-encryption-salt-v1" + + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + ) + + key = base64.urlsafe_b64encode(kdf.derive(master_key)) + return key + + +def encrypt_value(plaintext: str) -> str: + """ + Encrypt a value using Fernet symmetric encryption. + + Args: + plaintext: The value to encrypt. + + Returns: + The encrypted value as a base64 string. + """ + key = get_encryption_key() + f = Fernet(key) + encrypted = f.encrypt(plaintext.encode("utf-8")) + return encrypted.decode("utf-8") + + +def decrypt_value(encrypted: str) -> str: + """ + Decrypt a value using Fernet symmetric encryption. + + Args: + encrypted: The encrypted value as a base64 string. + + Returns: + The decrypted plaintext. + """ + key = get_encryption_key() + f = Fernet(key) + decrypted = f.decrypt(encrypted.encode("utf-8")) + return decrypted.decode("utf-8") diff --git a/llm-gateway/app/utils/logging.py b/llm-gateway/app/utils/logging.py new file mode 100644 index 0000000..b3eb128 --- /dev/null +++ b/llm-gateway/app/utils/logging.py @@ -0,0 +1,65 @@ +"""Structured logging configuration.""" +import logging +import sys + +import structlog +from structlog.types import Processor + +from app.config import get_settings + + +def setup_logging() -> None: + """Configure structured logging for the application.""" + settings = get_settings() + + # Shared processors for both stdlib and structlog + shared_processors: list[Processor] = [ + structlog.contextvars.merge_contextvars, + structlog.stdlib.add_log_level, + structlog.stdlib.add_logger_name, + structlog.processors.TimeStamper(fmt="iso"), + structlog.processors.StackInfoRenderer(), + ] + + # Configure structlog + structlog.configure( + processors=shared_processors + + [ + structlog.stdlib.ProcessorFormatter.wrap_for_formatter, + ], + logger_factory=structlog.stdlib.LoggerFactory(), + wrapper_class=structlog.stdlib.BoundLogger, + cache_logger_on_first_use=True, + ) + + # Configure stdlib logging + handler = logging.StreamHandler(sys.stdout) + + if settings.debug or settings.log_level == "DEBUG": + formatter = structlog.stdlib.ProcessorFormatter( + foreign_pre_chain=shared_processors, + processors=[ + structlog.dev.ConsoleRenderer(colors=True), + ], + ) + else: + formatter = structlog.stdlib.ProcessorFormatter( + foreign_pre_chain=shared_processors, + processors=[ + structlog.processors.format_exc_info, + structlog.processors.JSONRenderer(), + ], + ) + + handler.setFormatter(formatter) + + # Clear existing handlers and add our handler + root_logger = logging.getLogger() + root_logger.handlers.clear() + root_logger.addHandler(handler) + root_logger.setLevel(getattr(logging, settings.log_level.upper(), logging.INFO)) + + +def get_logger(name: str | None = None) -> structlog.stdlib.BoundLogger: + """Get a structured logger instance.""" + return structlog.get_logger(name) diff --git a/llm-gateway/docker-compose.yml b/llm-gateway/docker-compose.yml new file mode 100644 index 0000000..48012cf --- /dev/null +++ b/llm-gateway/docker-compose.yml @@ -0,0 +1,33 @@ +version: '3.8' + +services: + gateway: + build: . + ports: + - "8000:8000" + volumes: + - ./data:/app/data + environment: + - DATABASE_URL=sqlite:///data/gateway.db + - MASTER_KEY=${MASTER_KEY} + - DEBUG=${DEBUG:-false} + - LOG_LEVEL=${LOG_LEVEL:-INFO} + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 10s + + # Optional: Redis for distributed rate limiting (future) + # redis: + # image: redis:7-alpine + # ports: + # - "6379:6379" + # volumes: + # - redis_data:/data + +volumes: + gateway_data: + # redis_data: diff --git a/llm-gateway/requirements.txt b/llm-gateway/requirements.txt new file mode 100644 index 0000000..8e5ad9d --- /dev/null +++ b/llm-gateway/requirements.txt @@ -0,0 +1,37 @@ +# Web Framework +fastapi>=0.109.0 +uvicorn[standard]>=0.27.0 +pydantic>=2.5.0 +pydantic-settings>=2.1.0 + +# HTTP Client +httpx>=0.26.0 + +# Database +sqlalchemy>=2.0.0 +aiosqlite>=0.19.0 + +# Security +bcrypt>=4.1.0 +cryptography>=42.0.0 + +# Logging +structlog>=24.1.0 + +# Testing +pytest>=8.0.0 +pytest-asyncio>=0.23.0 +pytest-cov>=4.1.0 +httpx>=0.26.0 + +# Type Checking +mypy>=1.8.0 + +# Code Quality +ruff>=0.2.0 + +# Provider SDKs +openai>=1.12.0 +anthropic>=0.18.0 +google-generativeai>=0.3.0 +boto3>=1.34.0 diff --git a/llm-gateway/tests/__init__.py b/llm-gateway/tests/__init__.py new file mode 100644 index 0000000..2ab8767 --- /dev/null +++ b/llm-gateway/tests/__init__.py @@ -0,0 +1 @@ +# tests module diff --git a/llm-gateway/tests/conftest.py b/llm-gateway/tests/conftest.py new file mode 100644 index 0000000..389c229 --- /dev/null +++ b/llm-gateway/tests/conftest.py @@ -0,0 +1,106 @@ +"""Pytest configuration and fixtures.""" +import asyncio +import os +import tempfile +from typing import AsyncGenerator, Generator + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +# Set test environment before importing app +os.environ["TESTING"] = "1" + +# Create a temp file for the test database +_test_db_fd, _test_db_path = tempfile.mkstemp(suffix=".db") +os.environ["DATABASE_URL"] = f"sqlite:///{_test_db_path}" +os.environ["MASTER_KEY"] = "test-master-key-for-testing-32cha" + +# Now import the app modules +from app.db import database +from app.db.database import Base, get_db + +# Import all models to register them with Base +from app.models import ( # noqa: F401 + APIKey, + AuditLog, + ModelAlias, + Project, + Provider, + RateLimitCounter, + RequestLog, + UsageStatsHourly, +) + +# Create test engine +_test_db_url = f"sqlite+aiosqlite:///{_test_db_path}" +_test_engine = create_async_engine(_test_db_url, echo=False) +_test_session_factory = async_sessionmaker( + _test_engine, + class_=AsyncSession, + expire_on_commit=False, +) + +# Override the global engine and session factory +database.engine = _test_engine +database.AsyncSessionLocal = _test_session_factory + + +@pytest.fixture(scope="session") +def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]: + """Create an event loop for the test session.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest_asyncio.fixture(scope="session") +async def setup_test_db() -> AsyncGenerator[None, None]: + """Set up test database with all tables.""" + # Create all tables + async with _test_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + yield + + # Clean up + async with _test_engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + await _test_engine.dispose() + os.close(_test_db_fd) + os.unlink(_test_db_path) + + +@pytest_asyncio.fixture +async def db_session(setup_test_db: None) -> AsyncGenerator[AsyncSession, None]: + """Provide a database session for tests.""" + async with _test_session_factory() as session: + yield session + + +async def _get_test_db(): + """Override dependency for test database.""" + async with _test_session_factory() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + + +@pytest_asyncio.fixture +async def client(setup_test_db: None) -> AsyncGenerator[AsyncClient, None]: + """Provide an async HTTP client for API tests.""" + from app.main import create_app + + test_app = create_app() + + # Override the database dependency + test_app.dependency_overrides[get_db] = _get_test_db + + transport = ASGITransport(app=test_app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac diff --git a/llm-gateway/tests/integration/__init__.py b/llm-gateway/tests/integration/__init__.py new file mode 100644 index 0000000..a36bfee --- /dev/null +++ b/llm-gateway/tests/integration/__init__.py @@ -0,0 +1 @@ +# integration module diff --git a/llm-gateway/tests/integration/test_admin_api.py b/llm-gateway/tests/integration/test_admin_api.py new file mode 100644 index 0000000..b90660b --- /dev/null +++ b/llm-gateway/tests/integration/test_admin_api.py @@ -0,0 +1,237 @@ +"""Tests for Admin API endpoints.""" +import pytest +from httpx import AsyncClient + + +class TestProviderAPI: + """Test Provider management API.""" + + @pytest.mark.asyncio + async def test_create_provider(self, client: AsyncClient, setup_test_db): + """Test creating a provider.""" + response = await client.post( + "/admin/providers", + json={ + "name": "openai", + "api_base": "https://api.openai.com/v1", + "api_key": "test-api-key-1234567890", + "api_version": "v1", + "rpm_limit": 1000, + "tpm_limit": 100000, + }, + ) + assert response.status_code == 201 + data = response.json() + assert data["name"] == "openai" + assert data["api_base"] == "https://api.openai.com/v1" + assert data["health_status"] == "healthy" + + @pytest.mark.asyncio + async def test_list_providers(self, client: AsyncClient, setup_test_db): + """Test listing providers.""" + # Create a provider first + await client.post( + "/admin/providers", + json={ + "name": "anthropic", + "api_base": "https://api.anthropic.com", + "api_key": "test-key", + }, + ) + + response = await client.get("/admin/providers") + assert response.status_code == 200 + data = response.json() + assert data["total"] >= 1 + assert len(data["providers"]) >= 1 + + @pytest.mark.asyncio + async def test_get_provider(self, client: AsyncClient, setup_test_db): + """Test getting a specific provider.""" + create_response = await client.post( + "/admin/providers", + json={ + "name": "gemini", + "api_base": "https://generativelanguage.googleapis.com", + "api_key": "test-key", + }, + ) + provider_id = create_response.json()["id"] + + response = await client.get(f"/admin/providers/{provider_id}") + assert response.status_code == 200 + data = response.json() + assert data["name"] == "gemini" + + @pytest.mark.asyncio + async def test_update_provider(self, client: AsyncClient, setup_test_db): + """Test updating a provider.""" + create_response = await client.post( + "/admin/providers", + json={ + "name": "azure", + "api_base": "https://example.openai.azure.com", + "api_key": "test-key", + }, + ) + provider_id = create_response.json()["id"] + + response = await client.put( + f"/admin/providers/{provider_id}", + json={"rpm_limit": 500, "enabled": False}, + ) + assert response.status_code == 200 + data = response.json() + assert data["rpm_limit"] == 500 + assert data["enabled"] is False + + @pytest.mark.asyncio + async def test_delete_provider(self, client: AsyncClient, setup_test_db): + """Test deleting a provider.""" + create_response = await client.post( + "/admin/providers", + json={ + "name": "bedrock", + "api_base": "https://bedrock.us-east-1.amazonaws.com", + "api_key": "test-key", + }, + ) + provider_id = create_response.json()["id"] + + response = await client.delete(f"/admin/providers/{provider_id}") + assert response.status_code == 204 + + # Verify it's deleted + get_response = await client.get(f"/admin/providers/{provider_id}") + assert get_response.status_code == 404 + + +class TestProjectAPI: + """Test Project management API.""" + + @pytest.mark.asyncio + async def test_create_project(self, client: AsyncClient, setup_test_db): + """Test creating a project.""" + response = await client.post( + "/admin/projects", + json={ + "name": "Test Project", + "description": "A test project", + "budget_limit": 100.00, + "budget_period": "monthly", + }, + ) + assert response.status_code == 201 + data = response.json() + assert data["name"] == "Test Project" + assert data["budget_limit"] == "100.00" + + @pytest.mark.asyncio + async def test_list_projects(self, client: AsyncClient, setup_test_db): + """Test listing projects.""" + await client.post( + "/admin/projects", + json={"name": "Project 1"}, + ) + + response = await client.get("/admin/projects") + assert response.status_code == 200 + data = response.json() + assert data["total"] >= 1 + + +class TestAPIKeyAPI: + """Test API Key management API.""" + + @pytest.mark.asyncio + async def test_create_api_key(self, client: AsyncClient, setup_test_db): + """Test creating an API key.""" + response = await client.post( + "/admin/keys", + json={ + "name": "Test Key", + "prefix": "sk-test", + "rpm_limit": 100, + "tpm_limit": 10000, + }, + ) + assert response.status_code == 201 + data = response.json() + assert data["name"] == "Test Key" + assert "key" in data # Full key shown on creation + assert data["key"].startswith("sk-test_") + assert data["key_prefix"].startswith("sk-test_") + + @pytest.mark.asyncio + async def test_list_api_keys(self, client: AsyncClient, setup_test_db): + """Test listing API keys.""" + await client.post( + "/admin/keys", + json={"name": "Key 1"}, + ) + + response = await client.get("/admin/keys") + assert response.status_code == 200 + data = response.json() + assert data["total"] >= 1 + + +class TestModelAliasAPI: + """Test Model Alias management API.""" + + @pytest.mark.asyncio + async def test_create_model_alias(self, client: AsyncClient, setup_test_db): + """Test creating a model alias.""" + response = await client.post( + "/admin/models/aliases", + json={ + "alias": "gpt-4", + "provider": "openai", + "model": "gpt-4-turbo", + "input_price_per_1k": 0.01, + "output_price_per_1k": 0.03, + }, + ) + assert response.status_code == 201 + data = response.json() + assert data["alias"] == "gpt-4" + assert data["provider"] == "openai" + assert data["model"] == "gpt-4-turbo" + + @pytest.mark.asyncio + async def test_list_model_aliases(self, client: AsyncClient, setup_test_db): + """Test listing model aliases.""" + await client.post( + "/admin/models/aliases", + json={ + "alias": "claude", + "provider": "anthropic", + "model": "claude-3-sonnet", + }, + ) + + response = await client.get("/admin/models/aliases") + assert response.status_code == 200 + data = response.json() + assert data["total"] >= 1 + + +class TestHealthAPI: + """Test Health check API.""" + + @pytest.mark.asyncio + async def test_health_check(self, client: AsyncClient, setup_test_db): + """Test health check endpoint.""" + response = await client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + + @pytest.mark.asyncio + async def test_readiness_check(self, client: AsyncClient, setup_test_db): + """Test readiness check endpoint.""" + response = await client.get("/ready") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ready" + assert data["database"] == "connected" diff --git a/llm-gateway/tests/unit/__init__.py b/llm-gateway/tests/unit/__init__.py new file mode 100644 index 0000000..d2efcc5 --- /dev/null +++ b/llm-gateway/tests/unit/__init__.py @@ -0,0 +1 @@ +# unit module diff --git a/llm-gateway/tests/unit/test_config.py b/llm-gateway/tests/unit/test_config.py new file mode 100644 index 0000000..97fe8ad --- /dev/null +++ b/llm-gateway/tests/unit/test_config.py @@ -0,0 +1,44 @@ +"""Tests for configuration module.""" +import os +import pytest +from pathlib import Path + +from app.config import Settings + + +class TestSettings: + """Test configuration settings.""" + + def test_default_settings(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + """Test default configuration values.""" + # Change to temp directory for database path + monkeypatch.chdir(tmp_path) + + settings = Settings() + + assert settings.app_name == "LLM Gateway" + assert settings.debug is False + assert settings.database_url.startswith("sqlite:///") + assert settings.master_key is not None + + def test_custom_settings_from_env(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + """Test loading settings from environment variables.""" + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("APP_NAME", "Custom Gateway") + monkeypatch.setenv("DEBUG", "true") + monkeypatch.setenv("DATABASE_URL", "sqlite:///custom.db") + monkeypatch.setenv("MASTER_KEY", "test-master-key-12345678901234567890") # 40 chars + + settings = Settings() + + assert settings.app_name == "Custom Gateway" + assert settings.debug is True + assert settings.database_url == "sqlite:///custom.db" + assert settings.master_key == "test-master-key-12345678901234567890" + + def test_generate_master_key(self): + """Test master key generation.""" + key = Settings.generate_master_key() + + assert len(key) == 64 # 32 bytes = 64 hex chars + assert isinstance(key, str) diff --git a/llm-gateway/tests/unit/test_database.py b/llm-gateway/tests/unit/test_database.py new file mode 100644 index 0000000..6c6f5dd --- /dev/null +++ b/llm-gateway/tests/unit/test_database.py @@ -0,0 +1,34 @@ +"""Tests for database module.""" +import pytest + +from app.db.database import init_db, get_db +from app.models import Provider, APIKey, Project + + +class TestDatabase: + """Test database initialization and connections.""" + + @pytest.mark.asyncio + async def test_init_db_creates_tables(self, db_session): + """Test that init_db creates all required tables.""" + # Tables should exist if we got a session + assert db_session is not None + + # Try to query each table + from sqlalchemy import text + + result = await db_session.execute(text("SELECT name FROM sqlite_master WHERE type='table'")) + tables = [row[0] for row in result.fetchall()] + + assert "providers" in tables + assert "api_keys" in tables + assert "projects" in tables + assert "model_aliases" in tables + assert "request_logs" in tables + + @pytest.mark.asyncio + async def test_get_db_returns_session(self, setup_test_db): + """Test that get_db returns a valid database session.""" + async for session in get_db(): + assert session is not None + break