feat: implement LLM Gateway with multi-provider support
Implement a unified LLM Gateway supporting multiple API formats and providers: Features: - OpenAI Chat Completions, Responses API, and Anthropic Messages API - Provider adapters for OpenAI, Anthropic, Azure OpenAI, Google Gemini, AWS Bedrock - Model aliasing with weighted round-robin load balancing - Virtual API keys with RPM/TPM rate limiting - Budget control at key and project levels - Request logging, usage statistics, and audit logs - Fallback/retry with circuit breaker pattern - Admin CRUD APIs for providers, projects, keys, models, usage - Provider health checks Tech stack: - FastAPI with async SQLAlchemy 2.0 - SQLite with aiosqlite - bcrypt for API key hashing, AES-256 for provider key encryption - Docker containerization Tests: 18 passing integration tests for admin API endpoints Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
8f550a2100
commit
8348520bdf
29
llm-gateway/.env.example
Normal file
29
llm-gateway/.env.example
Normal file
@ -0,0 +1,29 @@
|
||||
# Application
|
||||
APP_NAME=LLM Gateway
|
||||
DEBUG=false
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
# Server
|
||||
HOST=0.0.0.0
|
||||
PORT=8000
|
||||
|
||||
# Database
|
||||
DATABASE_URL=sqlite:///data/gateway.db
|
||||
|
||||
# Security
|
||||
# Generate with: python -c "import secrets; print(secrets.token_hex(32))"
|
||||
MASTER_KEY=your-master-key-here-at-least-32-characters
|
||||
|
||||
# Rate Limiting
|
||||
RATE_LIMIT_WINDOW_SECONDS=60
|
||||
# GLOBAL_RPM_LIMIT=1000
|
||||
# GLOBAL_TPM_LIMIT=1000000
|
||||
|
||||
# Retry
|
||||
MAX_RETRIES=3
|
||||
RETRY_INITIAL_DELAY=1.0
|
||||
RETRY_MAX_DELAY=30.0
|
||||
|
||||
# Health Check
|
||||
HEALTH_CHECK_INTERVAL=30
|
||||
HEALTH_CHECK_TIMEOUT=10
|
||||
65
llm-gateway/.gitignore
vendored
Normal file
65
llm-gateway/.gitignore
vendored
Normal file
@ -0,0 +1,65 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Virtual environments
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
.venv/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Testing
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
|
||||
# Type checking
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Environment
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
|
||||
# Database
|
||||
*.db
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
data/
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
28
llm-gateway/Dockerfile
Normal file
28
llm-gateway/Dockerfile
Normal file
@ -0,0 +1,28 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements first for caching
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Create data directory
|
||||
RUN mkdir -p /app/data
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
# Run the application
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
41
llm-gateway/README.md
Normal file
41
llm-gateway/README.md
Normal file
@ -0,0 +1,41 @@
|
||||
# LLM Gateway
|
||||
|
||||
A unified LLM Gateway supporting multiple providers and API formats.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multi-format API**: OpenAI Chat Completions, OpenAI Responses API, Anthropic Messages API
|
||||
- **Multi-provider support**: OpenAI, Anthropic, Azure OpenAI, Google Gemini, AWS Bedrock
|
||||
- **Model aliasing and routing**: Flexible model mapping with load balancing
|
||||
- **Rate limiting**: RPM/TPM limits at multiple levels
|
||||
- **Budget control**: Key and Project level spending limits
|
||||
- **High availability**: Fallback, retry, and circuit breaker
|
||||
- **Observability**: Request logging and usage statistics
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Run the server
|
||||
uvicorn app.main:app --reload
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### LLM APIs
|
||||
- `POST /v1/chat/completions` - OpenAI-compatible chat completions
|
||||
- `POST /v1/responses` - OpenAI Responses API
|
||||
- `POST /v1/messages` - Anthropic Messages API
|
||||
|
||||
### Admin APIs
|
||||
- `GET|POST|PUT|DELETE /admin/providers` - Provider management
|
||||
- `GET|POST|PUT|DELETE /admin/keys` - API Key management
|
||||
- `GET|POST|PUT|DELETE /admin/projects` - Project management
|
||||
- `GET|POST|PUT|DELETE /admin/models/aliases` - Model alias management
|
||||
- `GET /admin/usage/stats` - Usage statistics
|
||||
|
||||
## Configuration
|
||||
|
||||
See `.env.example` for configuration options.
|
||||
1
llm-gateway/app/__init__.py
Normal file
1
llm-gateway/app/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# app module
|
||||
69
llm-gateway/app/adapters/__init__.py
Normal file
69
llm-gateway/app/adapters/__init__.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""Provider adapters package."""
|
||||
from typing import Type
|
||||
|
||||
from app.adapters.anthropic import AnthropicAdapter
|
||||
from app.adapters.azure import AzureOpenAIAdapter
|
||||
from app.adapters.base import BaseAdapter, HealthStatus, ProviderConfig
|
||||
from app.adapters.bedrock import BedrockAdapter
|
||||
from app.adapters.gemini import GeminiAdapter
|
||||
from app.adapters.openai import OpenAIAdapter
|
||||
|
||||
|
||||
# Registry of provider adapters
|
||||
ADAPTER_REGISTRY: dict[str, Type[BaseAdapter]] = {
|
||||
"openai": OpenAIAdapter,
|
||||
"anthropic": AnthropicAdapter,
|
||||
"azure": AzureOpenAIAdapter,
|
||||
"gemini": GeminiAdapter,
|
||||
"bedrock": BedrockAdapter,
|
||||
"google": GeminiAdapter, # Alias
|
||||
"aws": BedrockAdapter, # Alias
|
||||
}
|
||||
|
||||
|
||||
def get_adapter_class(provider: str) -> Type[BaseAdapter]:
|
||||
"""
|
||||
Get the adapter class for a provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name.
|
||||
|
||||
Returns:
|
||||
Adapter class.
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not supported.
|
||||
"""
|
||||
provider_lower = provider.lower()
|
||||
if provider_lower not in ADAPTER_REGISTRY:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
return ADAPTER_REGISTRY[provider_lower]
|
||||
|
||||
|
||||
def create_adapter(config: ProviderConfig) -> BaseAdapter:
|
||||
"""
|
||||
Create an adapter instance for a provider.
|
||||
|
||||
Args:
|
||||
config: Provider configuration.
|
||||
|
||||
Returns:
|
||||
Adapter instance.
|
||||
"""
|
||||
adapter_class = get_adapter_class(config.name)
|
||||
return adapter_class(config)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseAdapter",
|
||||
"HealthStatus",
|
||||
"ProviderConfig",
|
||||
"OpenAIAdapter",
|
||||
"AnthropicAdapter",
|
||||
"AzureOpenAIAdapter",
|
||||
"GeminiAdapter",
|
||||
"BedrockAdapter",
|
||||
"get_adapter_class",
|
||||
"create_adapter",
|
||||
"ADAPTER_REGISTRY",
|
||||
]
|
||||
206
llm-gateway/app/adapters/anthropic.py
Normal file
206
llm-gateway/app/adapters/anthropic.py
Normal file
@ -0,0 +1,206 @@
|
||||
"""Anthropic provider adapter."""
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import AsyncIterator
|
||||
|
||||
import httpx
|
||||
|
||||
from app.adapters.base import BaseAdapter, HealthStatus, ProviderConfig
|
||||
from app.core.fallback import RetryableError, PermanentError, classify_error
|
||||
from app.core.transformer import RequestTransformer
|
||||
from app.schemas.anthropic import (
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicMessagesResponse,
|
||||
AnthropicTextBlock,
|
||||
AnthropicUsage,
|
||||
)
|
||||
from app.schemas.openai import (
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequest,
|
||||
OpenAIChatCompletionResponse,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicAdapter(BaseAdapter):
|
||||
"""Adapter for Anthropic Claude API."""
|
||||
|
||||
def __init__(self, config: ProviderConfig):
|
||||
super().__init__(config)
|
||||
self.transformer = RequestTransformer()
|
||||
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
"""Get headers for Anthropic API requests."""
|
||||
headers = {
|
||||
**super().get_headers(),
|
||||
"x-api-key": self.config.api_key,
|
||||
"anthropic-version": self.config.api_version or "2023-06-01",
|
||||
}
|
||||
return headers
|
||||
|
||||
async def chat_completions(
|
||||
self,
|
||||
request: OpenAIChatCompletionRequest,
|
||||
) -> OpenAIChatCompletionResponse:
|
||||
"""
|
||||
Execute an OpenAI-format chat completion request via Anthropic.
|
||||
Converts request to Anthropic format, calls Anthropic, converts response back.
|
||||
"""
|
||||
# Convert OpenAI request to Anthropic format
|
||||
anthropic_request = self.transformer.openai_to_anthropic(request)
|
||||
|
||||
# Call Anthropic
|
||||
anthropic_response = await self.messages(anthropic_request)
|
||||
|
||||
# Convert response back to OpenAI format
|
||||
return self.transformer.anthropic_response_to_openai(
|
||||
anthropic_response, request.model
|
||||
)
|
||||
|
||||
async def stream_chat_completions(
|
||||
self,
|
||||
request: OpenAIChatCompletionRequest,
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Execute a streaming chat completion request via Anthropic."""
|
||||
# Convert to Anthropic format
|
||||
anthropic_request = self.transformer.openai_to_anthropic(request)
|
||||
anthropic_request.stream = True
|
||||
|
||||
url = f"{self.config.api_base}/v1/messages"
|
||||
payload = anthropic_request.model_dump(exclude_none=True)
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
try:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
url,
|
||||
headers=self.get_headers(),
|
||||
json=payload,
|
||||
) as response:
|
||||
if response.status_code >= 400:
|
||||
error = classify_error(
|
||||
Exception(f"{response.status_code}: {await response.aread()}")
|
||||
)
|
||||
raise error
|
||||
|
||||
message_id = None
|
||||
model = request.model
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
|
||||
data = line[6:]
|
||||
try:
|
||||
event = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
event_type = event.get("type")
|
||||
|
||||
if event_type == "message_start":
|
||||
message_id = event.get("message", {}).get("id")
|
||||
|
||||
elif event_type == "content_block_delta":
|
||||
delta = event.get("delta", {})
|
||||
text = delta.get("text", "")
|
||||
|
||||
chunk = OpenAIChatCompletionChunk(
|
||||
id=message_id or f"chatcmpl-{uuid.uuid4().hex[:24]}",
|
||||
object="chat.completion.chunk",
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": text, "role": "assistant"},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
yield chunk
|
||||
|
||||
elif event_type == "message_delta":
|
||||
delta = event.get("delta", {})
|
||||
stop_reason = delta.get("stop_reason")
|
||||
|
||||
if stop_reason:
|
||||
finish_reason = "stop" if stop_reason == "end_turn" else "length"
|
||||
chunk = OpenAIChatCompletionChunk(
|
||||
id=message_id or f"chatcmpl-{uuid.uuid4().hex[:24]}",
|
||||
object="chat.completion.chunk",
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
)
|
||||
yield chunk
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
raise RetryableError(str(e), error_type="timeout")
|
||||
except httpx.ConnectError as e:
|
||||
raise RetryableError(str(e), error_type="connection_error")
|
||||
|
||||
async def messages(
|
||||
self,
|
||||
request: AnthropicMessagesRequest,
|
||||
) -> AnthropicMessagesResponse:
|
||||
"""Execute an Anthropic Messages API request."""
|
||||
url = f"{self.config.api_base}/v1/messages"
|
||||
|
||||
payload = request.model_dump(exclude_none=True)
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=self.get_headers(),
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status_code >= 400:
|
||||
error = classify_error(Exception(f"{response.status_code}: {response.text}"))
|
||||
raise error
|
||||
|
||||
data = response.json()
|
||||
return AnthropicMessagesResponse(**data)
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
raise RetryableError(str(e), error_type="timeout")
|
||||
except httpx.ConnectError as e:
|
||||
raise RetryableError(str(e), error_type="connection_error")
|
||||
|
||||
async def check_health(self) -> HealthStatus:
|
||||
"""Check Anthropic API health by making a minimal request."""
|
||||
# Anthropic doesn't have a health endpoint, so we make a minimal request
|
||||
url = f"{self.config.api_base}/v1/messages"
|
||||
|
||||
# Minimal request
|
||||
payload = {
|
||||
"model": "claude-3-haiku-20240307",
|
||||
"max_tokens": 1,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=self.get_headers(),
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return HealthStatus.HEALTHY
|
||||
elif response.status_code >= 500:
|
||||
return HealthStatus.UNHEALTHY
|
||||
else:
|
||||
return HealthStatus.DEGRADED
|
||||
|
||||
except Exception:
|
||||
return HealthStatus.UNHEALTHY
|
||||
79
llm-gateway/app/adapters/azure.py
Normal file
79
llm-gateway/app/adapters/azure.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""Azure OpenAI provider adapter."""
|
||||
from typing import AsyncIterator
|
||||
|
||||
from app.adapters.base import BaseAdapter, HealthStatus, ProviderConfig
|
||||
from app.adapters.openai import OpenAIAdapter
|
||||
from app.schemas.anthropic import AnthropicMessagesRequest, AnthropicMessagesResponse
|
||||
from app.schemas.openai import (
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequest,
|
||||
OpenAIChatCompletionResponse,
|
||||
)
|
||||
|
||||
|
||||
class AzureOpenAIAdapter(BaseAdapter):
|
||||
"""Adapter for Azure OpenAI API."""
|
||||
|
||||
def __init__(self, config: ProviderConfig):
|
||||
super().__init__(config)
|
||||
# Azure uses OpenAI-compatible API
|
||||
self._openai_adapter: OpenAIAdapter | None = None
|
||||
|
||||
def _get_openai_adapter(self) -> OpenAIAdapter:
|
||||
"""Get or create OpenAI adapter for Azure."""
|
||||
if self._openai_adapter is None:
|
||||
# Azure config should include deployment_name
|
||||
azure_config = self.config.config or {}
|
||||
deployment_name = azure_config.get("deployment_name", "")
|
||||
|
||||
# Build Azure-specific API base
|
||||
api_base = self.config.api_base
|
||||
if deployment_name:
|
||||
api_base = f"{api_base}/openai/deployments/{deployment_name}"
|
||||
|
||||
openai_config = ProviderConfig(
|
||||
name=self.config.name,
|
||||
api_base=api_base,
|
||||
api_key=self.config.api_key,
|
||||
api_version=self.config.api_version,
|
||||
)
|
||||
self._openai_adapter = OpenAIAdapter(openai_config)
|
||||
return self._openai_adapter
|
||||
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
"""Get headers for Azure OpenAI API requests."""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.config.api_key, # Azure uses api-key header
|
||||
}
|
||||
|
||||
async def chat_completions(
|
||||
self,
|
||||
request: OpenAIChatCompletionRequest,
|
||||
) -> OpenAIChatCompletionResponse:
|
||||
"""Execute a chat completion request to Azure OpenAI."""
|
||||
adapter = self._get_openai_adapter()
|
||||
return await adapter.chat_completions(request)
|
||||
|
||||
async def stream_chat_completions(
|
||||
self,
|
||||
request: OpenAIChatCompletionRequest,
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Execute a streaming chat completion request to Azure OpenAI."""
|
||||
adapter = self._get_openai_adapter()
|
||||
async for chunk in adapter.stream_chat_completions(request):
|
||||
yield chunk
|
||||
|
||||
async def messages(
|
||||
self,
|
||||
request: AnthropicMessagesRequest,
|
||||
) -> AnthropicMessagesResponse:
|
||||
"""Execute an Anthropic Messages API request via Azure OpenAI."""
|
||||
adapter = self._get_openai_adapter()
|
||||
return await adapter.messages(request)
|
||||
|
||||
async def check_health(self) -> HealthStatus:
|
||||
"""Check Azure OpenAI API health."""
|
||||
# Use OpenAI-style health check
|
||||
adapter = self._get_openai_adapter()
|
||||
return await adapter.check_health()
|
||||
135
llm-gateway/app/adapters/base.py
Normal file
135
llm-gateway/app/adapters/base.py
Normal file
@ -0,0 +1,135 @@
|
||||
"""Base adapter interface for LLM providers."""
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
from app.schemas.anthropic import AnthropicMessagesRequest, AnthropicMessagesResponse
|
||||
from app.schemas.openai import (
|
||||
OpenAIChatCompletionRequest,
|
||||
OpenAIChatCompletionResponse,
|
||||
OpenAIChatCompletionChunk,
|
||||
)
|
||||
|
||||
|
||||
class HealthStatus(Enum):
|
||||
"""Provider health status."""
|
||||
|
||||
HEALTHY = "healthy"
|
||||
UNHEALTHY = "unhealthy"
|
||||
DEGRADED = "degraded"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderConfig:
|
||||
"""Configuration for a provider."""
|
||||
|
||||
name: str
|
||||
api_base: str
|
||||
api_key: str
|
||||
api_version: str | None = None
|
||||
config: dict[str, Any] | None = None
|
||||
rpm_limit: int | None = None
|
||||
tpm_limit: int | None = None
|
||||
|
||||
|
||||
class BaseAdapter(ABC):
|
||||
"""Abstract base class for LLM provider adapters."""
|
||||
|
||||
def __init__(self, config: ProviderConfig):
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completions(
|
||||
self,
|
||||
request: OpenAIChatCompletionRequest,
|
||||
) -> OpenAIChatCompletionResponse:
|
||||
"""
|
||||
Execute a chat completion request.
|
||||
|
||||
Args:
|
||||
request: OpenAI-format chat completion request.
|
||||
|
||||
Returns:
|
||||
OpenAI-format chat completion response.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stream_chat_completions(
|
||||
self,
|
||||
request: OpenAIChatCompletionRequest,
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""
|
||||
Execute a streaming chat completion request.
|
||||
|
||||
Args:
|
||||
request: OpenAI-format chat completion request.
|
||||
|
||||
Yields:
|
||||
OpenAI-format chat completion chunks.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def messages(
|
||||
self,
|
||||
request: AnthropicMessagesRequest,
|
||||
) -> AnthropicMessagesResponse:
|
||||
"""
|
||||
Execute an Anthropic Messages API request.
|
||||
|
||||
Args:
|
||||
request: Anthropic-format messages request.
|
||||
|
||||
Returns:
|
||||
Anthropic-format messages response.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
request: OpenAIChatCompletionRequest | AnthropicMessagesRequest,
|
||||
) -> int:
|
||||
"""
|
||||
Estimate token count for a request.
|
||||
|
||||
Args:
|
||||
request: The request to count tokens for.
|
||||
|
||||
Returns:
|
||||
Estimated token count.
|
||||
"""
|
||||
# Default implementation: rough estimation
|
||||
if isinstance(request, OpenAIChatCompletionRequest):
|
||||
total = 0
|
||||
for msg in request.messages:
|
||||
if isinstance(msg.content, str):
|
||||
total += len(msg.content.split())
|
||||
return int(total * 1.3) # Rough word-to-token ratio
|
||||
elif isinstance(request, AnthropicMessagesRequest):
|
||||
total = 0
|
||||
for msg in request.messages:
|
||||
if isinstance(msg.content, str):
|
||||
total += len(msg.content.split())
|
||||
if request.system:
|
||||
total += len(request.system.split())
|
||||
return int(total * 1.3)
|
||||
return 0
|
||||
|
||||
@abstractmethod
|
||||
async def check_health(self) -> HealthStatus:
|
||||
"""
|
||||
Check the health of the provider.
|
||||
|
||||
Returns:
|
||||
Health status of the provider.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
"""Get common headers for requests."""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
240
llm-gateway/app/adapters/bedrock.py
Normal file
240
llm-gateway/app/adapters/bedrock.py
Normal file
@ -0,0 +1,240 @@
|
||||
"""AWS Bedrock provider adapter."""
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import AsyncIterator, Any
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import ClientError
|
||||
HAS_BOTO3 = True
|
||||
except ImportError:
|
||||
HAS_BOTO3 = False
|
||||
|
||||
from app.adapters.base import BaseAdapter, HealthStatus, ProviderConfig
|
||||
from app.core.fallback import RetryableError, classify_error
|
||||
from app.core.transformer import RequestTransformer
|
||||
from app.schemas.anthropic import (
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicMessagesResponse,
|
||||
)
|
||||
from app.schemas.openai import (
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequest,
|
||||
OpenAIChatCompletionResponse,
|
||||
OpenAIChatCompletionChoice,
|
||||
OpenAIChatMessage,
|
||||
OpenAIUsage,
|
||||
)
|
||||
|
||||
|
||||
class BedrockAdapter(BaseAdapter):
|
||||
"""Adapter for AWS Bedrock API."""
|
||||
|
||||
def __init__(self, config: ProviderConfig):
|
||||
super().__init__(config)
|
||||
self.transformer = RequestTransformer()
|
||||
self._client = None
|
||||
|
||||
def _get_client(self):
|
||||
"""Get or create Bedrock runtime client."""
|
||||
if self._client is None and HAS_BOTO3:
|
||||
aws_config = self.config.config or {}
|
||||
region = aws_config.get("region", "us-east-1")
|
||||
|
||||
self._client = boto3.client(
|
||||
"bedrock-runtime",
|
||||
region_name=region,
|
||||
config=Config(
|
||||
retries={"max_attempts": 3, "mode": "adaptive"},
|
||||
connect_timeout=10,
|
||||
read_timeout=120,
|
||||
),
|
||||
)
|
||||
return self._client
|
||||
|
||||
def _openai_to_bedrock_anthropic(
|
||||
self, request: OpenAIChatCompletionRequest
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
"""Convert OpenAI request to Bedrock Anthropic format."""
|
||||
messages = []
|
||||
system = None
|
||||
|
||||
for msg in request.messages:
|
||||
if msg.role == "system":
|
||||
system = msg.content if isinstance(msg.content, str) else None
|
||||
else:
|
||||
messages.append({
|
||||
"role": msg.role,
|
||||
"content": [{"text": msg.content}] if isinstance(msg.content, str) else msg.content,
|
||||
})
|
||||
|
||||
bedrock_request = {
|
||||
"messages": messages,
|
||||
"max_tokens": request.max_tokens or 4096,
|
||||
}
|
||||
|
||||
if system:
|
||||
bedrock_request["system"] = system
|
||||
if request.temperature is not None:
|
||||
bedrock_request["temperature"] = request.temperature
|
||||
if request.top_p is not None:
|
||||
bedrock_request["top_p"] = request.top_p
|
||||
|
||||
# Return model ID and request body
|
||||
model_id = request.model
|
||||
return model_id, bedrock_request
|
||||
|
||||
def _bedrock_to_openai(
|
||||
self,
|
||||
response: dict[str, Any],
|
||||
model: str,
|
||||
) -> OpenAIChatCompletionResponse:
|
||||
"""Convert Bedrock Anthropic response to OpenAI format."""
|
||||
content = ""
|
||||
output = response.get("output", {})
|
||||
message = output.get("message", {})
|
||||
|
||||
for block in message.get("content", []):
|
||||
if "text" in block:
|
||||
content += block["text"]
|
||||
|
||||
stop_reason = response.get("stopReason", "end_turn")
|
||||
finish_reason_map = {
|
||||
"end_turn": "stop",
|
||||
"max_tokens": "length",
|
||||
"stop_sequence": "stop",
|
||||
"tool_use": "tool_calls",
|
||||
}
|
||||
finish_reason = finish_reason_map.get(stop_reason, "stop")
|
||||
|
||||
usage = response.get("usage", {})
|
||||
|
||||
return OpenAIChatCompletionResponse(
|
||||
id=f"bedrock-{uuid.uuid4().hex[:24]}",
|
||||
object="chat.completion",
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
choices=[
|
||||
OpenAIChatCompletionChoice(
|
||||
index=0,
|
||||
message=OpenAIChatMessage(role="assistant", content=content),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
usage=OpenAIUsage(
|
||||
prompt_tokens=usage.get("inputTokens", 0),
|
||||
completion_tokens=usage.get("outputTokens", 0),
|
||||
total_tokens=usage.get("inputTokens", 0) + usage.get("outputTokens", 0),
|
||||
),
|
||||
)
|
||||
|
||||
async def chat_completions(
|
||||
self,
|
||||
request: OpenAIChatCompletionRequest,
|
||||
) -> OpenAIChatCompletionResponse:
|
||||
"""Execute a chat completion request to Bedrock."""
|
||||
client = self._get_client()
|
||||
model_id, bedrock_request = self._openai_to_bedrock_anthropic(request)
|
||||
|
||||
try:
|
||||
response = client.invoke_model(
|
||||
modelId=model_id,
|
||||
contentType="application/json",
|
||||
accept="application/json",
|
||||
body=json.dumps(bedrock_request),
|
||||
)
|
||||
|
||||
response_body = json.loads(response["body"].read())
|
||||
return self._bedrock_to_openai(response_body, request.model)
|
||||
|
||||
except ClientError as e:
|
||||
error = classify_error(Exception(str(e)))
|
||||
raise error
|
||||
|
||||
async def stream_chat_completions(
|
||||
self,
|
||||
request: OpenAIChatCompletionRequest,
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Execute a streaming chat completion request to Bedrock."""
|
||||
client = self._get_client()
|
||||
model_id, bedrock_request = self._openai_to_bedrock_anthropic(request)
|
||||
|
||||
try:
|
||||
response = client.invoke_model_with_response_stream(
|
||||
modelId=model_id,
|
||||
contentType="application/json",
|
||||
accept="application/json",
|
||||
body=json.dumps(bedrock_request),
|
||||
)
|
||||
|
||||
chunk_id = f"bedrock-{uuid.uuid4().hex[:24]}"
|
||||
|
||||
for event in response["body"]:
|
||||
chunk_data = json.loads(event["chunk"]["bytes"])
|
||||
|
||||
if chunk_data.get("type") == "content_block_delta":
|
||||
delta = chunk_data.get("delta", {})
|
||||
text = delta.get("text", "")
|
||||
|
||||
if text:
|
||||
chunk = OpenAIChatCompletionChunk(
|
||||
id=chunk_id,
|
||||
object="chat.completion.chunk",
|
||||
created=int(time.time()),
|
||||
model=request.model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": text},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
yield chunk
|
||||
|
||||
elif chunk_data.get("type") == "message_delta":
|
||||
stop_reason = chunk_data.get("delta", {}).get("stop_reason")
|
||||
if stop_reason:
|
||||
finish_reason = "stop" if stop_reason == "end_turn" else "length"
|
||||
chunk = OpenAIChatCompletionChunk(
|
||||
id=chunk_id,
|
||||
object="chat.completion.chunk",
|
||||
created=int(time.time()),
|
||||
model=request.model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
)
|
||||
yield chunk
|
||||
|
||||
except ClientError as e:
|
||||
error = classify_error(Exception(str(e)))
|
||||
raise error
|
||||
|
||||
async def messages(
|
||||
self,
|
||||
request: AnthropicMessagesRequest,
|
||||
) -> AnthropicMessagesResponse:
|
||||
"""Execute an Anthropic Messages API request via Bedrock."""
|
||||
openai_request = self.transformer.anthropic_to_openai(request)
|
||||
openai_response = await self.chat_completions(openai_request)
|
||||
return self.transformer.openai_response_to_anthropic(openai_response)
|
||||
|
||||
async def check_health(self) -> HealthStatus:
|
||||
"""Check Bedrock API health."""
|
||||
if not HAS_BOTO3:
|
||||
return HealthStatus.UNHEALTHY
|
||||
|
||||
client = self._get_client()
|
||||
try:
|
||||
# List available models to check health
|
||||
client.list_foundation_models()
|
||||
return HealthStatus.HEALTHY
|
||||
except Exception:
|
||||
return HealthStatus.UNHEALTHY
|
||||
236
llm-gateway/app/adapters/gemini.py
Normal file
236
llm-gateway/app/adapters/gemini.py
Normal file
@ -0,0 +1,236 @@
|
||||
"""Google Gemini provider adapter."""
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import AsyncIterator, Any
|
||||
|
||||
import httpx
|
||||
|
||||
from app.adapters.base import BaseAdapter, HealthStatus, ProviderConfig
|
||||
from app.core.fallback import RetryableError, classify_error
|
||||
from app.core.transformer import RequestTransformer
|
||||
from app.schemas.anthropic import (
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicMessagesResponse,
|
||||
AnthropicTextBlock,
|
||||
AnthropicUsage,
|
||||
)
|
||||
from app.schemas.openai import (
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequest,
|
||||
OpenAIChatCompletionResponse,
|
||||
OpenAIChatCompletionChoice,
|
||||
OpenAIChatMessage,
|
||||
OpenAIUsage,
|
||||
)
|
||||
|
||||
|
||||
class GeminiAdapter(BaseAdapter):
|
||||
"""Adapter for Google Gemini API."""
|
||||
|
||||
def __init__(self, config: ProviderConfig):
|
||||
super().__init__(config)
|
||||
self.transformer = RequestTransformer()
|
||||
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
"""Get headers for Gemini API requests."""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _openai_to_gemini(self, request: OpenAIChatCompletionRequest) -> dict[str, Any]:
|
||||
"""Convert OpenAI request to Gemini format."""
|
||||
contents = []
|
||||
system_instruction = None
|
||||
|
||||
for msg in request.messages:
|
||||
if msg.role == "system":
|
||||
system_instruction = {"parts": [{"text": msg.content}]}
|
||||
else:
|
||||
role = "user" if msg.role == "user" else "model"
|
||||
if isinstance(msg.content, str):
|
||||
parts = [{"text": msg.content}]
|
||||
else:
|
||||
parts = [{"text": str(part)} for part in msg.content]
|
||||
contents.append({"role": role, "parts": parts})
|
||||
|
||||
gemini_request = {
|
||||
"contents": contents,
|
||||
"generationConfig": {},
|
||||
}
|
||||
|
||||
if system_instruction:
|
||||
gemini_request["systemInstruction"] = system_instruction
|
||||
|
||||
if request.temperature is not None:
|
||||
gemini_request["generationConfig"]["temperature"] = request.temperature
|
||||
if request.max_tokens is not None:
|
||||
gemini_request["generationConfig"]["maxOutputTokens"] = request.max_tokens
|
||||
if request.top_p is not None:
|
||||
gemini_request["generationConfig"]["topP"] = request.top_p
|
||||
if request.stop:
|
||||
if isinstance(request.stop, str):
|
||||
gemini_request["generationConfig"]["stopSequences"] = [request.stop]
|
||||
else:
|
||||
gemini_request["generationConfig"]["stopSequences"] = request.stop
|
||||
|
||||
return gemini_request
|
||||
|
||||
def _gemini_to_openai(
|
||||
self,
|
||||
response: dict[str, Any],
|
||||
model: str,
|
||||
) -> OpenAIChatCompletionResponse:
|
||||
"""Convert Gemini response to OpenAI format."""
|
||||
candidates = response.get("candidates", [])
|
||||
content = ""
|
||||
finish_reason = "stop"
|
||||
|
||||
if candidates:
|
||||
candidate = candidates[0]
|
||||
content_parts = candidate.get("content", {}).get("parts", [])
|
||||
content = "".join(p.get("text", "") for p in content_parts)
|
||||
|
||||
finish_reason_map = {
|
||||
"STOP": "stop",
|
||||
"MAX_TOKENS": "length",
|
||||
"SAFETY": "content_filter",
|
||||
}
|
||||
finish_reason = finish_reason_map.get(
|
||||
candidate.get("finishReason", "STOP"), "stop"
|
||||
)
|
||||
|
||||
usage = response.get("usageMetadata", {})
|
||||
|
||||
return OpenAIChatCompletionResponse(
|
||||
id=f"gemini-{uuid.uuid4().hex[:24]}",
|
||||
object="chat.completion",
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
choices=[
|
||||
OpenAIChatCompletionChoice(
|
||||
index=0,
|
||||
message=OpenAIChatMessage(role="assistant", content=content),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
usage=OpenAIUsage(
|
||||
prompt_tokens=usage.get("promptTokenCount", 0),
|
||||
completion_tokens=usage.get("candidatesTokenCount", 0),
|
||||
total_tokens=usage.get("totalTokenCount", 0),
|
||||
),
|
||||
)
|
||||
|
||||
async def chat_completions(
|
||||
self,
|
||||
request: OpenAIChatCompletionRequest,
|
||||
) -> OpenAIChatCompletionResponse:
|
||||
"""Execute a chat completion request to Gemini."""
|
||||
# Extract model name
|
||||
model = request.model
|
||||
url = f"{self.config.api_base}/v1beta/models/{model}:generateContent?key={self.config.api_key}"
|
||||
|
||||
gemini_request = self._openai_to_gemini(request)
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=self.get_headers(),
|
||||
json=gemini_request,
|
||||
)
|
||||
|
||||
if response.status_code >= 400:
|
||||
error = classify_error(Exception(f"{response.status_code}: {response.text}"))
|
||||
raise error
|
||||
|
||||
data = response.json()
|
||||
return self._gemini_to_openai(data, model)
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
raise RetryableError(str(e), error_type="timeout")
|
||||
except httpx.ConnectError as e:
|
||||
raise RetryableError(str(e), error_type="connection_error")
|
||||
|
||||
async def stream_chat_completions(
|
||||
self,
|
||||
request: OpenAIChatCompletionRequest,
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Execute a streaming chat completion request to Gemini."""
|
||||
model = request.model
|
||||
url = f"{self.config.api_base}/v1beta/models/{model}:streamGenerateContent?key={self.config.api_key}&alt=sse"
|
||||
|
||||
gemini_request = self._openai_to_gemini(request)
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
try:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
url,
|
||||
headers=self.get_headers(),
|
||||
json=gemini_request,
|
||||
) as response:
|
||||
if response.status_code >= 400:
|
||||
error = classify_error(
|
||||
Exception(f"{response.status_code}: {await response.aread()}")
|
||||
)
|
||||
raise error
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data = line[6:]
|
||||
try:
|
||||
gemini_response = json.loads(data)
|
||||
openai_response = self._gemini_to_openai(gemini_response, model)
|
||||
|
||||
if openai_response.choices:
|
||||
chunk = OpenAIChatCompletionChunk(
|
||||
id=openai_response.id,
|
||||
object="chat.completion.chunk",
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": openai_response.choices[0].message.content
|
||||
},
|
||||
"finish_reason": openai_response.choices[0].finish_reason,
|
||||
}
|
||||
],
|
||||
)
|
||||
yield chunk
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
raise RetryableError(str(e), error_type="timeout")
|
||||
except httpx.ConnectError as e:
|
||||
raise RetryableError(str(e), error_type="connection_error")
|
||||
|
||||
async def messages(
|
||||
self,
|
||||
request: AnthropicMessagesRequest,
|
||||
) -> AnthropicMessagesResponse:
|
||||
"""Execute an Anthropic Messages API request via Gemini."""
|
||||
openai_request = self.transformer.anthropic_to_openai(request)
|
||||
openai_response = await self.chat_completions(openai_request)
|
||||
return self.transformer.openai_response_to_anthropic(openai_response)
|
||||
|
||||
async def check_health(self) -> HealthStatus:
|
||||
"""Check Gemini API health."""
|
||||
url = f"{self.config.api_base}/v1beta/models?key={self.config.api_key}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
try:
|
||||
response = await client.get(url, headers=self.get_headers())
|
||||
|
||||
if response.status_code == 200:
|
||||
return HealthStatus.HEALTHY
|
||||
elif response.status_code >= 500:
|
||||
return HealthStatus.UNHEALTHY
|
||||
else:
|
||||
return HealthStatus.DEGRADED
|
||||
|
||||
except Exception:
|
||||
return HealthStatus.UNHEALTHY
|
||||
134
llm-gateway/app/adapters/openai.py
Normal file
134
llm-gateway/app/adapters/openai.py
Normal file
@ -0,0 +1,134 @@
|
||||
"""OpenAI provider adapter."""
|
||||
import json
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
import httpx
|
||||
|
||||
from app.adapters.base import BaseAdapter, HealthStatus, ProviderConfig
|
||||
from app.core.fallback import RetryableError, PermanentError, classify_error
|
||||
from app.schemas.anthropic import AnthropicMessagesRequest, AnthropicMessagesResponse
|
||||
from app.schemas.openai import (
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequest,
|
||||
OpenAIChatCompletionResponse,
|
||||
)
|
||||
from app.core.transformer import RequestTransformer
|
||||
|
||||
|
||||
class OpenAIAdapter(BaseAdapter):
|
||||
"""Adapter for OpenAI API."""
|
||||
|
||||
def __init__(self, config: ProviderConfig):
|
||||
super().__init__(config)
|
||||
self.transformer = RequestTransformer()
|
||||
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
"""Get headers for OpenAI API requests."""
|
||||
return {
|
||||
**super().get_headers(),
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
}
|
||||
|
||||
async def chat_completions(
|
||||
self,
|
||||
request: OpenAIChatCompletionRequest,
|
||||
) -> OpenAIChatCompletionResponse:
|
||||
"""Execute a chat completion request to OpenAI."""
|
||||
url = f"{self.config.api_base}/chat/completions"
|
||||
|
||||
payload = request.model_dump(exclude_none=True)
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=self.get_headers(),
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status_code >= 400:
|
||||
error = classify_error(Exception(f"{response.status_code}: {response.text}"))
|
||||
raise error
|
||||
|
||||
data = response.json()
|
||||
return OpenAIChatCompletionResponse(**data)
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
raise RetryableError(str(e), error_type="timeout")
|
||||
except httpx.ConnectError as e:
|
||||
raise RetryableError(str(e), error_type="connection_error")
|
||||
|
||||
async def stream_chat_completions(
|
||||
self,
|
||||
request: OpenAIChatCompletionRequest,
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Execute a streaming chat completion request to OpenAI."""
|
||||
url = f"{self.config.api_base}/chat/completions"
|
||||
|
||||
payload = request.model_dump(exclude_none=True)
|
||||
payload["stream"] = True
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
try:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
url,
|
||||
headers=self.get_headers(),
|
||||
json=payload,
|
||||
) as response:
|
||||
if response.status_code >= 400:
|
||||
error = classify_error(
|
||||
Exception(f"{response.status_code}: {await response.aread()}")
|
||||
)
|
||||
raise error
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data = line[6:]
|
||||
if data == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
yield OpenAIChatCompletionChunk(**chunk)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
raise RetryableError(str(e), error_type="timeout")
|
||||
except httpx.ConnectError as e:
|
||||
raise RetryableError(str(e), error_type="connection_error")
|
||||
|
||||
async def messages(
|
||||
self,
|
||||
request: AnthropicMessagesRequest,
|
||||
) -> AnthropicMessagesResponse:
|
||||
"""
|
||||
Execute an Anthropic Messages API request via OpenAI.
|
||||
Converts request to OpenAI format, calls OpenAI, converts response back.
|
||||
"""
|
||||
# Convert Anthropic request to OpenAI format
|
||||
openai_request = self.transformer.anthropic_to_openai(request)
|
||||
|
||||
# Call OpenAI
|
||||
openai_response = await self.chat_completions(openai_request)
|
||||
|
||||
# Convert response back to Anthropic format
|
||||
return self.transformer.openai_response_to_anthropic(openai_response)
|
||||
|
||||
async def check_health(self) -> HealthStatus:
|
||||
"""Check OpenAI API health by listing models."""
|
||||
url = f"{self.config.api_base}/models"
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
try:
|
||||
response = await client.get(url, headers=self.get_headers())
|
||||
|
||||
if response.status_code == 200:
|
||||
return HealthStatus.HEALTHY
|
||||
elif response.status_code >= 500:
|
||||
return HealthStatus.UNHEALTHY
|
||||
else:
|
||||
return HealthStatus.DEGRADED
|
||||
|
||||
except Exception:
|
||||
return HealthStatus.UNHEALTHY
|
||||
1
llm-gateway/app/api/__init__.py
Normal file
1
llm-gateway/app/api/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# api module
|
||||
1
llm-gateway/app/api/admin/__init__.py
Normal file
1
llm-gateway/app/api/admin/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# admin module
|
||||
68
llm-gateway/app/api/admin/health.py
Normal file
68
llm-gateway/app/api/admin/health.py
Normal file
@ -0,0 +1,68 @@
|
||||
"""Health check API endpoints."""
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import get_settings
|
||||
from app.db.database import get_db
|
||||
from app.models.provider import Provider
|
||||
|
||||
router = APIRouter(tags=["Health"])
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check() -> dict:
|
||||
"""Basic health check endpoint."""
|
||||
settings = get_settings()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"app": settings.app_name,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/admin/providers/{provider_id}/health")
|
||||
async def provider_health_check(
|
||||
provider_id: str,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> dict:
|
||||
"""Check health status of a specific provider."""
|
||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
return {
|
||||
"status": "not_found",
|
||||
"provider_id": provider_id,
|
||||
}
|
||||
|
||||
return {
|
||||
"status": provider.health_status,
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider.name,
|
||||
"last_check": provider.last_health_check.isoformat() if provider.last_health_check else None,
|
||||
"enabled": provider.enabled,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/ready")
|
||||
async def readiness_check(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> dict:
|
||||
"""Readiness check - verify database connection."""
|
||||
try:
|
||||
# Simple query to verify database connection
|
||||
await db.execute(select(1))
|
||||
return {
|
||||
"status": "ready",
|
||||
"database": "connected",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "not_ready",
|
||||
"database": "disconnected",
|
||||
"error": str(e),
|
||||
}
|
||||
233
llm-gateway/app/api/admin/keys.py
Normal file
233
llm-gateway/app/api/admin/keys.py
Normal file
@ -0,0 +1,233 @@
|
||||
"""API Key management API endpoints."""
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.models.api_key import APIKey
|
||||
from app.models.project import Project
|
||||
from app.schemas.api_key import (
|
||||
APIKeyCreate,
|
||||
APIKeyCreateResponse,
|
||||
APIKeyListResponse,
|
||||
APIKeyResponse,
|
||||
APIKeyUpdate,
|
||||
)
|
||||
from app.utils.crypto import generate_api_key
|
||||
|
||||
router = APIRouter(prefix="/keys", tags=["Admin - Keys"])
|
||||
|
||||
|
||||
@router.post("", response_model=APIKeyCreateResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_api_key(
|
||||
data: APIKeyCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> dict:
|
||||
"""Create a new API key."""
|
||||
# Verify project exists if specified
|
||||
if data.project_id:
|
||||
result = await db.execute(select(Project).where(Project.id == data.project_id))
|
||||
if not result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Project '{data.project_id}' not found",
|
||||
)
|
||||
|
||||
# Generate API key
|
||||
full_key, key_hash, key_prefix = generate_api_key(data.prefix)
|
||||
|
||||
# Create API key record
|
||||
api_key = APIKey(
|
||||
key_hash=key_hash,
|
||||
key_prefix=key_prefix,
|
||||
name=data.name,
|
||||
project_id=data.project_id,
|
||||
enabled=data.enabled,
|
||||
expires_at=data.expires_at,
|
||||
rpm_limit=data.rpm_limit,
|
||||
tpm_limit=data.tpm_limit,
|
||||
budget_limit=data.budget_limit,
|
||||
budget_period=data.budget_period,
|
||||
allowed_models=json.dumps(data.allowed_models) if data.allowed_models else None,
|
||||
)
|
||||
|
||||
db.add(api_key)
|
||||
await db.flush()
|
||||
await db.refresh(api_key)
|
||||
|
||||
# Return response with full key (only time it's shown)
|
||||
return {
|
||||
"id": api_key.id,
|
||||
"key": full_key, # Full key - shown only once!
|
||||
"key_prefix": api_key.key_prefix,
|
||||
"name": api_key.name,
|
||||
"project_id": api_key.project_id,
|
||||
"enabled": api_key.enabled,
|
||||
"expires_at": api_key.expires_at,
|
||||
"rpm_limit": api_key.rpm_limit,
|
||||
"tpm_limit": api_key.tpm_limit,
|
||||
"budget_limit": api_key.budget_limit,
|
||||
"budget_period": api_key.budget_period,
|
||||
"allowed_models": json.loads(api_key.allowed_models) if api_key.allowed_models else None,
|
||||
"current_usage": api_key.current_usage,
|
||||
"total_requests": api_key.total_requests,
|
||||
"created_at": api_key.created_at,
|
||||
"updated_at": api_key.updated_at,
|
||||
}
|
||||
|
||||
|
||||
@router.get("", response_model=APIKeyListResponse)
|
||||
async def list_api_keys(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
page: Annotated[int, Query(ge=1)] = 1,
|
||||
page_size: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
project_id: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
) -> APIKeyListResponse:
|
||||
"""List all API keys."""
|
||||
query = select(APIKey)
|
||||
count_query = select(func.count()).select_from(APIKey)
|
||||
|
||||
if project_id:
|
||||
query = query.where(APIKey.project_id == project_id)
|
||||
count_query = count_query.where(APIKey.project_id == project_id)
|
||||
|
||||
if enabled is not None:
|
||||
query = query.where(APIKey.enabled == enabled)
|
||||
count_query = count_query.where(APIKey.enabled == enabled)
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
query = query.offset(offset).limit(page_size).order_by(APIKey.created_at.desc())
|
||||
|
||||
result = await db.execute(query)
|
||||
keys = result.scalars().all()
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
return APIKeyListResponse(
|
||||
keys=[
|
||||
APIKeyResponse(
|
||||
id=k.id,
|
||||
key_prefix=k.key_prefix,
|
||||
name=k.name,
|
||||
project_id=k.project_id,
|
||||
enabled=k.enabled,
|
||||
expires_at=k.expires_at,
|
||||
rpm_limit=k.rpm_limit,
|
||||
tpm_limit=k.tpm_limit,
|
||||
budget_limit=k.budget_limit,
|
||||
budget_period=k.budget_period,
|
||||
allowed_models=json.loads(k.allowed_models) if k.allowed_models else None,
|
||||
current_usage=k.current_usage,
|
||||
total_requests=k.total_requests,
|
||||
created_at=k.created_at,
|
||||
updated_at=k.updated_at,
|
||||
)
|
||||
for k in keys
|
||||
],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{key_id}", response_model=APIKeyResponse)
|
||||
async def get_api_key(
|
||||
key_id: str,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> dict:
|
||||
"""Get an API key by ID."""
|
||||
result = await db.execute(select(APIKey).where(APIKey.id == key_id))
|
||||
key = result.scalar_one_or_none()
|
||||
|
||||
if not key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"API key '{key_id}' not found",
|
||||
)
|
||||
|
||||
return APIKeyResponse(
|
||||
id=key.id,
|
||||
key_prefix=key.key_prefix,
|
||||
name=key.name,
|
||||
project_id=key.project_id,
|
||||
enabled=key.enabled,
|
||||
expires_at=key.expires_at,
|
||||
rpm_limit=key.rpm_limit,
|
||||
tpm_limit=key.tpm_limit,
|
||||
budget_limit=key.budget_limit,
|
||||
budget_period=key.budget_period,
|
||||
allowed_models=json.loads(key.allowed_models) if key.allowed_models else None,
|
||||
current_usage=key.current_usage,
|
||||
total_requests=key.total_requests,
|
||||
created_at=key.created_at,
|
||||
updated_at=key.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{key_id}", response_model=APIKeyResponse)
|
||||
async def update_api_key(
|
||||
key_id: str,
|
||||
data: APIKeyUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> dict:
|
||||
"""Update an API key."""
|
||||
result = await db.execute(select(APIKey).where(APIKey.id == key_id))
|
||||
key = result.scalar_one_or_none()
|
||||
|
||||
if not key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"API key '{key_id}' not found",
|
||||
)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
if "allowed_models" in update_data:
|
||||
key.allowed_models = json.dumps(update_data.pop("allowed_models"))
|
||||
|
||||
for k, v in update_data.items():
|
||||
setattr(key, k, v)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(key)
|
||||
|
||||
return APIKeyResponse(
|
||||
id=key.id,
|
||||
key_prefix=key.key_prefix,
|
||||
name=key.name,
|
||||
project_id=key.project_id,
|
||||
enabled=key.enabled,
|
||||
expires_at=key.expires_at,
|
||||
rpm_limit=key.rpm_limit,
|
||||
tpm_limit=key.tpm_limit,
|
||||
budget_limit=key.budget_limit,
|
||||
budget_period=key.budget_period,
|
||||
allowed_models=json.loads(key.allowed_models) if key.allowed_models else None,
|
||||
current_usage=key.current_usage,
|
||||
total_requests=key.total_requests,
|
||||
created_at=key.created_at,
|
||||
updated_at=key.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{key_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_api_key(
|
||||
key_id: str,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> None:
|
||||
"""Delete an API key."""
|
||||
result = await db.execute(select(APIKey).where(APIKey.id == key_id))
|
||||
key = result.scalar_one_or_none()
|
||||
|
||||
if not key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"API key '{key_id}' not found",
|
||||
)
|
||||
|
||||
await db.delete(key)
|
||||
169
llm-gateway/app/api/admin/models.py
Normal file
169
llm-gateway/app/api/admin/models.py
Normal file
@ -0,0 +1,169 @@
|
||||
"""Model alias management API endpoints."""
|
||||
import json
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.models.model_alias import ModelAlias
|
||||
from app.schemas.model_alias import (
|
||||
ModelAliasCreate,
|
||||
ModelAliasListResponse,
|
||||
ModelAliasResponse,
|
||||
ModelAliasUpdate,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/models/aliases", tags=["Admin - Models"])
|
||||
|
||||
|
||||
@router.post("", response_model=ModelAliasResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_model_alias(
|
||||
data: ModelAliasCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> ModelAlias:
|
||||
"""Create a new model alias."""
|
||||
# Check if alias already exists
|
||||
result = await db.execute(select(ModelAlias).where(ModelAlias.alias == data.alias))
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Model alias '{data.alias}' already exists",
|
||||
)
|
||||
|
||||
# Create model alias
|
||||
alias = ModelAlias(
|
||||
alias=data.alias,
|
||||
provider=data.provider,
|
||||
model=data.model,
|
||||
enabled=data.enabled,
|
||||
routing_type=data.routing_type,
|
||||
routing_config=json.dumps(data.routing_config.model_dump()) if data.routing_config else None,
|
||||
input_price_per_1k=data.input_price_per_1k,
|
||||
output_price_per_1k=data.output_price_per_1k,
|
||||
)
|
||||
|
||||
db.add(alias)
|
||||
await db.flush()
|
||||
await db.refresh(alias)
|
||||
|
||||
return alias
|
||||
|
||||
|
||||
@router.get("", response_model=ModelAliasListResponse)
|
||||
async def list_model_aliases(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
page: Annotated[int, Query(ge=1)] = 1,
|
||||
page_size: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
provider: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
) -> ModelAliasListResponse:
|
||||
"""List all model aliases."""
|
||||
query = select(ModelAlias)
|
||||
count_query = select(func.count()).select_from(ModelAlias)
|
||||
|
||||
if provider:
|
||||
query = query.where(ModelAlias.provider == provider)
|
||||
count_query = count_query.where(ModelAlias.provider == provider)
|
||||
|
||||
if enabled is not None:
|
||||
query = query.where(ModelAlias.enabled == enabled)
|
||||
count_query = count_query.where(ModelAlias.enabled == enabled)
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
query = query.offset(offset).limit(page_size).order_by(ModelAlias.created_at.desc())
|
||||
|
||||
result = await db.execute(query)
|
||||
aliases = result.scalars().all()
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
return ModelAliasListResponse(
|
||||
aliases=[
|
||||
ModelAliasResponse(
|
||||
id=a.id,
|
||||
alias=a.alias,
|
||||
provider=a.provider,
|
||||
model=a.model,
|
||||
enabled=a.enabled,
|
||||
routing_type=a.routing_type,
|
||||
routing_config=json.loads(a.routing_config) if a.routing_config else None,
|
||||
input_price_per_1k=a.input_price_per_1k,
|
||||
output_price_per_1k=a.output_price_per_1k,
|
||||
created_at=a.created_at,
|
||||
updated_at=a.updated_at,
|
||||
)
|
||||
for a in aliases
|
||||
],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{alias_id}", response_model=ModelAliasResponse)
|
||||
async def get_model_alias(
|
||||
alias_id: str,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> ModelAlias:
|
||||
"""Get a model alias by ID."""
|
||||
result = await db.execute(select(ModelAlias).where(ModelAlias.id == alias_id))
|
||||
alias = result.scalar_one_or_none()
|
||||
|
||||
if not alias:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Model alias '{alias_id}' not found",
|
||||
)
|
||||
|
||||
return alias
|
||||
|
||||
|
||||
@router.put("/{alias_id}", response_model=ModelAliasResponse)
|
||||
async def update_model_alias(
|
||||
alias_id: str,
|
||||
data: ModelAliasUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> ModelAlias:
|
||||
"""Update a model alias."""
|
||||
result = await db.execute(select(ModelAlias).where(ModelAlias.id == alias_id))
|
||||
alias = result.scalar_one_or_none()
|
||||
|
||||
if not alias:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Model alias '{alias_id}' not found",
|
||||
)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
if "routing_config" in update_data and update_data["routing_config"]:
|
||||
alias.routing_config = json.dumps(update_data.pop("routing_config"))
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(alias, key, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(alias)
|
||||
|
||||
return alias
|
||||
|
||||
|
||||
@router.delete("/{alias_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_model_alias(
|
||||
alias_id: str,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> None:
|
||||
"""Delete a model alias."""
|
||||
result = await db.execute(select(ModelAlias).where(ModelAlias.id == alias_id))
|
||||
alias = result.scalar_one_or_none()
|
||||
|
||||
if not alias:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Model alias '{alias_id}' not found",
|
||||
)
|
||||
|
||||
await db.delete(alias)
|
||||
132
llm-gateway/app/api/admin/projects.py
Normal file
132
llm-gateway/app/api/admin/projects.py
Normal file
@ -0,0 +1,132 @@
|
||||
"""Project management API endpoints."""
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.models.project import Project
|
||||
from app.schemas.project import (
|
||||
ProjectCreate,
|
||||
ProjectListResponse,
|
||||
ProjectResponse,
|
||||
ProjectUpdate,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/projects", tags=["Admin - Projects"])
|
||||
|
||||
|
||||
@router.post("", response_model=ProjectResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_project(
|
||||
data: ProjectCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> Project:
|
||||
"""Create a new project."""
|
||||
project = Project(
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
budget_limit=data.budget_limit,
|
||||
budget_period=data.budget_period,
|
||||
enabled=data.enabled,
|
||||
)
|
||||
|
||||
db.add(project)
|
||||
await db.flush()
|
||||
await db.refresh(project)
|
||||
|
||||
return project
|
||||
|
||||
|
||||
@router.get("", response_model=ProjectListResponse)
|
||||
async def list_projects(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
page: Annotated[int, Query(ge=1)] = 1,
|
||||
page_size: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
enabled: bool | None = None,
|
||||
) -> ProjectListResponse:
|
||||
"""List all projects."""
|
||||
query = select(Project)
|
||||
count_query = select(func.count()).select_from(Project)
|
||||
|
||||
if enabled is not None:
|
||||
query = query.where(Project.enabled == enabled)
|
||||
count_query = count_query.where(Project.enabled == enabled)
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
query = query.offset(offset).limit(page_size).order_by(Project.created_at.desc())
|
||||
|
||||
result = await db.execute(query)
|
||||
projects = result.scalars().all()
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
return ProjectListResponse(
|
||||
projects=[ProjectResponse.model_validate(p) for p in projects],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{project_id}", response_model=ProjectResponse)
|
||||
async def get_project(
|
||||
project_id: str,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> Project:
|
||||
"""Get a project by ID."""
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Project '{project_id}' not found",
|
||||
)
|
||||
|
||||
return project
|
||||
|
||||
|
||||
@router.put("/{project_id}", response_model=ProjectResponse)
|
||||
async def update_project(
|
||||
project_id: str,
|
||||
data: ProjectUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> Project:
|
||||
"""Update a project."""
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Project '{project_id}' not found",
|
||||
)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(project, key, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(project)
|
||||
|
||||
return project
|
||||
|
||||
|
||||
@router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_project(
|
||||
project_id: str,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> None:
|
||||
"""Delete a project."""
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Project '{project_id}' not found",
|
||||
)
|
||||
|
||||
await db.delete(project)
|
||||
192
llm-gateway/app/api/admin/providers.py
Normal file
192
llm-gateway/app/api/admin/providers.py
Normal file
@ -0,0 +1,192 @@
|
||||
"""Provider management API endpoints."""
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.models.provider import Provider
|
||||
from app.schemas.common import PaginationParams
|
||||
from app.schemas.provider import (
|
||||
ProviderCreate,
|
||||
ProviderListResponse,
|
||||
ProviderResponse,
|
||||
ProviderUpdate,
|
||||
)
|
||||
from app.utils.crypto import encrypt_value
|
||||
|
||||
router = APIRouter(prefix="/providers", tags=["Admin - Providers"])
|
||||
|
||||
|
||||
@router.post("", response_model=ProviderResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_provider(
|
||||
data: ProviderCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> Provider:
|
||||
"""Create a new provider."""
|
||||
# Check if provider already exists
|
||||
result = await db.execute(select(Provider).where(Provider.name == data.name))
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Provider '{data.name}' already exists",
|
||||
)
|
||||
|
||||
# Encrypt API key
|
||||
encrypted_key = encrypt_value(data.api_key)
|
||||
|
||||
# Create provider
|
||||
provider = Provider(
|
||||
name=data.name,
|
||||
api_base=data.api_base,
|
||||
api_key_encrypted=encrypted_key,
|
||||
api_version=data.api_version,
|
||||
config=json.dumps(data.config) if data.config else None,
|
||||
rpm_limit=data.rpm_limit,
|
||||
tpm_limit=data.tpm_limit,
|
||||
enabled=data.enabled,
|
||||
)
|
||||
|
||||
db.add(provider)
|
||||
await db.flush()
|
||||
await db.refresh(provider)
|
||||
|
||||
# Decrypt config for response
|
||||
if provider.config:
|
||||
provider.config = json.loads(provider.config)
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
@router.get("", response_model=ProviderListResponse)
|
||||
async def list_providers(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
page: Annotated[int, Query(ge=1)] = 1,
|
||||
page_size: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
enabled: bool | None = None,
|
||||
) -> ProviderListResponse:
|
||||
"""List all providers."""
|
||||
# Build query
|
||||
query = select(Provider)
|
||||
count_query = select(func.count()).select_from(Provider)
|
||||
|
||||
if enabled is not None:
|
||||
query = query.where(Provider.enabled == enabled)
|
||||
count_query = count_query.where(Provider.enabled == enabled)
|
||||
|
||||
# Apply pagination
|
||||
offset = (page - 1) * page_size
|
||||
query = query.offset(offset).limit(page_size).order_by(Provider.created_at.desc())
|
||||
|
||||
# Execute queries
|
||||
result = await db.execute(query)
|
||||
providers = result.scalars().all()
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# Decrypt config for response
|
||||
for p in providers:
|
||||
if p.config:
|
||||
p.config = json.loads(p.config)
|
||||
|
||||
return ProviderListResponse(
|
||||
providers=[
|
||||
ProviderResponse(
|
||||
id=p.id,
|
||||
name=p.name,
|
||||
enabled=p.enabled,
|
||||
api_base=p.api_base,
|
||||
api_version=p.api_version,
|
||||
config=json.loads(p.config) if p.config else None,
|
||||
rpm_limit=p.rpm_limit,
|
||||
tpm_limit=p.tpm_limit,
|
||||
health_status=p.health_status,
|
||||
last_health_check=p.last_health_check,
|
||||
created_at=p.created_at,
|
||||
updated_at=p.updated_at,
|
||||
)
|
||||
for p in providers
|
||||
],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{provider_id}", response_model=ProviderResponse)
|
||||
async def get_provider(
|
||||
provider_id: str,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> Provider:
|
||||
"""Get a provider by ID."""
|
||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Provider '{provider_id}' not found",
|
||||
)
|
||||
|
||||
if provider.config:
|
||||
provider.config = json.loads(provider.config)
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
@router.put("/{provider_id}", response_model=ProviderResponse)
|
||||
async def update_provider(
|
||||
provider_id: str,
|
||||
data: ProviderUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> Provider:
|
||||
"""Update a provider."""
|
||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Provider '{provider_id}' not found",
|
||||
)
|
||||
|
||||
# Update fields
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
if "api_key" in update_data:
|
||||
provider.api_key_encrypted = encrypt_value(update_data.pop("api_key"))
|
||||
|
||||
if "config" in update_data:
|
||||
provider.config = json.dumps(update_data.pop("config"))
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(provider, key, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(provider)
|
||||
|
||||
if provider.config:
|
||||
provider.config = json.loads(provider.config)
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
@router.delete("/{provider_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_provider(
|
||||
provider_id: str,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> None:
|
||||
"""Delete a provider."""
|
||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Provider '{provider_id}' not found",
|
||||
)
|
||||
|
||||
await db.delete(provider)
|
||||
221
llm-gateway/app/api/admin/usage.py
Normal file
221
llm-gateway/app/api/admin/usage.py
Normal file
@ -0,0 +1,221 @@
|
||||
"""Usage statistics API endpoints."""
|
||||
from datetime import datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.models.usage import RequestLog, UsageStatsHourly
|
||||
|
||||
router = APIRouter(prefix="/usage", tags=["Admin - Usage"])
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_usage_stats(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
start_date: Annotated[datetime | None, Query()] = None,
|
||||
end_date: Annotated[datetime | None, Query()] = None,
|
||||
group_by: Annotated[str, Query(pattern="^(model|provider|project|key)$")] = "model",
|
||||
project_id: str | None = None,
|
||||
key_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Get usage statistics."""
|
||||
# Default to last 7 days
|
||||
if not start_date:
|
||||
start_date = datetime.utcnow() - timedelta(days=7)
|
||||
if not end_date:
|
||||
end_date = datetime.utcnow()
|
||||
|
||||
# Build base query
|
||||
query = select(
|
||||
getattr(UsageStatsHourly, group_by).label("group"),
|
||||
func.sum(UsageStatsHourly.request_count).label("requests"),
|
||||
func.sum(UsageStatsHourly.input_tokens).label("input_tokens"),
|
||||
func.sum(UsageStatsHourly.output_tokens).label("output_tokens"),
|
||||
func.sum(UsageStatsHourly.total_tokens).label("total_tokens"),
|
||||
func.sum(UsageStatsHourly.cost_usd).label("cost_usd"),
|
||||
func.avg(UsageStatsHourly.avg_latency_ms).label("avg_latency_ms"),
|
||||
func.sum(UsageStatsHourly.error_count).label("errors"),
|
||||
).where(
|
||||
UsageStatsHourly.timestamp >= start_date,
|
||||
UsageStatsHourly.timestamp <= end_date,
|
||||
)
|
||||
|
||||
if project_id:
|
||||
query = query.where(UsageStatsHourly.project_id == project_id)
|
||||
if key_id:
|
||||
query = query.where(UsageStatsHourly.virtual_key_id == key_id)
|
||||
|
||||
query = query.group_by(getattr(UsageStatsHourly, group_by))
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Calculate totals
|
||||
total_query = select(
|
||||
func.sum(UsageStatsHourly.request_count).label("requests"),
|
||||
func.sum(UsageStatsHourly.input_tokens).label("input_tokens"),
|
||||
func.sum(UsageStatsHourly.output_tokens).label("output_tokens"),
|
||||
func.sum(UsageStatsHourly.total_tokens).label("total_tokens"),
|
||||
func.sum(UsageStatsHourly.cost_usd).label("cost_usd"),
|
||||
func.sum(UsageStatsHourly.error_count).label("errors"),
|
||||
).where(
|
||||
UsageStatsHourly.timestamp >= start_date,
|
||||
UsageStatsHourly.timestamp <= end_date,
|
||||
)
|
||||
|
||||
if project_id:
|
||||
total_query = total_query.where(UsageStatsHourly.project_id == project_id)
|
||||
if key_id:
|
||||
total_query = total_query.where(UsageStatsHourly.virtual_key_id == key_id)
|
||||
|
||||
total_result = await db.execute(total_query)
|
||||
total_row = total_result.one()
|
||||
|
||||
return {
|
||||
"period": {
|
||||
"start": start_date.isoformat(),
|
||||
"end": end_date.isoformat(),
|
||||
},
|
||||
"totals": {
|
||||
"requests": total_row.requests or 0,
|
||||
"input_tokens": total_row.input_tokens or 0,
|
||||
"output_tokens": total_row.output_tokens or 0,
|
||||
"total_tokens": total_row.total_tokens or 0,
|
||||
"cost_usd": float(total_row.cost_usd or 0),
|
||||
"errors": total_row.errors or 0,
|
||||
},
|
||||
"by_" + group_by: [
|
||||
{
|
||||
group_by: row.group,
|
||||
"requests": row.requests or 0,
|
||||
"input_tokens": row.input_tokens or 0,
|
||||
"output_tokens": row.output_tokens or 0,
|
||||
"total_tokens": row.total_tokens or 0,
|
||||
"cost_usd": float(row.cost_usd or 0),
|
||||
"avg_latency_ms": int(row.avg_latency_ms or 0),
|
||||
"errors": row.errors or 0,
|
||||
}
|
||||
for row in rows
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/logs")
|
||||
async def get_usage_logs(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
page: Annotated[int, Query(ge=1)] = 1,
|
||||
page_size: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
project_id: str | None = None,
|
||||
key_id: str | None = None,
|
||||
provider: str | None = None,
|
||||
model: str | None = None,
|
||||
status_code: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Get request logs."""
|
||||
query = select(RequestLog)
|
||||
count_query = select(func.count()).select_from(RequestLog)
|
||||
|
||||
# Apply filters
|
||||
if project_id:
|
||||
query = query.where(RequestLog.project_id == project_id)
|
||||
count_query = count_query.where(RequestLog.project_id == project_id)
|
||||
if key_id:
|
||||
query = query.where(RequestLog.virtual_key_id == key_id)
|
||||
count_query = count_query.where(RequestLog.virtual_key_id == key_id)
|
||||
if provider:
|
||||
query = query.where(RequestLog.provider == provider)
|
||||
count_query = count_query.where(RequestLog.provider == provider)
|
||||
if model:
|
||||
query = query.where(RequestLog.model == model)
|
||||
count_query = count_query.where(RequestLog.model == model)
|
||||
if status_code:
|
||||
query = query.where(RequestLog.status_code == status_code)
|
||||
count_query = count_query.where(RequestLog.status_code == status_code)
|
||||
|
||||
# Apply pagination
|
||||
offset = (page - 1) * page_size
|
||||
query = query.offset(offset).limit(page_size).order_by(RequestLog.timestamp.desc())
|
||||
|
||||
result = await db.execute(query)
|
||||
logs = result.scalars().all()
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
return {
|
||||
"logs": [
|
||||
{
|
||||
"id": log.id,
|
||||
"timestamp": log.timestamp.isoformat(),
|
||||
"key_id": log.virtual_key_id,
|
||||
"project_id": log.project_id,
|
||||
"provider": log.provider,
|
||||
"model": log.model,
|
||||
"model_alias": log.model_alias,
|
||||
"request_type": log.request_type,
|
||||
"input_tokens": log.input_tokens,
|
||||
"output_tokens": log.output_tokens,
|
||||
"total_tokens": log.total_tokens,
|
||||
"status_code": log.status_code,
|
||||
"latency_ms": log.latency_ms,
|
||||
"finish_reason": log.finish_reason,
|
||||
"cost_usd": float(log.cost_usd),
|
||||
}
|
||||
for log in logs
|
||||
],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/costs")
|
||||
async def get_cost_breakdown(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
start_date: Annotated[datetime | None, Query()] = None,
|
||||
end_date: Annotated[datetime | None, Query()] = None,
|
||||
group_by: Annotated[str, Query(pattern="^(model|provider|project|key)$")] = "provider",
|
||||
) -> dict[str, Any]:
|
||||
"""Get cost breakdown."""
|
||||
# Default to current month
|
||||
if not start_date:
|
||||
start_date = datetime.utcnow().replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
if not end_date:
|
||||
end_date = datetime.utcnow()
|
||||
|
||||
query = select(
|
||||
getattr(UsageStatsHourly, group_by).label("group"),
|
||||
func.sum(UsageStatsHourly.cost_usd).label("cost_usd"),
|
||||
func.sum(UsageStatsHourly.request_count).label("requests"),
|
||||
func.sum(UsageStatsHourly.total_tokens).label("tokens"),
|
||||
).where(
|
||||
UsageStatsHourly.timestamp >= start_date,
|
||||
UsageStatsHourly.timestamp <= end_date,
|
||||
).group_by(getattr(UsageStatsHourly, group_by))
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
total_cost = sum(float(row.cost_usd or 0) for row in rows)
|
||||
|
||||
return {
|
||||
"period": {
|
||||
"start": start_date.isoformat(),
|
||||
"end": end_date.isoformat(),
|
||||
},
|
||||
"total_cost_usd": total_cost,
|
||||
"breakdown": [
|
||||
{
|
||||
group_by: row.group,
|
||||
"cost_usd": float(row.cost_usd or 0),
|
||||
"percentage": float(row.cost_usd or 0) / total_cost * 100 if total_cost > 0 else 0,
|
||||
"requests": row.requests or 0,
|
||||
"tokens": row.tokens or 0,
|
||||
}
|
||||
for row in sorted(rows, key=lambda r: r.cost_usd or 0, reverse=True)
|
||||
],
|
||||
}
|
||||
1
llm-gateway/app/api/v1/__init__.py
Normal file
1
llm-gateway/app/api/v1/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# v1 module
|
||||
327
llm-gateway/app/api/v1/chat.py
Normal file
327
llm-gateway/app/api/v1/chat.py
Normal file
@ -0,0 +1,327 @@
|
||||
"""Chat Completions API endpoint (OpenAI-compatible)."""
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.adapters import ProviderConfig, create_adapter
|
||||
from app.core.budget import BudgetController, BudgetExceeded
|
||||
from app.core.circuit_breaker import get_circuit_breaker
|
||||
from app.core.fallback import RetryExecutor, RetryConfig, RetryableError, classify_error
|
||||
from app.core.load_balancer import LoadBalancer
|
||||
from app.core.rate_limiter import RateLimiter, RateLimitExceeded
|
||||
from app.core.router import Router, RoutingResult
|
||||
from app.db.database import get_db
|
||||
from app.models.api_key import APIKey
|
||||
from app.models.provider import Provider
|
||||
from app.models.usage import RequestLog
|
||||
from app.schemas.openai import (
|
||||
OpenAIChatCompletionRequest,
|
||||
OpenAIChatCompletionResponse,
|
||||
)
|
||||
from app.utils.crypto import decrypt_value, verify_api_key
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/v1", tags=["Chat"])
|
||||
|
||||
|
||||
async def authenticate(
|
||||
authorization: str | None = Header(None),
|
||||
x_api_key: str | None = Header(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> APIKey:
|
||||
"""Authenticate request using virtual API key."""
|
||||
# Extract key from header
|
||||
key = None
|
||||
if authorization:
|
||||
if authorization.startswith("Bearer "):
|
||||
key = authorization[7:]
|
||||
elif x_api_key:
|
||||
key = x_api_key
|
||||
|
||||
if not key:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={"error": {"type": "authentication_error", "message": "Missing API key"}},
|
||||
)
|
||||
|
||||
# Find and verify key
|
||||
result = await db.execute(select(APIKey))
|
||||
api_keys = result.scalars().all()
|
||||
|
||||
for api_key in api_keys:
|
||||
if verify_api_key(key, api_key.key_hash):
|
||||
if not api_key.enabled:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": {"type": "permission_error", "message": "API key is disabled"}},
|
||||
)
|
||||
if api_key.expires_at and api_key.expires_at < datetime.utcnow():
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": {"type": "permission_error", "message": "API key has expired"}},
|
||||
)
|
||||
return api_key
|
||||
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={"error": {"type": "authentication_error", "message": "Invalid API key"}},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/completions")
|
||||
async def chat_completions(
|
||||
request: OpenAIChatCompletionRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
api_key: Annotated[APIKey, Depends(authenticate)],
|
||||
) -> OpenAIChatCompletionResponse:
|
||||
"""Execute a chat completion request."""
|
||||
start_time = time.time()
|
||||
routing_result: RoutingResult | None = None
|
||||
response: OpenAIChatCompletionResponse | None = None
|
||||
error: Exception | None = None
|
||||
|
||||
try:
|
||||
# Initialize services
|
||||
router_service = Router(db)
|
||||
rate_limiter = RateLimiter(db)
|
||||
budget_controller = BudgetController(db)
|
||||
load_balancer = LoadBalancer(db)
|
||||
|
||||
# Check rate limits
|
||||
try:
|
||||
await rate_limiter.check_and_increment(
|
||||
f"key:{api_key.id}",
|
||||
rpm_limit=api_key.rpm_limit,
|
||||
tpm_limit=api_key.tpm_limit,
|
||||
)
|
||||
except RateLimitExceeded as e:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"error": {
|
||||
"type": "rate_limit_error",
|
||||
"message": str(e),
|
||||
"details": {
|
||||
"limit": e.limit,
|
||||
"remaining": e.remaining,
|
||||
"reset_at": e.reset_at.isoformat(),
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Check budget
|
||||
try:
|
||||
await budget_controller.check_budget(api_key.id)
|
||||
except BudgetExceeded as e:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail={
|
||||
"error": {
|
||||
"type": "budget_exceeded_error",
|
||||
"message": str(e),
|
||||
"details": {
|
||||
"limit": float(e.limit),
|
||||
"current_usage": float(e.current_usage),
|
||||
"scope": e.scope,
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Resolve model routing
|
||||
routing_result = await router_service.resolve_model(request.model)
|
||||
|
||||
# Check if model is allowed
|
||||
if api_key.allowed_models:
|
||||
import json
|
||||
allowed = json.loads(api_key.allowed_models)
|
||||
if request.model not in allowed:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": {
|
||||
"type": "permission_error",
|
||||
"message": f"Model '{request.model}' is not allowed for this API key",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Get provider config
|
||||
result = await db.execute(
|
||||
select(Provider).where(Provider.name == routing_result.provider, Provider.enabled == True)
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail={
|
||||
"error": {
|
||||
"type": "provider_error",
|
||||
"message": f"Provider '{routing_result.provider}' not found or disabled",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Check circuit breaker
|
||||
circuit_breaker = get_circuit_breaker()
|
||||
if not circuit_breaker.is_available(provider.name):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail={
|
||||
"error": {
|
||||
"type": "service_unavailable",
|
||||
"message": f"Provider '{provider.name}' is currently unavailable",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Create adapter
|
||||
provider_config = ProviderConfig(
|
||||
name=provider.name,
|
||||
api_base=provider.api_base,
|
||||
api_key=decrypt_value(provider.api_key_encrypted),
|
||||
api_version=provider.api_version,
|
||||
config=json.loads(provider.config) if provider.config else None,
|
||||
)
|
||||
|
||||
adapter = create_adapter(provider_config)
|
||||
|
||||
# Update request model to actual model
|
||||
request.model = routing_result.model
|
||||
|
||||
# Execute request with retry
|
||||
retry_executor = RetryExecutor(RetryConfig())
|
||||
|
||||
async def execute():
|
||||
return await adapter.chat_completions(request)
|
||||
|
||||
response = await retry_executor.execute(execute, provider.name, "chat_completions")
|
||||
circuit_breaker.record_success(provider.name)
|
||||
|
||||
# Record usage
|
||||
if response.usage:
|
||||
cost = await _calculate_cost(db, routing_result, response.usage)
|
||||
await budget_controller.record_usage(
|
||||
api_key.id,
|
||||
cost,
|
||||
response.usage.prompt_tokens,
|
||||
response.usage.completion_tokens,
|
||||
)
|
||||
|
||||
# Record tokens for rate limiting
|
||||
await rate_limiter.record_tokens(
|
||||
f"key:{api_key.id}",
|
||||
response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except RetryableError as e:
|
||||
error = e
|
||||
logger.error(
|
||||
"Chat completion failed",
|
||||
model=request.model,
|
||||
provider=routing_result.provider if routing_result else None,
|
||||
error=str(e),
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": {
|
||||
"type": "provider_error",
|
||||
"message": str(e),
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
error = e
|
||||
logger.exception("Unexpected error in chat completions")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": {
|
||||
"type": "internal_error",
|
||||
"message": str(e),
|
||||
}
|
||||
},
|
||||
)
|
||||
finally:
|
||||
# Log request
|
||||
await _log_request(
|
||||
db,
|
||||
api_key.id,
|
||||
routing_result,
|
||||
request,
|
||||
response,
|
||||
error,
|
||||
start_time,
|
||||
)
|
||||
|
||||
|
||||
async def _calculate_cost(
|
||||
db: AsyncSession,
|
||||
routing_result: RoutingResult,
|
||||
usage: Any,
|
||||
) -> Decimal:
|
||||
"""Calculate cost for a request."""
|
||||
from app.models.model_alias import ModelAlias
|
||||
import json
|
||||
|
||||
# Try to find pricing from model alias
|
||||
result = await db.execute(
|
||||
select(ModelAlias).where(ModelAlias.alias == routing_result.model)
|
||||
)
|
||||
alias = result.scalar_one_or_none()
|
||||
|
||||
if alias and alias.input_price_per_1k and alias.output_price_per_1k:
|
||||
input_cost = (usage.prompt_tokens / 1000) * alias.input_price_per_1k
|
||||
output_cost = (usage.completion_tokens / 1000) * alias.output_price_per_1k
|
||||
return Decimal(str(input_cost + output_cost))
|
||||
|
||||
# Default pricing (rough estimate)
|
||||
return Decimal("0.001")
|
||||
|
||||
|
||||
async def _log_request(
|
||||
db: AsyncSession,
|
||||
key_id: str,
|
||||
routing_result: RoutingResult | None,
|
||||
request: OpenAIChatCompletionRequest,
|
||||
response: OpenAIChatCompletionResponse | None,
|
||||
error: Exception | None,
|
||||
start_time: float,
|
||||
) -> None:
|
||||
"""Log request details."""
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
log = RequestLog(
|
||||
virtual_key_id=key_id,
|
||||
project_id=None, # Will be filled from API key
|
||||
provider=routing_result.provider if routing_result else "unknown",
|
||||
model=request.model,
|
||||
model_alias=request.model,
|
||||
request_type="chat",
|
||||
input_tokens=response.usage.prompt_tokens if response and response.usage else 0,
|
||||
output_tokens=response.usage.completion_tokens if response and response.usage else 0,
|
||||
total_tokens=response.usage.total_tokens if response and response.usage else 0,
|
||||
status_code=500 if error else (response and 200 or 500),
|
||||
latency_ms=latency_ms,
|
||||
finish_reason=response.choices[0].finish_reason if response and response.choices else None,
|
||||
cost_usd=Decimal("0.001") if response and response.usage else Decimal("0"),
|
||||
)
|
||||
|
||||
db.add(log)
|
||||
await db.commit()
|
||||
53
llm-gateway/app/api/v1/messages.py
Normal file
53
llm-gateway/app/api/v1/messages.py
Normal file
@ -0,0 +1,53 @@
|
||||
"""Anthropic Messages API endpoint."""
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.api.v1.chat import authenticate, _calculate_cost, _log_request
|
||||
from app.core.transformer import RequestTransformer
|
||||
from app.db.database import get_db
|
||||
from app.models.api_key import APIKey
|
||||
from app.schemas.anthropic import AnthropicMessagesRequest, AnthropicMessagesResponse
|
||||
from app.schemas.openai import OpenAIChatCompletionRequest
|
||||
|
||||
router = APIRouter(prefix="/v1", tags=["Messages"])
|
||||
|
||||
|
||||
@router.post("/messages")
|
||||
async def messages(
|
||||
request: AnthropicMessagesRequest,
|
||||
db: Annotated[None, Depends(get_db)],
|
||||
api_key: Annotated[APIKey, Depends(authenticate)],
|
||||
) -> AnthropicMessagesResponse:
|
||||
"""
|
||||
Execute an Anthropic Messages API request.
|
||||
|
||||
This endpoint accepts Anthropic-format requests and forwards them
|
||||
to the appropriate provider (Anthropic, OpenAI, etc.).
|
||||
"""
|
||||
from app.api.v1.chat import chat_completions
|
||||
from app.schemas.openai import OpenAIChatCompletionRequest
|
||||
|
||||
# Convert Anthropic request to OpenAI format
|
||||
transformer = RequestTransformer()
|
||||
openai_request = transformer.anthropic_to_openai(request)
|
||||
|
||||
# Use the chat completions endpoint
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
# Re-import with correct type
|
||||
from app.db.database import get_db as _get_db
|
||||
from typing import AsyncGenerator
|
||||
|
||||
# Get database session
|
||||
async for session in _get_db():
|
||||
# Call chat completions with converted request
|
||||
response = await chat_completions(openai_request, session, api_key)
|
||||
|
||||
# Convert response back to Anthropic format
|
||||
anthropic_response = transformer.openai_response_to_anthropic(response)
|
||||
|
||||
return anthropic_response
|
||||
|
||||
raise HTTPException(status_code=500, detail="Database session error")
|
||||
72
llm-gateway/app/api/v1/responses.py
Normal file
72
llm-gateway/app/api/v1/responses.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""OpenAI Responses API endpoint."""
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.api.v1.chat import authenticate
|
||||
from app.db.database import get_db
|
||||
from app.models.api_key import APIKey
|
||||
from app.schemas.openai import OpenAIResponseRequest
|
||||
|
||||
router = APIRouter(prefix="/v1", tags=["Responses"])
|
||||
|
||||
|
||||
@router.post("/responses")
|
||||
async def responses(
|
||||
request: OpenAIResponseRequest,
|
||||
db: Annotated[None, Depends(get_db)],
|
||||
api_key: Annotated[APIKey, Depends(authenticate)],
|
||||
) -> dict:
|
||||
"""
|
||||
Execute an OpenAI Responses API request.
|
||||
|
||||
This is the new OpenAI Responses API format.
|
||||
Converts to chat completions internally.
|
||||
"""
|
||||
from app.schemas.openai import OpenAIChatCompletionRequest, OpenAIChatMessage
|
||||
from app.api.v1.chat import chat_completions
|
||||
from app.db.database import get_db as _get_db
|
||||
|
||||
# Convert Responses API format to Chat Completions format
|
||||
messages = []
|
||||
|
||||
# Add instructions as system message if present
|
||||
if request.instructions:
|
||||
messages.append(OpenAIChatMessage(role="system", content=request.instructions))
|
||||
|
||||
# Add input as user message
|
||||
if isinstance(request.input, str):
|
||||
messages.append(OpenAIChatMessage(role="user", content=request.input))
|
||||
else:
|
||||
# Handle structured input
|
||||
messages.append(OpenAIChatMessage(role="user", content=str(request.input)))
|
||||
|
||||
chat_request = OpenAIChatCompletionRequest(
|
||||
model=request.model,
|
||||
messages=messages,
|
||||
max_tokens=request.max_output_tokens,
|
||||
temperature=request.temperature,
|
||||
tools=request.tools,
|
||||
tool_choice=request.tool_choice,
|
||||
)
|
||||
|
||||
# Get database session and call chat completions
|
||||
async for session in _get_db():
|
||||
response = await chat_completions(chat_request, session, api_key)
|
||||
|
||||
# Convert to Responses API format
|
||||
return {
|
||||
"id": response.id,
|
||||
"object": "response",
|
||||
"created": response.created,
|
||||
"model": response.model,
|
||||
"output": response.choices[0].message.content if response.choices else "",
|
||||
"usage": {
|
||||
"input_tokens": response.usage.prompt_tokens if response.usage else 0,
|
||||
"output_tokens": response.usage.completion_tokens if response.usage else 0,
|
||||
"total_tokens": response.usage.total_tokens if response.usage else 0,
|
||||
},
|
||||
}
|
||||
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=500, detail="Database session error")
|
||||
73
llm-gateway/app/config.py
Normal file
73
llm-gateway/app/config.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""Application configuration using Pydantic Settings."""
|
||||
import secrets
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# Application
|
||||
app_name: str = "LLM Gateway"
|
||||
debug: bool = False
|
||||
log_level: str = "INFO"
|
||||
|
||||
# Server
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
# Database
|
||||
database_url: str = Field(
|
||||
default="sqlite:///data/gateway.db",
|
||||
description="SQLite database URL",
|
||||
)
|
||||
|
||||
# Security
|
||||
master_key: str = Field(
|
||||
default_factory=lambda: secrets.token_hex(32),
|
||||
description="Master key for encrypting provider API keys",
|
||||
)
|
||||
|
||||
# Rate Limiting
|
||||
rate_limit_window_seconds: int = 60
|
||||
global_rpm_limit: int | None = None
|
||||
global_tpm_limit: int | None = None
|
||||
|
||||
# Retry
|
||||
max_retries: int = 3
|
||||
retry_initial_delay: float = 1.0
|
||||
retry_max_delay: float = 30.0
|
||||
retry_exponential_base: float = 2.0
|
||||
|
||||
# Health Check
|
||||
health_check_interval: int = 30
|
||||
health_check_timeout: int = 10
|
||||
|
||||
@field_validator("master_key")
|
||||
@classmethod
|
||||
def validate_master_key(cls, v: str) -> str:
|
||||
"""Ensure master key is at least 32 characters."""
|
||||
if len(v) < 32:
|
||||
raise ValueError("Master key must be at least 32 characters")
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def generate_master_key() -> str:
|
||||
"""Generate a secure random master key."""
|
||||
return secrets.token_hex(32)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings instance."""
|
||||
return Settings()
|
||||
1
llm-gateway/app/core/__init__.py
Normal file
1
llm-gateway/app/core/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# core module
|
||||
160
llm-gateway/app/core/budget.py
Normal file
160
llm-gateway/app/core/budget.py
Normal file
@ -0,0 +1,160 @@
|
||||
"""Budget controller for spending limits."""
|
||||
from datetime import datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.api_key import APIKey
|
||||
from app.models.project import Project
|
||||
from app.models.usage import RequestLog
|
||||
|
||||
|
||||
class BudgetExceeded(Exception):
|
||||
"""Exception raised when budget is exceeded."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
budget_type: str,
|
||||
limit: Decimal,
|
||||
current_usage: Decimal,
|
||||
scope: str,
|
||||
):
|
||||
self.budget_type = budget_type
|
||||
self.limit = limit
|
||||
self.current_usage = current_usage
|
||||
self.scope = scope
|
||||
super().__init__(f"Budget exceeded: {scope} {budget_type}")
|
||||
|
||||
|
||||
class BudgetController:
|
||||
"""Controller for budget limits at key and project levels."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def check_budget(
|
||||
self,
|
||||
key_id: str,
|
||||
estimated_cost: Decimal = Decimal("0"),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Check budget limits before making a request.
|
||||
|
||||
Args:
|
||||
key_id: The API key ID.
|
||||
estimated_cost: Estimated cost for this request.
|
||||
|
||||
Returns:
|
||||
Dict with budget status.
|
||||
|
||||
Raises:
|
||||
BudgetExceeded: If budget is exceeded.
|
||||
"""
|
||||
# Get API key
|
||||
result = await self.db.execute(select(APIKey).where(APIKey.id == key_id))
|
||||
api_key = result.scalar_one_or_none()
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(f"API key '{key_id}' not found")
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Check key-level budget
|
||||
if api_key.budget_limit:
|
||||
period_start = self._get_period_start(api_key.budget_period, now)
|
||||
current_usage = await self._get_usage_for_period(key_id, None, period_start, now)
|
||||
|
||||
if current_usage + estimated_cost > api_key.budget_limit:
|
||||
raise BudgetExceeded(
|
||||
budget_type=api_key.budget_period or "total",
|
||||
limit=api_key.budget_limit,
|
||||
current_usage=current_usage,
|
||||
scope=f"key:{api_key.name}",
|
||||
)
|
||||
|
||||
# Check soft limit (90% of hard limit)
|
||||
soft_limit = api_key.budget_limit * Decimal("0.9")
|
||||
if current_usage + estimated_cost > soft_limit:
|
||||
# Log warning but don't block
|
||||
pass
|
||||
|
||||
# Check project-level budget
|
||||
if api_key.project_id:
|
||||
result = await self.db.execute(
|
||||
select(Project).where(Project.id == api_key.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if project and project.budget_limit:
|
||||
period_start = self._get_period_start(project.budget_period, now)
|
||||
current_usage = await self._get_usage_for_period(
|
||||
None, project.id, period_start, now
|
||||
)
|
||||
|
||||
if current_usage + estimated_cost > project.budget_limit:
|
||||
raise BudgetExceeded(
|
||||
budget_type=project.budget_period or "total",
|
||||
limit=project.budget_limit,
|
||||
current_usage=current_usage,
|
||||
scope=f"project:{project.name}",
|
||||
)
|
||||
|
||||
return {
|
||||
"key_budget_limit": float(api_key.budget_limit) if api_key.budget_limit else None,
|
||||
"key_current_usage": float(api_key.current_usage),
|
||||
"project_budget_limit": None,
|
||||
"project_current_usage": None,
|
||||
}
|
||||
|
||||
async def record_usage(
|
||||
self,
|
||||
key_id: str,
|
||||
cost: Decimal,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
) -> None:
|
||||
"""Record actual usage after request completes."""
|
||||
result = await self.db.execute(select(APIKey).where(APIKey.id == key_id))
|
||||
api_key = result.scalar_one_or_none()
|
||||
|
||||
if api_key:
|
||||
api_key.current_usage += cost
|
||||
api_key.total_requests += 1
|
||||
|
||||
def _get_period_start(self, period: str | None, now: datetime) -> datetime:
|
||||
"""Get the start of the budget period."""
|
||||
if period == "daily":
|
||||
return now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
elif period == "weekly":
|
||||
start = now - timedelta(days=now.weekday())
|
||||
return start.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
elif period == "monthly":
|
||||
return now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
else:
|
||||
# No period means all-time budget
|
||||
return datetime.min
|
||||
|
||||
async def _get_usage_for_period(
|
||||
self,
|
||||
key_id: str | None,
|
||||
project_id: str | None,
|
||||
period_start: datetime,
|
||||
now: datetime,
|
||||
) -> Decimal:
|
||||
"""Get total usage for a period."""
|
||||
query = select(func.sum(RequestLog.cost_usd)).where(
|
||||
RequestLog.timestamp >= period_start,
|
||||
RequestLog.timestamp <= now,
|
||||
)
|
||||
|
||||
if key_id:
|
||||
query = query.where(RequestLog.virtual_key_id == key_id)
|
||||
if project_id:
|
||||
query = query.where(RequestLog.project_id == project_id)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
total = result.scalar()
|
||||
|
||||
return Decimal(str(total or 0))
|
||||
168
llm-gateway/app/core/circuit_breaker.py
Normal file
168
llm-gateway/app/core/circuit_breaker.py
Normal file
@ -0,0 +1,168 @@
|
||||
"""Circuit breaker for provider fault tolerance."""
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
"""Circuit breaker states."""
|
||||
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Failing, reject all requests
|
||||
HALF_OPEN = "half_open" # Testing if recovered
|
||||
|
||||
|
||||
@dataclass
|
||||
class CircuitStats:
|
||||
"""Statistics for circuit breaker."""
|
||||
|
||||
failure_count: int = 0
|
||||
success_count: int = 0
|
||||
last_failure_time: float = 0
|
||||
last_success_time: float = 0
|
||||
state: CircuitState = CircuitState.CLOSED
|
||||
last_state_change: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""
|
||||
Circuit breaker for provider fault tolerance.
|
||||
|
||||
States:
|
||||
- CLOSED: Normal operation, requests pass through
|
||||
- OPEN: Too many failures, reject all requests
|
||||
- HALF_OPEN: Testing if provider has recovered
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
failure_threshold: int = 5,
|
||||
recovery_timeout: float = 30.0,
|
||||
success_threshold: int = 2,
|
||||
):
|
||||
"""
|
||||
Initialize circuit breaker.
|
||||
|
||||
Args:
|
||||
failure_threshold: Number of failures before opening circuit.
|
||||
recovery_timeout: Seconds to wait before trying half-open state.
|
||||
success_threshold: Number of successes in half-open to close circuit.
|
||||
"""
|
||||
self.failure_threshold = failure_threshold
|
||||
self.recovery_timeout = recovery_timeout
|
||||
self.success_threshold = success_threshold
|
||||
self._circuits: dict[str, CircuitStats] = {}
|
||||
|
||||
def get_stats(self, provider: str) -> CircuitStats:
|
||||
"""Get or create circuit stats for a provider."""
|
||||
if provider not in self._circuits:
|
||||
self._circuits[provider] = CircuitStats()
|
||||
return self._circuits[provider]
|
||||
|
||||
def is_available(self, provider: str) -> bool:
|
||||
"""
|
||||
Check if a provider is available (circuit is not open).
|
||||
|
||||
Args:
|
||||
provider: Provider name.
|
||||
|
||||
Returns:
|
||||
True if requests should be allowed.
|
||||
"""
|
||||
stats = self.get_stats(provider)
|
||||
now = time.time()
|
||||
|
||||
if stats.state == CircuitState.CLOSED:
|
||||
return True
|
||||
|
||||
elif stats.state == CircuitState.OPEN:
|
||||
# Check if recovery timeout has passed
|
||||
if now - stats.last_state_change >= self.recovery_timeout:
|
||||
# Transition to half-open
|
||||
stats.state = CircuitState.HALF_OPEN
|
||||
stats.success_count = 0
|
||||
stats.last_state_change = now
|
||||
return True
|
||||
return False
|
||||
|
||||
elif stats.state == CircuitState.HALF_OPEN:
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
def record_success(self, provider: str) -> None:
|
||||
"""
|
||||
Record a successful request.
|
||||
|
||||
Args:
|
||||
provider: Provider name.
|
||||
"""
|
||||
stats = self.get_stats(provider)
|
||||
now = time.time()
|
||||
|
||||
stats.success_count += 1
|
||||
stats.last_success_time = now
|
||||
stats.failure_count = 0 # Reset failure count on success
|
||||
|
||||
if stats.state == CircuitState.HALF_OPEN:
|
||||
if stats.success_count >= self.success_threshold:
|
||||
# Transition to closed
|
||||
stats.state = CircuitState.CLOSED
|
||||
stats.last_state_change = now
|
||||
|
||||
def record_failure(self, provider: str) -> None:
|
||||
"""
|
||||
Record a failed request.
|
||||
|
||||
Args:
|
||||
provider: Provider name.
|
||||
"""
|
||||
stats = self.get_stats(provider)
|
||||
now = time.time()
|
||||
|
||||
stats.failure_count += 1
|
||||
stats.last_failure_time = now
|
||||
|
||||
if stats.state == CircuitState.HALF_OPEN:
|
||||
# Immediately open on failure in half-open
|
||||
stats.state = CircuitState.OPEN
|
||||
stats.last_state_change = now
|
||||
|
||||
elif stats.state == CircuitState.CLOSED:
|
||||
if stats.failure_count >= self.failure_threshold:
|
||||
# Transition to open
|
||||
stats.state = CircuitState.OPEN
|
||||
stats.last_state_change = now
|
||||
|
||||
def get_state(self, provider: str) -> CircuitState:
|
||||
"""Get current circuit state for a provider."""
|
||||
return self.get_stats(provider).state
|
||||
|
||||
def reset(self, provider: str) -> None:
|
||||
"""Reset circuit breaker for a provider."""
|
||||
if provider in self._circuits:
|
||||
del self._circuits[provider]
|
||||
|
||||
def reset_all(self) -> None:
|
||||
"""Reset all circuit breakers."""
|
||||
self._circuits.clear()
|
||||
|
||||
|
||||
# Global circuit breaker instance
|
||||
_circuit_breaker: CircuitBreaker | None = None
|
||||
|
||||
|
||||
def get_circuit_breaker() -> CircuitBreaker:
|
||||
"""Get the global circuit breaker instance."""
|
||||
global _circuit_breaker
|
||||
if _circuit_breaker is None:
|
||||
from app.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
_circuit_breaker = CircuitBreaker(
|
||||
failure_threshold=5,
|
||||
recovery_timeout=30.0,
|
||||
)
|
||||
return _circuit_breaker
|
||||
223
llm-gateway/app/core/fallback.py
Normal file
223
llm-gateway/app/core/fallback.py
Normal file
@ -0,0 +1,223 @@
|
||||
"""Fallback and retry logic for provider failures."""
|
||||
import asyncio
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
from app.config import get_settings
|
||||
from app.core.circuit_breaker import CircuitBreaker, get_circuit_breaker
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RetryableError(Exception):
|
||||
"""Error that can be retried."""
|
||||
|
||||
def __init__(self, message: str, error_type: str = "unknown"):
|
||||
self.message = message
|
||||
self.error_type = error_type
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class PermanentError(Exception):
|
||||
"""Error that should not be retried."""
|
||||
|
||||
def __init__(self, message: str, error_type: str = "unknown"):
|
||||
self.message = message
|
||||
self.error_type = error_type
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""Retry configuration."""
|
||||
|
||||
max_retries: int = 3
|
||||
initial_delay: float = 1.0
|
||||
max_delay: float = 30.0
|
||||
exponential_base: float = 2.0
|
||||
jitter: bool = True
|
||||
retryable_errors: set[str] | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.retryable_errors is None:
|
||||
self.retryable_errors = {
|
||||
"rate_limit_exceeded",
|
||||
"timeout",
|
||||
"service_unavailable",
|
||||
"internal_error",
|
||||
"connection_error",
|
||||
"overloaded",
|
||||
}
|
||||
|
||||
|
||||
class RetryExecutor:
|
||||
"""Executor for retrying failed operations."""
|
||||
|
||||
def __init__(self, config: RetryConfig | None = None):
|
||||
self.config = config or RetryConfig()
|
||||
self.settings = get_settings()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
operation: Callable[[], T],
|
||||
provider: str,
|
||||
operation_name: str = "operation",
|
||||
) -> T:
|
||||
"""
|
||||
Execute an operation with retry logic.
|
||||
|
||||
Args:
|
||||
operation: Async function to execute.
|
||||
provider: Provider name for circuit breaker.
|
||||
operation_name: Name for logging.
|
||||
|
||||
Returns:
|
||||
Result of the operation.
|
||||
|
||||
Raises:
|
||||
PermanentError: If operation fails permanently.
|
||||
RetryableError: If all retries are exhausted.
|
||||
"""
|
||||
circuit_breaker = get_circuit_breaker()
|
||||
last_error: Exception | None = None
|
||||
|
||||
for attempt in range(self.config.max_retries + 1):
|
||||
# Check circuit breaker
|
||||
if not circuit_breaker.is_available(provider):
|
||||
raise RetryableError(
|
||||
f"Circuit breaker open for provider '{provider}'",
|
||||
error_type="circuit_open",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await operation()
|
||||
circuit_breaker.record_success(provider)
|
||||
return result
|
||||
|
||||
except PermanentError as e:
|
||||
logger.warning(
|
||||
"Permanent error in %s",
|
||||
operation_name,
|
||||
provider=provider,
|
||||
attempt=attempt,
|
||||
error=e.message,
|
||||
)
|
||||
circuit_breaker.record_failure(provider)
|
||||
raise
|
||||
|
||||
except RetryableError as e:
|
||||
last_error = e
|
||||
circuit_breaker.record_failure(provider)
|
||||
|
||||
if attempt < self.config.max_retries:
|
||||
delay = self._calculate_delay(attempt)
|
||||
logger.warning(
|
||||
"Retryable error in %s, retrying in %.2fs",
|
||||
operation_name,
|
||||
delay,
|
||||
provider=provider,
|
||||
attempt=attempt,
|
||||
error=e.message,
|
||||
error_type=e.error_type,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logger.error(
|
||||
"All retries exhausted for %s",
|
||||
operation_name,
|
||||
provider=provider,
|
||||
attempts=attempt + 1,
|
||||
error=e.message,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Unknown error - treat as retryable
|
||||
last_error = RetryableError(str(e), error_type="unknown")
|
||||
circuit_breaker.record_failure(provider)
|
||||
|
||||
if attempt < self.config.max_retries:
|
||||
delay = self._calculate_delay(attempt)
|
||||
logger.warning(
|
||||
"Unknown error in %s, retrying in %.2fs",
|
||||
operation_name,
|
||||
delay,
|
||||
provider=provider,
|
||||
attempt=attempt,
|
||||
error=str(e),
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# All retries exhausted
|
||||
if last_error:
|
||||
raise last_error
|
||||
raise RetryableError("All retries exhausted", error_type="max_retries")
|
||||
|
||||
def _calculate_delay(self, attempt: int) -> float:
|
||||
"""Calculate delay for a retry attempt with exponential backoff and jitter."""
|
||||
delay = self.config.initial_delay * (self.config.exponential_base**attempt)
|
||||
delay = min(delay, self.config.max_delay)
|
||||
|
||||
if self.config.jitter:
|
||||
# Add jitter (0.5 to 1.5 times the delay)
|
||||
jitter = random.uniform(0.5, 1.5)
|
||||
delay *= jitter
|
||||
|
||||
return delay
|
||||
|
||||
|
||||
def classify_error(error: Exception) -> Exception:
|
||||
"""
|
||||
Classify an error as retryable or permanent.
|
||||
|
||||
Args:
|
||||
error: The error to classify.
|
||||
|
||||
Returns:
|
||||
RetryableError or PermanentError.
|
||||
"""
|
||||
error_str = str(error).lower()
|
||||
|
||||
# Retryable errors
|
||||
retryable_patterns = [
|
||||
"rate limit",
|
||||
"429",
|
||||
"timeout",
|
||||
"503",
|
||||
"502",
|
||||
"service unavailable",
|
||||
"internal error",
|
||||
"500",
|
||||
"connection",
|
||||
"overloaded",
|
||||
"capacity",
|
||||
]
|
||||
|
||||
for pattern in retryable_patterns:
|
||||
if pattern in error_str:
|
||||
return RetryableError(str(error), error_type=pattern.replace(" ", "_"))
|
||||
|
||||
# Permanent errors
|
||||
permanent_patterns = [
|
||||
"401",
|
||||
"unauthorized",
|
||||
"invalid api key",
|
||||
"403",
|
||||
"forbidden",
|
||||
"400",
|
||||
"invalid request",
|
||||
"404",
|
||||
"not found",
|
||||
"model not found",
|
||||
]
|
||||
|
||||
for pattern in permanent_patterns:
|
||||
if pattern in error_str:
|
||||
return PermanentError(str(error), error_type=pattern.replace(" ", "_"))
|
||||
|
||||
# Default to retryable
|
||||
return RetryableError(str(error), error_type="unknown")
|
||||
137
llm-gateway/app/core/load_balancer.py
Normal file
137
llm-gateway/app/core/load_balancer.py
Normal file
@ -0,0 +1,137 @@
|
||||
"""Load balancer for distributing requests across providers."""
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.provider import Provider
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderInfo:
|
||||
"""Information about a provider for load balancing."""
|
||||
|
||||
name: str
|
||||
weight: int
|
||||
health_status: str
|
||||
current_requests: int = 0
|
||||
|
||||
|
||||
class LoadBalancer:
|
||||
"""Load balancer for distributing requests across providers."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def select_provider(
|
||||
self,
|
||||
providers: list[dict[str, Any]],
|
||||
strategy: str = "weighted_round_robin",
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Select a provider based on the specified strategy.
|
||||
|
||||
Args:
|
||||
providers: List of provider configs with weights.
|
||||
strategy: Load balancing strategy.
|
||||
|
||||
Returns:
|
||||
Selected provider config, or None if none available.
|
||||
"""
|
||||
if not providers:
|
||||
return None
|
||||
|
||||
if len(providers) == 1:
|
||||
return providers[0]
|
||||
|
||||
# Filter healthy providers
|
||||
healthy_providers = await self._filter_healthy_providers(providers)
|
||||
|
||||
if not healthy_providers:
|
||||
return providers[0] # Fallback to first if all unhealthy
|
||||
|
||||
if strategy == "weighted_round_robin":
|
||||
return self._weighted_round_robin(healthy_providers)
|
||||
elif strategy == "random":
|
||||
return self._random_select(healthy_providers)
|
||||
elif strategy == "least_connections":
|
||||
return await self._least_connections(healthy_providers)
|
||||
else:
|
||||
return self._weighted_round_robin(healthy_providers)
|
||||
|
||||
def _weighted_round_robin(
|
||||
self, providers: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Select provider using weighted round robin."""
|
||||
total_weight = sum(p.get("weight", 1) for p in providers)
|
||||
r = random.uniform(0, total_weight)
|
||||
|
||||
cumulative = 0
|
||||
for provider in providers:
|
||||
cumulative += provider.get("weight", 1)
|
||||
if r <= cumulative:
|
||||
return provider
|
||||
|
||||
return providers[0]
|
||||
|
||||
def _random_select(self, providers: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
"""Select provider randomly."""
|
||||
return random.choice(providers)
|
||||
|
||||
async def _least_connections(
|
||||
self, providers: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Select provider with least connections (placeholder)."""
|
||||
# In a real implementation, this would track active connections
|
||||
# For now, use random as a fallback
|
||||
return self._random_select(providers)
|
||||
|
||||
async def _filter_healthy_providers(
|
||||
self, providers: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Filter out unhealthy providers."""
|
||||
healthy = []
|
||||
|
||||
for p in providers:
|
||||
provider_name = p.get("provider")
|
||||
if not provider_name:
|
||||
continue
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Provider).where(
|
||||
Provider.name == provider_name,
|
||||
Provider.enabled == True,
|
||||
Provider.health_status == "healthy",
|
||||
)
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if provider:
|
||||
healthy.append(p)
|
||||
|
||||
return healthy
|
||||
|
||||
async def get_provider_health(
|
||||
self, provider_name: str
|
||||
) -> dict[str, Any]:
|
||||
"""Get health status for a provider."""
|
||||
result = await self.db.execute(
|
||||
select(Provider).where(Provider.name == provider_name)
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
return {
|
||||
"name": provider_name,
|
||||
"status": "not_found",
|
||||
"available": False,
|
||||
}
|
||||
|
||||
return {
|
||||
"name": provider.name,
|
||||
"status": provider.health_status,
|
||||
"available": provider.enabled and provider.health_status == "healthy",
|
||||
"last_check": provider.last_health_check.isoformat() if provider.last_health_check else None,
|
||||
}
|
||||
210
llm-gateway/app/core/rate_limiter.py
Normal file
210
llm-gateway/app/core/rate_limiter.py
Normal file
@ -0,0 +1,210 @@
|
||||
"""Rate limiter for RPM/TPM limits."""
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import get_settings
|
||||
from app.models.usage import RateLimitCounter
|
||||
|
||||
|
||||
class RateLimitExceeded(Exception):
|
||||
"""Exception raised when rate limit is exceeded."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
limit_type: str,
|
||||
limit: int,
|
||||
remaining: int,
|
||||
reset_at: datetime,
|
||||
):
|
||||
self.limit_type = limit_type
|
||||
self.limit = limit
|
||||
self.remaining = remaining
|
||||
self.reset_at = reset_at
|
||||
super().__init__(f"Rate limit exceeded: {limit_type}")
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate limiter for RPM and TPM limits."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.settings = get_settings()
|
||||
self.window_seconds = self.settings.rate_limit_window_seconds
|
||||
|
||||
async def check_and_increment(
|
||||
self,
|
||||
key: str,
|
||||
rpm_limit: int | None = None,
|
||||
tpm_limit: int | None = None,
|
||||
estimated_tokens: int = 0,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Check rate limits and increment counters.
|
||||
|
||||
Args:
|
||||
key: Unique key for rate limiting (e.g., "key:{key_id}:rpm").
|
||||
rpm_limit: Requests per minute limit.
|
||||
tpm_limit: Tokens per minute limit.
|
||||
estimated_tokens: Estimated tokens for this request.
|
||||
|
||||
Returns:
|
||||
Dict with remaining limits.
|
||||
|
||||
Raises:
|
||||
RateLimitExceeded: If rate limit is exceeded.
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
window_start = now - timedelta(seconds=self.window_seconds)
|
||||
|
||||
# Check global limits
|
||||
if self.settings.global_rpm_limit:
|
||||
await self._check_limit(
|
||||
"global:rpm",
|
||||
self.settings.global_rpm_limit,
|
||||
window_start,
|
||||
now,
|
||||
1,
|
||||
"request",
|
||||
)
|
||||
|
||||
if self.settings.global_tpm_limit and estimated_tokens > 0:
|
||||
await self._check_limit(
|
||||
"global:tpm",
|
||||
self.settings.global_tpm_limit,
|
||||
window_start,
|
||||
now,
|
||||
estimated_tokens,
|
||||
"token",
|
||||
)
|
||||
|
||||
# Check specific limits
|
||||
if rpm_limit:
|
||||
await self._check_limit(
|
||||
f"{key}:rpm",
|
||||
rpm_limit,
|
||||
window_start,
|
||||
now,
|
||||
1,
|
||||
"request",
|
||||
)
|
||||
|
||||
if tpm_limit and estimated_tokens > 0:
|
||||
await self._check_limit(
|
||||
f"{key}:tpm",
|
||||
tpm_limit,
|
||||
window_start,
|
||||
now,
|
||||
estimated_tokens,
|
||||
"token",
|
||||
)
|
||||
|
||||
# Get remaining limits
|
||||
remaining = await self._get_remaining(key, rpm_limit, tpm_limit, window_start, now)
|
||||
|
||||
return remaining
|
||||
|
||||
async def _check_limit(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
window_start: datetime,
|
||||
now: datetime,
|
||||
increment: int,
|
||||
counter_type: str,
|
||||
) -> None:
|
||||
"""Check if a limit would be exceeded and increment counter."""
|
||||
# Get current counter
|
||||
result = await self.db.execute(
|
||||
select(RateLimitCounter).where(RateLimitCounter.key == key)
|
||||
)
|
||||
counter = result.scalar_one_or_none()
|
||||
|
||||
if counter:
|
||||
# Check if window has expired
|
||||
if counter.window_start < window_start:
|
||||
# Reset counter for new window
|
||||
counter.request_count = 0
|
||||
counter.token_count = 0
|
||||
counter.window_start = now
|
||||
|
||||
# Check limit
|
||||
current = (
|
||||
counter.request_count if counter_type == "request" else counter.token_count
|
||||
)
|
||||
if current >= limit:
|
||||
reset_at = counter.window_start + timedelta(seconds=counter.window_duration)
|
||||
raise RateLimitExceeded(
|
||||
limit_type=f"{counter_type}_per_{self.window_seconds}s",
|
||||
limit=limit,
|
||||
remaining=0,
|
||||
reset_at=reset_at,
|
||||
)
|
||||
|
||||
# Increment counter
|
||||
if counter_type == "request":
|
||||
counter.request_count += increment
|
||||
else:
|
||||
counter.token_count += increment
|
||||
else:
|
||||
# Create new counter
|
||||
counter = RateLimitCounter(
|
||||
key=key,
|
||||
request_count=1 if counter_type == "request" else 0,
|
||||
token_count=increment if counter_type == "token" else 0,
|
||||
window_start=now,
|
||||
window_duration=self.window_seconds,
|
||||
)
|
||||
self.db.add(counter)
|
||||
|
||||
async def _get_remaining(
|
||||
self,
|
||||
key: str,
|
||||
rpm_limit: int | None,
|
||||
tpm_limit: int | None,
|
||||
window_start: datetime,
|
||||
now: datetime,
|
||||
) -> dict[str, Any]:
|
||||
"""Get remaining limits for a key."""
|
||||
remaining = {}
|
||||
|
||||
if rpm_limit:
|
||||
result = await self.db.execute(
|
||||
select(RateLimitCounter).where(RateLimitCounter.key == f"{key}:rpm")
|
||||
)
|
||||
counter = result.scalar_one_or_none()
|
||||
remaining["rpm"] = rpm_limit - (counter.request_count if counter else 0)
|
||||
|
||||
if tpm_limit:
|
||||
result = await self.db.execute(
|
||||
select(RateLimitCounter).where(RateLimitCounter.key == f"{key}:tpm")
|
||||
)
|
||||
counter = result.scalar_one_or_none()
|
||||
remaining["tpm"] = tpm_limit - (counter.token_count if counter else 0)
|
||||
|
||||
remaining["reset_at"] = (now + timedelta(seconds=self.window_seconds)).isoformat()
|
||||
|
||||
return remaining
|
||||
|
||||
async def record_tokens(self, key: str, tokens: int) -> None:
|
||||
"""Record actual token usage after request completes."""
|
||||
tpm_key = f"{key}:tpm"
|
||||
result = await self.db.execute(
|
||||
select(RateLimitCounter).where(RateLimitCounter.key == tpm_key)
|
||||
)
|
||||
counter = result.scalar_one_or_none()
|
||||
|
||||
if counter:
|
||||
counter.token_count += tokens
|
||||
else:
|
||||
counter = RateLimitCounter(
|
||||
key=tpm_key,
|
||||
request_count=0,
|
||||
token_count=tokens,
|
||||
window_start=datetime.utcnow(),
|
||||
window_duration=self.window_seconds,
|
||||
)
|
||||
self.db.add(counter)
|
||||
173
llm-gateway/app/core/router.py
Normal file
173
llm-gateway/app/core/router.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""Router for model alias resolution and routing."""
|
||||
import json
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.model_alias import ModelAlias
|
||||
from app.models.provider import Provider
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutingResult:
|
||||
"""Result of routing decision."""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
provider_config: dict[str, Any] | None = None
|
||||
fallback_chain: list[dict[str, str]] | None = None
|
||||
|
||||
|
||||
class Router:
|
||||
"""Router for resolving model aliases and making routing decisions."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def resolve_model(self, model_alias: str) -> RoutingResult:
|
||||
"""
|
||||
Resolve a model alias to a provider and model.
|
||||
|
||||
Args:
|
||||
model_alias: The model alias to resolve.
|
||||
|
||||
Returns:
|
||||
RoutingResult with provider, model, and optional fallback chain.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model alias is not found.
|
||||
"""
|
||||
# Look up the model alias
|
||||
result = await self.db.execute(
|
||||
select(ModelAlias).where(ModelAlias.alias == model_alias, ModelAlias.enabled == True)
|
||||
)
|
||||
alias = result.scalar_one_or_none()
|
||||
|
||||
if not alias:
|
||||
# If no alias found, treat as direct provider/model reference
|
||||
# Format: "provider/model" (e.g., "openai/gpt-4")
|
||||
if "/" in model_alias:
|
||||
provider, model = model_alias.split("/", 1)
|
||||
return RoutingResult(provider=provider, model=model)
|
||||
raise ValueError(f"Model alias '{model_alias}' not found")
|
||||
|
||||
# Parse routing config
|
||||
routing_config = json.loads(alias.routing_config) if alias.routing_config else None
|
||||
|
||||
if alias.routing_type == "simple":
|
||||
return RoutingResult(
|
||||
provider=alias.provider,
|
||||
model=alias.model,
|
||||
)
|
||||
|
||||
elif alias.routing_type == "load_balance":
|
||||
# Weighted random selection
|
||||
providers = routing_config.get("providers", []) if routing_config else []
|
||||
if not providers:
|
||||
return RoutingResult(provider=alias.provider, model=alias.model)
|
||||
|
||||
# Filter healthy providers
|
||||
healthy_providers = await self._filter_healthy_providers(providers)
|
||||
|
||||
if not healthy_providers:
|
||||
# Fallback to default if all unhealthy
|
||||
return RoutingResult(provider=alias.provider, model=alias.model)
|
||||
|
||||
# Weighted random selection
|
||||
total_weight = sum(p.get("weight", 1) for p in healthy_providers)
|
||||
r = random.uniform(0, total_weight)
|
||||
|
||||
cumulative = 0
|
||||
for p in healthy_providers:
|
||||
cumulative += p.get("weight", 1)
|
||||
if r <= cumulative:
|
||||
return RoutingResult(
|
||||
provider=p["provider"],
|
||||
model=p.get("model", alias.model),
|
||||
)
|
||||
|
||||
# Fallback
|
||||
return RoutingResult(
|
||||
provider=healthy_providers[0]["provider"],
|
||||
model=healthy_providers[0].get("model", alias.model),
|
||||
)
|
||||
|
||||
elif alias.routing_type == "fallback":
|
||||
# Return primary with fallback chain
|
||||
primary = (
|
||||
routing_config.get("primary", {}) if routing_config else {}
|
||||
)
|
||||
fallback = routing_config.get("fallback", []) if routing_config else []
|
||||
|
||||
return RoutingResult(
|
||||
provider=primary.get("provider", alias.provider),
|
||||
model=primary.get("model", alias.model),
|
||||
fallback_chain=fallback,
|
||||
)
|
||||
|
||||
# Default to simple routing
|
||||
return RoutingResult(provider=alias.provider, model=alias.model)
|
||||
|
||||
async def _filter_healthy_providers(
|
||||
self, providers: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Filter out unhealthy providers."""
|
||||
healthy = []
|
||||
for p in providers:
|
||||
provider_name = p.get("provider")
|
||||
if not provider_name:
|
||||
continue
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Provider).where(
|
||||
Provider.name == provider_name,
|
||||
Provider.enabled == True,
|
||||
Provider.health_status == "healthy",
|
||||
)
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if provider:
|
||||
healthy.append(p)
|
||||
|
||||
return healthy
|
||||
|
||||
async def get_fallback_provider(
|
||||
self, failed_provider: str, fallback_chain: list[dict[str, str]] | None
|
||||
) -> RoutingResult | None:
|
||||
"""
|
||||
Get the next fallback provider after a failure.
|
||||
|
||||
Args:
|
||||
failed_provider: The provider that failed.
|
||||
fallback_chain: List of fallback providers.
|
||||
|
||||
Returns:
|
||||
RoutingResult for the next provider, or None if no more fallbacks.
|
||||
"""
|
||||
if not fallback_chain:
|
||||
return None
|
||||
|
||||
# Find the failed provider in the chain and return the next one
|
||||
for i, fallback in enumerate(fallback_chain):
|
||||
if fallback.get("provider") == failed_provider and i + 1 < len(fallback_chain):
|
||||
next_fallback = fallback_chain[i + 1]
|
||||
return RoutingResult(
|
||||
provider=next_fallback.get("provider", ""),
|
||||
model=next_fallback.get("model", ""),
|
||||
fallback_chain=fallback_chain[i + 1 :],
|
||||
)
|
||||
|
||||
# If failed provider not in chain or at the end, try from the beginning
|
||||
for fallback in fallback_chain:
|
||||
if fallback.get("provider") != failed_provider:
|
||||
return RoutingResult(
|
||||
provider=fallback.get("provider", ""),
|
||||
model=fallback.get("model", ""),
|
||||
fallback_chain=fallback_chain,
|
||||
)
|
||||
|
||||
return None
|
||||
228
llm-gateway/app/core/transformer.py
Normal file
228
llm-gateway/app/core/transformer.py
Normal file
@ -0,0 +1,228 @@
|
||||
"""Request transformer for converting between API formats."""
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from app.schemas.anthropic import (
|
||||
AnthropicContentBlock,
|
||||
AnthropicMessage,
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicMessagesResponse,
|
||||
AnthropicTextBlock,
|
||||
AnthropicUsage,
|
||||
)
|
||||
from app.schemas.openai import (
|
||||
OpenAIChatCompletionChoice,
|
||||
OpenAIChatCompletionRequest,
|
||||
OpenAIChatCompletionResponse,
|
||||
OpenAIChatMessage,
|
||||
OpenAIUsage,
|
||||
)
|
||||
|
||||
|
||||
class RequestTransformer:
|
||||
"""Transform requests between different API formats."""
|
||||
|
||||
@staticmethod
|
||||
def openai_to_anthropic(
|
||||
request: OpenAIChatCompletionRequest,
|
||||
) -> AnthropicMessagesRequest:
|
||||
"""Convert OpenAI Chat Completions request to Anthropic Messages request."""
|
||||
messages = []
|
||||
system = None
|
||||
|
||||
for msg in request.messages:
|
||||
if msg.role == "system":
|
||||
# Extract system message
|
||||
system = msg.content if isinstance(msg.content, str) else None
|
||||
elif msg.role in ("user", "assistant"):
|
||||
# Convert content to Anthropic format
|
||||
if isinstance(msg.content, str):
|
||||
content = msg.content
|
||||
elif isinstance(msg.content, list):
|
||||
# Handle multi-modal content
|
||||
content = []
|
||||
for part in msg.content:
|
||||
if part.get("type") == "text":
|
||||
content.append(
|
||||
AnthropicContentBlock(type="text", text=part.get("text", ""))
|
||||
)
|
||||
elif part.get("type") == "image_url":
|
||||
# Convert OpenAI image format to Anthropic
|
||||
content.append(
|
||||
AnthropicContentBlock(
|
||||
type="image",
|
||||
source={
|
||||
"type": "url",
|
||||
"url": part.get("image_url", {}).get("url", ""),
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
content.append(
|
||||
AnthropicContentBlock(type="text", text=str(part))
|
||||
)
|
||||
else:
|
||||
content = str(msg.content) if msg.content else ""
|
||||
|
||||
messages.append(
|
||||
AnthropicMessage(
|
||||
role=msg.role, # type: ignore
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
|
||||
# Handle max_tokens
|
||||
max_tokens = request.max_tokens or 4096
|
||||
|
||||
return AnthropicMessagesRequest(
|
||||
model=request.model,
|
||||
messages=messages,
|
||||
system=system,
|
||||
max_tokens=max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stop_sequences=[request.stop] if isinstance(request.stop, str) else request.stop,
|
||||
stream=request.stream,
|
||||
tools=request.tools,
|
||||
tool_choice=request.tool_choice,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def anthropic_to_openai(
|
||||
request: AnthropicMessagesRequest,
|
||||
model: str | None = None,
|
||||
) -> OpenAIChatCompletionRequest:
|
||||
"""Convert Anthropic Messages request to OpenAI Chat Completions request."""
|
||||
messages = []
|
||||
|
||||
# Add system message if present
|
||||
if request.system:
|
||||
messages.append(
|
||||
OpenAIChatMessage(
|
||||
role="system",
|
||||
content=request.system,
|
||||
)
|
||||
)
|
||||
|
||||
# Convert messages
|
||||
for msg in request.messages:
|
||||
if isinstance(msg.content, str):
|
||||
content = msg.content
|
||||
elif isinstance(msg.content, list):
|
||||
# Convert Anthropic content blocks to OpenAI format
|
||||
content = []
|
||||
for block in msg.content:
|
||||
if hasattr(block, "type") and block.type == "text":
|
||||
content.append({"type": "text", "text": block.text or ""})
|
||||
elif isinstance(block, dict) and block.get("type") == "text":
|
||||
content.append({"type": "text", "text": block.get("text", "")})
|
||||
else:
|
||||
content.append({"type": "text", "text": str(block)})
|
||||
else:
|
||||
content = str(msg.content)
|
||||
|
||||
messages.append(
|
||||
OpenAIChatMessage(
|
||||
role=msg.role, # type: ignore
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
|
||||
return OpenAIChatCompletionRequest(
|
||||
model=model or request.model,
|
||||
messages=messages,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stream=request.stream,
|
||||
tools=request.tools,
|
||||
tool_choice=request.tool_choice,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def anthropic_response_to_openai(
|
||||
response: AnthropicMessagesResponse,
|
||||
model: str,
|
||||
) -> OpenAIChatCompletionResponse:
|
||||
"""Convert Anthropic Messages response to OpenAI Chat Completions response."""
|
||||
# Extract text content
|
||||
content = ""
|
||||
for block in response.content:
|
||||
if hasattr(block, "type") and block.type == "text":
|
||||
content += block.text or ""
|
||||
elif isinstance(block, dict) and block.get("type") == "text":
|
||||
content += block.get("text", "")
|
||||
|
||||
# Map stop reasons
|
||||
finish_reason_map = {
|
||||
"end_turn": "stop",
|
||||
"max_tokens": "length",
|
||||
"stop_sequence": "stop",
|
||||
"tool_use": "tool_calls",
|
||||
}
|
||||
finish_reason = finish_reason_map.get(response.stop_reason or "", "stop")
|
||||
|
||||
return OpenAIChatCompletionResponse(
|
||||
id=response.id,
|
||||
object="chat.completion",
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
choices=[
|
||||
OpenAIChatCompletionChoice(
|
||||
index=0,
|
||||
message=OpenAIChatMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
usage=OpenAIUsage(
|
||||
prompt_tokens=response.usage.input_tokens,
|
||||
completion_tokens=response.usage.output_tokens,
|
||||
total_tokens=response.usage.input_tokens + response.usage.output_tokens,
|
||||
)
|
||||
if response.usage
|
||||
else None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def openai_response_to_anthropic(
|
||||
response: OpenAIChatCompletionResponse,
|
||||
) -> AnthropicMessagesResponse:
|
||||
"""Convert OpenAI Chat Completions response to Anthropic Messages response."""
|
||||
# Extract content from choices
|
||||
content = []
|
||||
for choice in response.choices:
|
||||
if choice.message.content:
|
||||
content.append(
|
||||
AnthropicTextBlock(
|
||||
type="text",
|
||||
text=choice.message.content,
|
||||
)
|
||||
)
|
||||
|
||||
# Map finish reasons
|
||||
stop_reason_map = {
|
||||
"stop": "end_turn",
|
||||
"length": "max_tokens",
|
||||
"tool_calls": "tool_use",
|
||||
"content_filter": "end_turn",
|
||||
}
|
||||
stop_reason = stop_reason_map.get(
|
||||
response.choices[0].finish_reason or "", "end_turn"
|
||||
)
|
||||
|
||||
return AnthropicMessagesResponse(
|
||||
id=response.id or f"msg_{uuid.uuid4().hex[:24]}",
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=content,
|
||||
model=response.model,
|
||||
stop_reason=stop_reason,
|
||||
usage=AnthropicUsage(
|
||||
input_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
output_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
),
|
||||
)
|
||||
1
llm-gateway/app/db/__init__.py
Normal file
1
llm-gateway/app/db/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# db module
|
||||
70
llm-gateway/app/db/database.py
Normal file
70
llm-gateway/app/db/database.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""Database connection and session management."""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all models."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Create async engine
|
||||
def get_engine():
|
||||
"""Create database engine based on settings."""
|
||||
settings = get_settings()
|
||||
db_url = settings.database_url
|
||||
|
||||
# Ensure data directory exists for SQLite
|
||||
if db_url.startswith("sqlite:///"):
|
||||
db_path = db_url.replace("sqlite:///", "")
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
# Use aiosqlite for async SQLite
|
||||
db_url = db_url.replace("sqlite:///", "sqlite+aiosqlite:///")
|
||||
|
||||
return create_async_engine(db_url, echo=settings.debug)
|
||||
|
||||
|
||||
engine = get_engine()
|
||||
|
||||
# Session factory
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""Initialize database tables."""
|
||||
# Import all models to ensure they are registered
|
||||
from app.models import ( # noqa: F401
|
||||
provider,
|
||||
api_key,
|
||||
project,
|
||||
model_alias,
|
||||
usage,
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def get_db():
|
||||
"""Dependency to get database session."""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
61
llm-gateway/app/main.py
Normal file
61
llm-gateway/app/main.py
Normal file
@ -0,0 +1,61 @@
|
||||
"""FastAPI application entry point."""
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.admin import health, keys, models, projects, providers, usage
|
||||
from app.api.v1 import chat, messages, responses
|
||||
from app.config import get_settings
|
||||
from app.db.database import init_db
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan handler."""
|
||||
# Startup - skip if testing
|
||||
if not os.environ.get("TESTING"):
|
||||
await init_db()
|
||||
yield
|
||||
# Shutdown
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create and configure the FastAPI application."""
|
||||
settings = get_settings()
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
description="A unified LLM Gateway supporting multiple providers and API formats",
|
||||
version="0.1.0",
|
||||
debug=settings.debug,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Admin API routers
|
||||
app.include_router(providers.router, prefix="/admin")
|
||||
app.include_router(projects.router, prefix="/admin")
|
||||
app.include_router(keys.router, prefix="/admin")
|
||||
app.include_router(models.router, prefix="/admin")
|
||||
app.include_router(usage.router, prefix="/admin")
|
||||
app.include_router(health.router)
|
||||
|
||||
# LLM API routers
|
||||
app.include_router(chat.router)
|
||||
app.include_router(messages.router)
|
||||
app.include_router(responses.router)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
1
llm-gateway/app/middleware/__init__.py
Normal file
1
llm-gateway/app/middleware/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# middleware module
|
||||
17
llm-gateway/app/models/__init__.py
Normal file
17
llm-gateway/app/models/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""Models package."""
|
||||
from app.models.provider import Provider
|
||||
from app.models.project import Project
|
||||
from app.models.api_key import APIKey
|
||||
from app.models.model_alias import ModelAlias
|
||||
from app.models.usage import RequestLog, UsageStatsHourly, AuditLog, RateLimitCounter
|
||||
|
||||
__all__ = [
|
||||
"Provider",
|
||||
"Project",
|
||||
"APIKey",
|
||||
"ModelAlias",
|
||||
"RequestLog",
|
||||
"UsageStatsHourly",
|
||||
"AuditLog",
|
||||
"RateLimitCounter",
|
||||
]
|
||||
48
llm-gateway/app/models/api_key.py
Normal file
48
llm-gateway/app/models/api_key.py
Normal file
@ -0,0 +1,48 @@
|
||||
"""API Key model."""
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, Numeric, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class APIKey(Base):
|
||||
"""Virtual API Key for authentication and usage tracking."""
|
||||
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4()))
|
||||
key_hash: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True)
|
||||
key_prefix: Mapped[str] = mapped_column(String(20), nullable=False) # e.g., "sk-proj_abc..."
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
|
||||
# Project relationship
|
||||
project_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("projects.id"))
|
||||
project: Mapped["Project"] = relationship("Project", backref="api_keys")
|
||||
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
expires_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
|
||||
# Rate Limits
|
||||
rpm_limit: Mapped[int | None] = mapped_column(Integer)
|
||||
tpm_limit: Mapped[int | None] = mapped_column(Integer)
|
||||
|
||||
# Budget
|
||||
budget_limit: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
budget_period: Mapped[str | None] = mapped_column(String(20)) # daily, weekly, monthly
|
||||
|
||||
# Permissions
|
||||
allowed_models: Mapped[str | None] = mapped_column(Text) # JSON array
|
||||
|
||||
# Usage Stats
|
||||
current_usage: Mapped[Decimal] = mapped_column(Numeric(10, 2), default=Decimal("0"))
|
||||
total_requests: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
# Timestamps
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
36
llm-gateway/app/models/model_alias.py
Normal file
36
llm-gateway/app/models/model_alias.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""Model alias model."""
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Numeric, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class ModelAlias(Base):
|
||||
"""Model alias for routing and cost tracking."""
|
||||
|
||||
__tablename__ = "model_aliases"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4()))
|
||||
alias: Mapped[str] = mapped_column(String(100), unique=True, nullable=False, index=True)
|
||||
provider: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
model: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
# Routing: simple, load_balance, fallback
|
||||
routing_type: Mapped[str] = mapped_column(String(20), default="simple")
|
||||
routing_config: Mapped[str | None] = mapped_column(Text) # JSON
|
||||
|
||||
# Pricing (per 1K tokens)
|
||||
input_price_per_1k: Mapped[Decimal | None] = mapped_column(Numeric(10, 6))
|
||||
output_price_per_1k: Mapped[Decimal | None] = mapped_column(Numeric(10, 6))
|
||||
|
||||
# Timestamps
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
31
llm-gateway/app/models/project.py
Normal file
31
llm-gateway/app/models/project.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""Project model."""
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Numeric, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class Project(Base):
|
||||
"""Project for organizing API keys and tracking usage."""
|
||||
|
||||
__tablename__ = "projects"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4()))
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
|
||||
description: Mapped[str | None] = mapped_column(Text)
|
||||
|
||||
# Budget
|
||||
budget_limit: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
|
||||
budget_period: Mapped[str | None] = mapped_column(String(20)) # daily, weekly, monthly
|
||||
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
# Timestamps
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
40
llm-gateway/app/models/provider.py
Normal file
40
llm-gateway/app/models/provider.py
Normal file
@ -0,0 +1,40 @@
|
||||
"""Provider model."""
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Integer, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class Provider(Base):
|
||||
"""LLM Provider configuration."""
|
||||
|
||||
__tablename__ = "providers"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4()))
|
||||
name: Mapped[str] = mapped_column(String(50), unique=True, nullable=False, index=True)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
# API Configuration
|
||||
api_base: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
api_key_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
api_version: Mapped[str | None] = mapped_column(String(50))
|
||||
|
||||
# Provider-specific config (JSON)
|
||||
config: Mapped[str | None] = mapped_column(Text)
|
||||
|
||||
# Rate Limits
|
||||
rpm_limit: Mapped[int | None] = mapped_column(Integer)
|
||||
tpm_limit: Mapped[int | None] = mapped_column(Integer)
|
||||
|
||||
# Health Status
|
||||
health_status: Mapped[str] = mapped_column(String(20), default="healthy")
|
||||
last_health_check: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
|
||||
# Timestamps
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
110
llm-gateway/app/models/usage.py
Normal file
110
llm-gateway/app/models/usage.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""Usage tracking models."""
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import BigInteger, DateTime, ForeignKey, Integer, Numeric, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class RequestLog(Base):
|
||||
"""Request log for tracking individual API calls."""
|
||||
|
||||
__tablename__ = "request_logs"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4()))
|
||||
timestamp: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, index=True)
|
||||
|
||||
# Keys and Projects
|
||||
virtual_key_id: Mapped[str | None] = mapped_column(
|
||||
String(36), ForeignKey("api_keys.id"), index=True
|
||||
)
|
||||
project_id: Mapped[str | None] = mapped_column(
|
||||
String(36), ForeignKey("projects.id"), index=True
|
||||
)
|
||||
|
||||
# Request details
|
||||
provider: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
model: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
model_alias: Mapped[str | None] = mapped_column(String(100))
|
||||
request_type: Mapped[str] = mapped_column(String(20)) # chat, completion, embedding
|
||||
|
||||
# Tokens
|
||||
input_tokens: Mapped[int] = mapped_column(Integer)
|
||||
output_tokens: Mapped[int] = mapped_column(Integer)
|
||||
total_tokens: Mapped[int] = mapped_column(Integer)
|
||||
|
||||
# Response
|
||||
status_code: Mapped[int] = mapped_column(Integer)
|
||||
latency_ms: Mapped[int] = mapped_column(Integer)
|
||||
finish_reason: Mapped[str | None] = mapped_column(String(50))
|
||||
|
||||
# Cost
|
||||
cost_usd: Mapped[Decimal] = mapped_column(Numeric(10, 6))
|
||||
|
||||
# Request metadata (JSON)
|
||||
request_metadata: Mapped[str | None] = mapped_column(Text)
|
||||
|
||||
|
||||
class UsageStatsHourly(Base):
|
||||
"""Hourly aggregated usage statistics."""
|
||||
|
||||
__tablename__ = "usage_stats_hourly"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
timestamp: Mapped[datetime] = mapped_column(DateTime, nullable=False, index=True)
|
||||
|
||||
virtual_key_id: Mapped[str | None] = mapped_column(String(36), index=True)
|
||||
project_id: Mapped[str | None] = mapped_column(String(36), index=True)
|
||||
provider: Mapped[str] = mapped_column(String(50))
|
||||
model: Mapped[str] = mapped_column(String(100))
|
||||
|
||||
# Aggregates
|
||||
request_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
input_tokens: Mapped[int] = mapped_column(BigInteger, default=0)
|
||||
output_tokens: Mapped[int] = mapped_column(BigInteger, default=0)
|
||||
total_tokens: Mapped[int] = mapped_column(BigInteger, default=0)
|
||||
cost_usd: Mapped[Decimal] = mapped_column(Numeric(10, 6), default=Decimal("0"))
|
||||
|
||||
avg_latency_ms: Mapped[int | None] = mapped_column(Integer)
|
||||
error_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
"""Audit log for tracking admin operations."""
|
||||
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4()))
|
||||
timestamp: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, index=True)
|
||||
|
||||
# Who
|
||||
actor: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
|
||||
# What
|
||||
action: Mapped[str] = mapped_column(String(50), nullable=False) # create, update, delete
|
||||
|
||||
# Which
|
||||
resource: Mapped[str] = mapped_column(String(50), nullable=False) # provider, key, project
|
||||
resource_id: Mapped[str | None] = mapped_column(String(36))
|
||||
|
||||
# Changes
|
||||
changes: Mapped[str | None] = mapped_column(Text) # JSON with before/after
|
||||
|
||||
# Request info
|
||||
ip_address: Mapped[str | None] = mapped_column(String(50))
|
||||
user_agent: Mapped[str | None] = mapped_column(String(500))
|
||||
|
||||
|
||||
class RateLimitCounter(Base):
|
||||
"""Rate limit counter for tracking usage windows."""
|
||||
|
||||
__tablename__ = "rate_limit_counters"
|
||||
|
||||
key: Mapped[str] = mapped_column(String(200), primary_key=True)
|
||||
request_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
token_count: Mapped[int] = mapped_column(BigInteger, default=0)
|
||||
window_start: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
window_duration: Mapped[int] = mapped_column(Integer, nullable=False) # seconds
|
||||
1
llm-gateway/app/schemas/__init__.py
Normal file
1
llm-gateway/app/schemas/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# schemas module
|
||||
107
llm-gateway/app/schemas/anthropic.py
Normal file
107
llm-gateway/app/schemas/anthropic.py
Normal file
@ -0,0 +1,107 @@
|
||||
"""Anthropic Messages API request and response schemas."""
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# === Request Models ===
|
||||
|
||||
|
||||
class AnthropicContentBlock(BaseModel):
|
||||
"""Anthropic content block."""
|
||||
|
||||
type: Literal["text", "image", "tool_use", "tool_result"]
|
||||
text: str | None = None
|
||||
source: dict[str, Any] | None = None # For images
|
||||
tool_use_id: str | None = None # For tool_result
|
||||
content: str | list[dict[str, Any]] | None = None # For tool_result
|
||||
name: str | None = None # For tool_use
|
||||
input: dict[str, Any] | None = None # For tool_use
|
||||
|
||||
|
||||
class AnthropicMessage(BaseModel):
|
||||
"""Anthropic message."""
|
||||
|
||||
role: Literal["user", "assistant"]
|
||||
content: str | list[AnthropicContentBlock]
|
||||
|
||||
|
||||
class AnthropicMessagesRequest(BaseModel):
|
||||
"""Anthropic Messages API request."""
|
||||
|
||||
model: str
|
||||
messages: list[AnthropicMessage]
|
||||
system: str | None = None
|
||||
max_tokens: int = Field(..., ge=1)
|
||||
temperature: float | None = Field(None, ge=0, le=1)
|
||||
top_p: float | None = Field(None, ge=0, le=1)
|
||||
top_k: int | None = Field(None, ge=0)
|
||||
stop_sequences: list[str] | None = None
|
||||
stream: bool = False
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
tool_choice: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# === Response Models ===
|
||||
|
||||
|
||||
class AnthropicTextBlock(BaseModel):
|
||||
"""Anthropic text content block."""
|
||||
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
class AnthropicToolUseBlock(BaseModel):
|
||||
"""Anthropic tool use content block."""
|
||||
|
||||
type: Literal["tool_use"] = "tool_use"
|
||||
id: str
|
||||
name: str
|
||||
input: dict[str, Any]
|
||||
|
||||
|
||||
class AnthropicUsage(BaseModel):
|
||||
"""Anthropic token usage."""
|
||||
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
cache_creation_input_tokens: int | None = None
|
||||
cache_read_input_tokens: int | None = None
|
||||
|
||||
|
||||
class AnthropicMessagesResponse(BaseModel):
|
||||
"""Anthropic Messages API response."""
|
||||
|
||||
id: str
|
||||
type: Literal["message"] = "message"
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: list[AnthropicTextBlock | AnthropicToolUseBlock]
|
||||
model: str
|
||||
stop_reason: Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"] | None = None
|
||||
stop_sequence: str | None = None
|
||||
usage: AnthropicUsage
|
||||
|
||||
|
||||
class AnthropicStreamMessage(BaseModel):
|
||||
"""Anthropic streaming message start."""
|
||||
|
||||
type: Literal["message_start"] = "message_start"
|
||||
message: AnthropicMessagesResponse
|
||||
|
||||
|
||||
class AnthropicContentBlockDelta(BaseModel):
|
||||
"""Anthropic streaming content block delta."""
|
||||
|
||||
type: Literal["content_block_delta"] = "content_block_delta"
|
||||
index: int
|
||||
delta: dict[str, Any]
|
||||
|
||||
|
||||
class AnthropicMessageDelta(BaseModel):
|
||||
"""Anthropic streaming message delta."""
|
||||
|
||||
type: Literal["message_delta"] = "message_delta"
|
||||
delta: dict[str, Any]
|
||||
usage: AnthropicUsage | None = None
|
||||
71
llm-gateway/app/schemas/api_key.py
Normal file
71
llm-gateway/app/schemas/api_key.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""API Key schemas for API requests and responses."""
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.schemas.common import BaseResponse
|
||||
|
||||
|
||||
class APIKeyCreate(BaseModel):
|
||||
"""Schema for creating an API key."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
project_id: str | None = Field(None, description="Associated project ID")
|
||||
prefix: str = Field(default="sk", max_length=10, description="Key prefix")
|
||||
expires_at: datetime | None = Field(None, description="Expiration date")
|
||||
rpm_limit: int | None = Field(None, ge=1)
|
||||
tpm_limit: int | None = Field(None, ge=1)
|
||||
budget_limit: Decimal | None = Field(None, ge=0)
|
||||
budget_period: str | None = Field(None, pattern="^(daily|weekly|monthly)$")
|
||||
allowed_models: list[str] | None = Field(None, description="List of allowed model aliases")
|
||||
enabled: bool = Field(True)
|
||||
|
||||
|
||||
class APIKeyUpdate(BaseModel):
|
||||
"""Schema for updating an API key."""
|
||||
|
||||
name: str | None = Field(None, min_length=1, max_length=100)
|
||||
project_id: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
rpm_limit: int | None = Field(None, ge=1)
|
||||
tpm_limit: int | None = Field(None, ge=1)
|
||||
budget_limit: Decimal | None = Field(None, ge=0)
|
||||
budget_period: str | None = Field(None, pattern="^(daily|weekly|monthly)$")
|
||||
allowed_models: list[str] | None = None
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class APIKeyResponse(BaseResponse):
|
||||
"""Schema for API key response."""
|
||||
|
||||
id: str
|
||||
key_prefix: str
|
||||
name: str
|
||||
project_id: str | None
|
||||
enabled: bool
|
||||
expires_at: datetime | None
|
||||
rpm_limit: int | None
|
||||
tpm_limit: int | None
|
||||
budget_limit: Decimal | None
|
||||
budget_period: str | None
|
||||
allowed_models: list[str] | None
|
||||
current_usage: Decimal
|
||||
total_requests: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class APIKeyCreateResponse(APIKeyResponse):
|
||||
"""Schema for API key creation response (includes the full key)."""
|
||||
|
||||
key: str = Field(..., description="Full API key (shown only once)")
|
||||
|
||||
|
||||
class APIKeyListResponse(BaseResponse):
|
||||
"""Schema for API key list response."""
|
||||
|
||||
keys: list[APIKeyResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
30
llm-gateway/app/schemas/common.py
Normal file
30
llm-gateway/app/schemas/common.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""Common schemas for API requests and responses."""
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
"""Base response model."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response model."""
|
||||
|
||||
error: dict[str, Any]
|
||||
|
||||
|
||||
class PaginationParams(BaseModel):
|
||||
"""Pagination parameters."""
|
||||
|
||||
page: int = Field(default=1, ge=1)
|
||||
page_size: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
@property
|
||||
def offset(self) -> int:
|
||||
"""Calculate offset for database query."""
|
||||
return (self.page - 1) * self.page_size
|
||||
71
llm-gateway/app/schemas/model_alias.py
Normal file
71
llm-gateway/app/schemas/model_alias.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""Model alias schemas for API requests and responses."""
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.schemas.common import BaseResponse
|
||||
|
||||
|
||||
class RoutingConfig(BaseModel):
|
||||
"""Routing configuration for load balancing and fallback."""
|
||||
|
||||
providers: list[dict[str, Any]] | None = Field(
|
||||
None, description="List of providers with weights for load balancing"
|
||||
)
|
||||
primary: dict[str, str] | None = Field(None, description="Primary provider for fallback")
|
||||
fallback: list[dict[str, str]] | None = Field(None, description="Fallback providers")
|
||||
|
||||
|
||||
class ModelAliasCreate(BaseModel):
|
||||
"""Schema for creating a model alias."""
|
||||
|
||||
alias: str = Field(..., min_length=1, max_length=100, description="User-facing model name")
|
||||
provider: str = Field(..., min_length=1, max_length=50, description="Provider name")
|
||||
model: str = Field(..., min_length=1, max_length=100, description="Actual model name")
|
||||
routing_type: str = Field(
|
||||
default="simple", pattern="^(simple|load_balance|fallback)$"
|
||||
)
|
||||
routing_config: RoutingConfig | None = None
|
||||
input_price_per_1k: Decimal | None = Field(None, ge=0)
|
||||
output_price_per_1k: Decimal | None = Field(None, ge=0)
|
||||
enabled: bool = Field(True)
|
||||
|
||||
|
||||
class ModelAliasUpdate(BaseModel):
|
||||
"""Schema for updating a model alias."""
|
||||
|
||||
alias: str | None = Field(None, min_length=1, max_length=100)
|
||||
provider: str | None = Field(None, min_length=1, max_length=50)
|
||||
model: str | None = Field(None, min_length=1, max_length=100)
|
||||
routing_type: str | None = Field(None, pattern="^(simple|load_balance|fallback)$")
|
||||
routing_config: RoutingConfig | None = None
|
||||
input_price_per_1k: Decimal | None = Field(None, ge=0)
|
||||
output_price_per_1k: Decimal | None = Field(None, ge=0)
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class ModelAliasResponse(BaseResponse):
|
||||
"""Schema for model alias response."""
|
||||
|
||||
id: str
|
||||
alias: str
|
||||
provider: str
|
||||
model: str
|
||||
enabled: bool
|
||||
routing_type: str
|
||||
routing_config: dict[str, Any] | None
|
||||
input_price_per_1k: Decimal | None
|
||||
output_price_per_1k: Decimal | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class ModelAliasListResponse(BaseResponse):
|
||||
"""Schema for model alias list response."""
|
||||
|
||||
aliases: list[ModelAliasResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
100
llm-gateway/app/schemas/openai.py
Normal file
100
llm-gateway/app/schemas/openai.py
Normal file
@ -0,0 +1,100 @@
|
||||
"""OpenAI API request and response schemas."""
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# === Request Models ===
|
||||
|
||||
|
||||
class OpenAIChatMessage(BaseModel):
|
||||
"""OpenAI chat message."""
|
||||
|
||||
role: Literal["system", "user", "assistant", "tool", "function"]
|
||||
content: str | list[dict[str, Any]] | None = None
|
||||
name: str | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
tool_call_id: str | None = None
|
||||
|
||||
|
||||
class OpenAIChatCompletionRequest(BaseModel):
|
||||
"""OpenAI Chat Completions API request."""
|
||||
|
||||
model: str
|
||||
messages: list[OpenAIChatMessage]
|
||||
temperature: float | None = Field(None, ge=0, le=2)
|
||||
top_p: float | None = Field(None, ge=0, le=1)
|
||||
n: int | None = Field(None, ge=1)
|
||||
stream: bool = False
|
||||
stop: str | list[str] | None = None
|
||||
max_tokens: int | None = Field(None, ge=1)
|
||||
presence_penalty: float | None = Field(None, ge=-2, le=2)
|
||||
frequency_penalty: float | None = Field(None, ge=-2, le=2)
|
||||
logit_bias: dict[str, float] | None = None
|
||||
user: str | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
response_format: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class OpenAIResponseRequest(BaseModel):
|
||||
"""OpenAI Responses API request (new format)."""
|
||||
|
||||
model: str
|
||||
input: str | list[dict[str, Any]]
|
||||
instructions: str | None = None
|
||||
temperature: float | None = Field(None, ge=0, le=2)
|
||||
max_output_tokens: int | None = Field(None, ge=1)
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# === Response Models ===
|
||||
|
||||
|
||||
class OpenAIChatCompletionChoice(BaseModel):
|
||||
"""OpenAI chat completion choice."""
|
||||
|
||||
index: int
|
||||
message: OpenAIChatMessage
|
||||
finish_reason: str | None
|
||||
|
||||
|
||||
class OpenAIUsage(BaseModel):
|
||||
"""Token usage information."""
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class OpenAIChatCompletionResponse(BaseModel):
|
||||
"""OpenAI Chat Completions API response."""
|
||||
|
||||
id: str
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int
|
||||
model: str
|
||||
choices: list[OpenAIChatCompletionChoice]
|
||||
usage: OpenAIUsage | None = None
|
||||
system_fingerprint: str | None = None
|
||||
|
||||
|
||||
class OpenAIStreamChoice(BaseModel):
|
||||
"""OpenAI streaming choice."""
|
||||
|
||||
index: int
|
||||
delta: dict[str, Any]
|
||||
finish_reason: str | None
|
||||
|
||||
|
||||
class OpenAIChatCompletionChunk(BaseModel):
|
||||
"""OpenAI Chat Completions streaming chunk."""
|
||||
|
||||
id: str
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int
|
||||
model: str
|
||||
choices: list[OpenAIStreamChoice]
|
||||
system_fingerprint: str | None = None
|
||||
49
llm-gateway/app/schemas/project.py
Normal file
49
llm-gateway/app/schemas/project.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""Project schemas for API requests and responses."""
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.schemas.common import BaseResponse
|
||||
|
||||
|
||||
class ProjectCreate(BaseModel):
|
||||
"""Schema for creating a project."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
description: str | None = Field(None, max_length=1000)
|
||||
budget_limit: Decimal | None = Field(None, ge=0)
|
||||
budget_period: str | None = Field(None, pattern="^(daily|weekly|monthly)$")
|
||||
enabled: bool = Field(True)
|
||||
|
||||
|
||||
class ProjectUpdate(BaseModel):
|
||||
"""Schema for updating a project."""
|
||||
|
||||
name: str | None = Field(None, min_length=1, max_length=100)
|
||||
description: str | None = Field(None, max_length=1000)
|
||||
budget_limit: Decimal | None = Field(None, ge=0)
|
||||
budget_period: str | None = Field(None, pattern="^(daily|weekly|monthly)$")
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class ProjectResponse(BaseResponse):
|
||||
"""Schema for project response."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str | None
|
||||
budget_limit: Decimal | None
|
||||
budget_period: str | None
|
||||
enabled: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class ProjectListResponse(BaseResponse):
|
||||
"""Schema for project list response."""
|
||||
|
||||
projects: list[ProjectResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
58
llm-gateway/app/schemas/provider.py
Normal file
58
llm-gateway/app/schemas/provider.py
Normal file
@ -0,0 +1,58 @@
|
||||
"""Provider schemas for API requests and responses."""
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.schemas.common import BaseResponse
|
||||
|
||||
|
||||
class ProviderCreate(BaseModel):
|
||||
"""Schema for creating a provider."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=50, description="Provider name")
|
||||
api_base: str = Field(..., min_length=1, max_length=500, description="API base URL")
|
||||
api_key: str = Field(..., min_length=1, description="API key (will be encrypted)")
|
||||
api_version: str | None = Field(None, max_length=50, description="API version")
|
||||
config: dict[str, Any] | None = Field(None, description="Provider-specific configuration")
|
||||
rpm_limit: int | None = Field(None, ge=1, description="Requests per minute limit")
|
||||
tpm_limit: int | None = Field(None, ge=1, description="Tokens per minute limit")
|
||||
enabled: bool = Field(True, description="Whether provider is enabled")
|
||||
|
||||
|
||||
class ProviderUpdate(BaseModel):
|
||||
"""Schema for updating a provider."""
|
||||
|
||||
api_base: str | None = Field(None, min_length=1, max_length=500)
|
||||
api_key: str | None = Field(None, min_length=1)
|
||||
api_version: str | None = Field(None, max_length=50)
|
||||
config: dict[str, Any] | None = None
|
||||
rpm_limit: int | None = Field(None, ge=1)
|
||||
tpm_limit: int | None = Field(None, ge=1)
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class ProviderResponse(BaseResponse):
|
||||
"""Schema for provider response."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
enabled: bool
|
||||
api_base: str
|
||||
api_version: str | None
|
||||
config: dict[str, Any] | None
|
||||
rpm_limit: int | None
|
||||
tpm_limit: int | None
|
||||
health_status: str
|
||||
last_health_check: datetime | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class ProviderListResponse(BaseResponse):
|
||||
"""Schema for provider list response."""
|
||||
|
||||
providers: list[ProviderResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
1
llm-gateway/app/utils/__init__.py
Normal file
1
llm-gateway/app/utils/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# utils module
|
||||
123
llm-gateway/app/utils/crypto.py
Normal file
123
llm-gateway/app/utils/crypto.py
Normal file
@ -0,0 +1,123 @@
|
||||
"""Cryptographic utilities for key management."""
|
||||
import base64
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import bcrypt
|
||||
from cryptography.fernet import Fernet
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
|
||||
def hash_api_key(key: str) -> str:
|
||||
"""
|
||||
Hash an API key using bcrypt.
|
||||
|
||||
Args:
|
||||
key: The plaintext API key to hash.
|
||||
|
||||
Returns:
|
||||
The hashed key string.
|
||||
"""
|
||||
salt = bcrypt.gensalt()
|
||||
hashed = bcrypt.hashpw(key.encode("utf-8"), salt)
|
||||
return hashed.decode("utf-8")
|
||||
|
||||
|
||||
def verify_api_key(key: str, hashed: str) -> bool:
|
||||
"""
|
||||
Verify an API key against its hash.
|
||||
|
||||
Args:
|
||||
key: The plaintext API key to verify.
|
||||
hashed: The hashed key to compare against.
|
||||
|
||||
Returns:
|
||||
True if the key matches, False otherwise.
|
||||
"""
|
||||
try:
|
||||
return bcrypt.checkpw(key.encode("utf-8"), hashed.encode("utf-8"))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def generate_api_key(prefix: str = "sk") -> Tuple[str, str, str]:
|
||||
"""
|
||||
Generate a new API key.
|
||||
|
||||
Args:
|
||||
prefix: The prefix for the key (e.g., "sk", "sk-proj").
|
||||
|
||||
Returns:
|
||||
Tuple of (full_key, key_hash, key_prefix).
|
||||
"""
|
||||
# Generate random key part
|
||||
random_part = base64.urlsafe_b64encode(os.urandom(24)).decode("utf-8").rstrip("=")
|
||||
full_key = f"{prefix}_{random_part}"
|
||||
|
||||
# Hash the key
|
||||
key_hash = hash_api_key(full_key)
|
||||
|
||||
# Create prefix for display (first 12 chars)
|
||||
key_prefix = full_key[:12] + "..."
|
||||
|
||||
return full_key, key_hash, key_prefix
|
||||
|
||||
|
||||
def get_encryption_key() -> bytes:
|
||||
"""
|
||||
Derive an encryption key from the master key.
|
||||
|
||||
Returns:
|
||||
A URL-safe base64-encoded 32-byte key for Fernet.
|
||||
"""
|
||||
settings = get_settings()
|
||||
master_key = settings.master_key.encode("utf-8")
|
||||
|
||||
# Use a fixed salt for deterministic key derivation
|
||||
# In production, consider using a random salt stored securely
|
||||
salt = b"llm-gateway-encryption-salt-v1"
|
||||
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
)
|
||||
|
||||
key = base64.urlsafe_b64encode(kdf.derive(master_key))
|
||||
return key
|
||||
|
||||
|
||||
def encrypt_value(plaintext: str) -> str:
|
||||
"""
|
||||
Encrypt a value using Fernet symmetric encryption.
|
||||
|
||||
Args:
|
||||
plaintext: The value to encrypt.
|
||||
|
||||
Returns:
|
||||
The encrypted value as a base64 string.
|
||||
"""
|
||||
key = get_encryption_key()
|
||||
f = Fernet(key)
|
||||
encrypted = f.encrypt(plaintext.encode("utf-8"))
|
||||
return encrypted.decode("utf-8")
|
||||
|
||||
|
||||
def decrypt_value(encrypted: str) -> str:
|
||||
"""
|
||||
Decrypt a value using Fernet symmetric encryption.
|
||||
|
||||
Args:
|
||||
encrypted: The encrypted value as a base64 string.
|
||||
|
||||
Returns:
|
||||
The decrypted plaintext.
|
||||
"""
|
||||
key = get_encryption_key()
|
||||
f = Fernet(key)
|
||||
decrypted = f.decrypt(encrypted.encode("utf-8"))
|
||||
return decrypted.decode("utf-8")
|
||||
65
llm-gateway/app/utils/logging.py
Normal file
65
llm-gateway/app/utils/logging.py
Normal file
@ -0,0 +1,65 @@
|
||||
"""Structured logging configuration."""
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import structlog
|
||||
from structlog.types import Processor
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
"""Configure structured logging for the application."""
|
||||
settings = get_settings()
|
||||
|
||||
# Shared processors for both stdlib and structlog
|
||||
shared_processors: list[Processor] = [
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.stdlib.add_log_level,
|
||||
structlog.stdlib.add_logger_name,
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
]
|
||||
|
||||
# Configure structlog
|
||||
structlog.configure(
|
||||
processors=shared_processors
|
||||
+ [
|
||||
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
||||
],
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
wrapper_class=structlog.stdlib.BoundLogger,
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
|
||||
# Configure stdlib logging
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
|
||||
if settings.debug or settings.log_level == "DEBUG":
|
||||
formatter = structlog.stdlib.ProcessorFormatter(
|
||||
foreign_pre_chain=shared_processors,
|
||||
processors=[
|
||||
structlog.dev.ConsoleRenderer(colors=True),
|
||||
],
|
||||
)
|
||||
else:
|
||||
formatter = structlog.stdlib.ProcessorFormatter(
|
||||
foreign_pre_chain=shared_processors,
|
||||
processors=[
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.JSONRenderer(),
|
||||
],
|
||||
)
|
||||
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
# Clear existing handlers and add our handler
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.handlers.clear()
|
||||
root_logger.addHandler(handler)
|
||||
root_logger.setLevel(getattr(logging, settings.log_level.upper(), logging.INFO))
|
||||
|
||||
|
||||
def get_logger(name: str | None = None) -> structlog.stdlib.BoundLogger:
|
||||
"""Get a structured logger instance."""
|
||||
return structlog.get_logger(name)
|
||||
33
llm-gateway/docker-compose.yml
Normal file
33
llm-gateway/docker-compose.yml
Normal file
@ -0,0 +1,33 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
gateway:
|
||||
build: .
|
||||
ports:
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
environment:
|
||||
- DATABASE_URL=sqlite:///data/gateway.db
|
||||
- MASTER_KEY=${MASTER_KEY}
|
||||
- DEBUG=${DEBUG:-false}
|
||||
- LOG_LEVEL=${LOG_LEVEL:-INFO}
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
|
||||
# Optional: Redis for distributed rate limiting (future)
|
||||
# redis:
|
||||
# image: redis:7-alpine
|
||||
# ports:
|
||||
# - "6379:6379"
|
||||
# volumes:
|
||||
# - redis_data:/data
|
||||
|
||||
volumes:
|
||||
gateway_data:
|
||||
# redis_data:
|
||||
37
llm-gateway/requirements.txt
Normal file
37
llm-gateway/requirements.txt
Normal file
@ -0,0 +1,37 @@
|
||||
# Web Framework
|
||||
fastapi>=0.109.0
|
||||
uvicorn[standard]>=0.27.0
|
||||
pydantic>=2.5.0
|
||||
pydantic-settings>=2.1.0
|
||||
|
||||
# HTTP Client
|
||||
httpx>=0.26.0
|
||||
|
||||
# Database
|
||||
sqlalchemy>=2.0.0
|
||||
aiosqlite>=0.19.0
|
||||
|
||||
# Security
|
||||
bcrypt>=4.1.0
|
||||
cryptography>=42.0.0
|
||||
|
||||
# Logging
|
||||
structlog>=24.1.0
|
||||
|
||||
# Testing
|
||||
pytest>=8.0.0
|
||||
pytest-asyncio>=0.23.0
|
||||
pytest-cov>=4.1.0
|
||||
httpx>=0.26.0
|
||||
|
||||
# Type Checking
|
||||
mypy>=1.8.0
|
||||
|
||||
# Code Quality
|
||||
ruff>=0.2.0
|
||||
|
||||
# Provider SDKs
|
||||
openai>=1.12.0
|
||||
anthropic>=0.18.0
|
||||
google-generativeai>=0.3.0
|
||||
boto3>=1.34.0
|
||||
1
llm-gateway/tests/__init__.py
Normal file
1
llm-gateway/tests/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# tests module
|
||||
106
llm-gateway/tests/conftest.py
Normal file
106
llm-gateway/tests/conftest.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""Pytest configuration and fixtures."""
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
from typing import AsyncGenerator, Generator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
# Set test environment before importing app
|
||||
os.environ["TESTING"] = "1"
|
||||
|
||||
# Create a temp file for the test database
|
||||
_test_db_fd, _test_db_path = tempfile.mkstemp(suffix=".db")
|
||||
os.environ["DATABASE_URL"] = f"sqlite:///{_test_db_path}"
|
||||
os.environ["MASTER_KEY"] = "test-master-key-for-testing-32cha"
|
||||
|
||||
# Now import the app modules
|
||||
from app.db import database
|
||||
from app.db.database import Base, get_db
|
||||
|
||||
# Import all models to register them with Base
|
||||
from app.models import ( # noqa: F401
|
||||
APIKey,
|
||||
AuditLog,
|
||||
ModelAlias,
|
||||
Project,
|
||||
Provider,
|
||||
RateLimitCounter,
|
||||
RequestLog,
|
||||
UsageStatsHourly,
|
||||
)
|
||||
|
||||
# Create test engine
|
||||
_test_db_url = f"sqlite+aiosqlite:///{_test_db_path}"
|
||||
_test_engine = create_async_engine(_test_db_url, echo=False)
|
||||
_test_session_factory = async_sessionmaker(
|
||||
_test_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
# Override the global engine and session factory
|
||||
database.engine = _test_engine
|
||||
database.AsyncSessionLocal = _test_session_factory
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
|
||||
"""Create an event loop for the test session."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def setup_test_db() -> AsyncGenerator[None, None]:
|
||||
"""Set up test database with all tables."""
|
||||
# Create all tables
|
||||
async with _test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield
|
||||
|
||||
# Clean up
|
||||
async with _test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
await _test_engine.dispose()
|
||||
os.close(_test_db_fd)
|
||||
os.unlink(_test_db_path)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session(setup_test_db: None) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Provide a database session for tests."""
|
||||
async with _test_session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def _get_test_db():
|
||||
"""Override dependency for test database."""
|
||||
async with _test_session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(setup_test_db: None) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Provide an async HTTP client for API tests."""
|
||||
from app.main import create_app
|
||||
|
||||
test_app = create_app()
|
||||
|
||||
# Override the database dependency
|
||||
test_app.dependency_overrides[get_db] = _get_test_db
|
||||
|
||||
transport = ASGITransport(app=test_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
1
llm-gateway/tests/integration/__init__.py
Normal file
1
llm-gateway/tests/integration/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# integration module
|
||||
237
llm-gateway/tests/integration/test_admin_api.py
Normal file
237
llm-gateway/tests/integration/test_admin_api.py
Normal file
@ -0,0 +1,237 @@
|
||||
"""Tests for Admin API endpoints."""
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
class TestProviderAPI:
|
||||
"""Test Provider management API."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_provider(self, client: AsyncClient, setup_test_db):
|
||||
"""Test creating a provider."""
|
||||
response = await client.post(
|
||||
"/admin/providers",
|
||||
json={
|
||||
"name": "openai",
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"api_key": "test-api-key-1234567890",
|
||||
"api_version": "v1",
|
||||
"rpm_limit": 1000,
|
||||
"tpm_limit": 100000,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == "openai"
|
||||
assert data["api_base"] == "https://api.openai.com/v1"
|
||||
assert data["health_status"] == "healthy"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_providers(self, client: AsyncClient, setup_test_db):
|
||||
"""Test listing providers."""
|
||||
# Create a provider first
|
||||
await client.post(
|
||||
"/admin/providers",
|
||||
json={
|
||||
"name": "anthropic",
|
||||
"api_base": "https://api.anthropic.com",
|
||||
"api_key": "test-key",
|
||||
},
|
||||
)
|
||||
|
||||
response = await client.get("/admin/providers")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 1
|
||||
assert len(data["providers"]) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_provider(self, client: AsyncClient, setup_test_db):
|
||||
"""Test getting a specific provider."""
|
||||
create_response = await client.post(
|
||||
"/admin/providers",
|
||||
json={
|
||||
"name": "gemini",
|
||||
"api_base": "https://generativelanguage.googleapis.com",
|
||||
"api_key": "test-key",
|
||||
},
|
||||
)
|
||||
provider_id = create_response.json()["id"]
|
||||
|
||||
response = await client.get(f"/admin/providers/{provider_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "gemini"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_provider(self, client: AsyncClient, setup_test_db):
|
||||
"""Test updating a provider."""
|
||||
create_response = await client.post(
|
||||
"/admin/providers",
|
||||
json={
|
||||
"name": "azure",
|
||||
"api_base": "https://example.openai.azure.com",
|
||||
"api_key": "test-key",
|
||||
},
|
||||
)
|
||||
provider_id = create_response.json()["id"]
|
||||
|
||||
response = await client.put(
|
||||
f"/admin/providers/{provider_id}",
|
||||
json={"rpm_limit": 500, "enabled": False},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["rpm_limit"] == 500
|
||||
assert data["enabled"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_provider(self, client: AsyncClient, setup_test_db):
|
||||
"""Test deleting a provider."""
|
||||
create_response = await client.post(
|
||||
"/admin/providers",
|
||||
json={
|
||||
"name": "bedrock",
|
||||
"api_base": "https://bedrock.us-east-1.amazonaws.com",
|
||||
"api_key": "test-key",
|
||||
},
|
||||
)
|
||||
provider_id = create_response.json()["id"]
|
||||
|
||||
response = await client.delete(f"/admin/providers/{provider_id}")
|
||||
assert response.status_code == 204
|
||||
|
||||
# Verify it's deleted
|
||||
get_response = await client.get(f"/admin/providers/{provider_id}")
|
||||
assert get_response.status_code == 404
|
||||
|
||||
|
||||
class TestProjectAPI:
|
||||
"""Test Project management API."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_project(self, client: AsyncClient, setup_test_db):
|
||||
"""Test creating a project."""
|
||||
response = await client.post(
|
||||
"/admin/projects",
|
||||
json={
|
||||
"name": "Test Project",
|
||||
"description": "A test project",
|
||||
"budget_limit": 100.00,
|
||||
"budget_period": "monthly",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Project"
|
||||
assert data["budget_limit"] == "100.00"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_projects(self, client: AsyncClient, setup_test_db):
|
||||
"""Test listing projects."""
|
||||
await client.post(
|
||||
"/admin/projects",
|
||||
json={"name": "Project 1"},
|
||||
)
|
||||
|
||||
response = await client.get("/admin/projects")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 1
|
||||
|
||||
|
||||
class TestAPIKeyAPI:
|
||||
"""Test API Key management API."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_api_key(self, client: AsyncClient, setup_test_db):
|
||||
"""Test creating an API key."""
|
||||
response = await client.post(
|
||||
"/admin/keys",
|
||||
json={
|
||||
"name": "Test Key",
|
||||
"prefix": "sk-test",
|
||||
"rpm_limit": 100,
|
||||
"tpm_limit": 10000,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Key"
|
||||
assert "key" in data # Full key shown on creation
|
||||
assert data["key"].startswith("sk-test_")
|
||||
assert data["key_prefix"].startswith("sk-test_")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_api_keys(self, client: AsyncClient, setup_test_db):
|
||||
"""Test listing API keys."""
|
||||
await client.post(
|
||||
"/admin/keys",
|
||||
json={"name": "Key 1"},
|
||||
)
|
||||
|
||||
response = await client.get("/admin/keys")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 1
|
||||
|
||||
|
||||
class TestModelAliasAPI:
|
||||
"""Test Model Alias management API."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_model_alias(self, client: AsyncClient, setup_test_db):
|
||||
"""Test creating a model alias."""
|
||||
response = await client.post(
|
||||
"/admin/models/aliases",
|
||||
json={
|
||||
"alias": "gpt-4",
|
||||
"provider": "openai",
|
||||
"model": "gpt-4-turbo",
|
||||
"input_price_per_1k": 0.01,
|
||||
"output_price_per_1k": 0.03,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["alias"] == "gpt-4"
|
||||
assert data["provider"] == "openai"
|
||||
assert data["model"] == "gpt-4-turbo"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_model_aliases(self, client: AsyncClient, setup_test_db):
|
||||
"""Test listing model aliases."""
|
||||
await client.post(
|
||||
"/admin/models/aliases",
|
||||
json={
|
||||
"alias": "claude",
|
||||
"provider": "anthropic",
|
||||
"model": "claude-3-sonnet",
|
||||
},
|
||||
)
|
||||
|
||||
response = await client.get("/admin/models/aliases")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 1
|
||||
|
||||
|
||||
class TestHealthAPI:
|
||||
"""Test Health check API."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(self, client: AsyncClient, setup_test_db):
|
||||
"""Test health check endpoint."""
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_check(self, client: AsyncClient, setup_test_db):
|
||||
"""Test readiness check endpoint."""
|
||||
response = await client.get("/ready")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "ready"
|
||||
assert data["database"] == "connected"
|
||||
1
llm-gateway/tests/unit/__init__.py
Normal file
1
llm-gateway/tests/unit/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# unit module
|
||||
44
llm-gateway/tests/unit/test_config.py
Normal file
44
llm-gateway/tests/unit/test_config.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""Tests for configuration module."""
|
||||
import os
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
|
||||
class TestSettings:
|
||||
"""Test configuration settings."""
|
||||
|
||||
def test_default_settings(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test default configuration values."""
|
||||
# Change to temp directory for database path
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
assert settings.app_name == "LLM Gateway"
|
||||
assert settings.debug is False
|
||||
assert settings.database_url.startswith("sqlite:///")
|
||||
assert settings.master_key is not None
|
||||
|
||||
def test_custom_settings_from_env(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test loading settings from environment variables."""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
monkeypatch.setenv("APP_NAME", "Custom Gateway")
|
||||
monkeypatch.setenv("DEBUG", "true")
|
||||
monkeypatch.setenv("DATABASE_URL", "sqlite:///custom.db")
|
||||
monkeypatch.setenv("MASTER_KEY", "test-master-key-12345678901234567890") # 40 chars
|
||||
|
||||
settings = Settings()
|
||||
|
||||
assert settings.app_name == "Custom Gateway"
|
||||
assert settings.debug is True
|
||||
assert settings.database_url == "sqlite:///custom.db"
|
||||
assert settings.master_key == "test-master-key-12345678901234567890"
|
||||
|
||||
def test_generate_master_key(self):
|
||||
"""Test master key generation."""
|
||||
key = Settings.generate_master_key()
|
||||
|
||||
assert len(key) == 64 # 32 bytes = 64 hex chars
|
||||
assert isinstance(key, str)
|
||||
34
llm-gateway/tests/unit/test_database.py
Normal file
34
llm-gateway/tests/unit/test_database.py
Normal file
@ -0,0 +1,34 @@
|
||||
"""Tests for database module."""
|
||||
import pytest
|
||||
|
||||
from app.db.database import init_db, get_db
|
||||
from app.models import Provider, APIKey, Project
|
||||
|
||||
|
||||
class TestDatabase:
|
||||
"""Test database initialization and connections."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_creates_tables(self, db_session):
|
||||
"""Test that init_db creates all required tables."""
|
||||
# Tables should exist if we got a session
|
||||
assert db_session is not None
|
||||
|
||||
# Try to query each table
|
||||
from sqlalchemy import text
|
||||
|
||||
result = await db_session.execute(text("SELECT name FROM sqlite_master WHERE type='table'"))
|
||||
tables = [row[0] for row in result.fetchall()]
|
||||
|
||||
assert "providers" in tables
|
||||
assert "api_keys" in tables
|
||||
assert "projects" in tables
|
||||
assert "model_aliases" in tables
|
||||
assert "request_logs" in tables
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_returns_session(self, setup_test_db):
|
||||
"""Test that get_db returns a valid database session."""
|
||||
async for session in get_db():
|
||||
assert session is not None
|
||||
break
|
||||
Loading…
x
Reference in New Issue
Block a user