feat: implement LLM Gateway with multi-provider support

Implement a unified LLM Gateway supporting multiple API formats and providers:

Features:
- OpenAI Chat Completions, Responses API, and Anthropic Messages API
- Provider adapters for OpenAI, Anthropic, Azure OpenAI, Google Gemini, AWS Bedrock
- Model aliasing with weighted round-robin load balancing
- Virtual API keys with RPM/TPM rate limiting
- Budget control at key and project levels
- Request logging, usage statistics, and audit logs
- Fallback/retry with circuit breaker pattern
- Admin CRUD APIs for providers, projects, keys, models, usage
- Provider health checks

Tech stack:
- FastAPI with async SQLAlchemy 2.0
- SQLite with aiosqlite
- bcrypt for API key hashing, AES-256 for provider key encryption
- Docker containerization

Tests: 18 passing integration tests for admin API endpoints

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
root 2026-05-01 15:39:21 +08:00
parent 8f550a2100
commit 8348520bdf
63 changed files with 5691 additions and 0 deletions

29
llm-gateway/.env.example Normal file
View File

@ -0,0 +1,29 @@
# Application
APP_NAME=LLM Gateway
DEBUG=false
LOG_LEVEL=INFO
# Server
HOST=0.0.0.0
PORT=8000
# Database
DATABASE_URL=sqlite:///data/gateway.db
# Security
# Generate with: python -c "import secrets; print(secrets.token_hex(32))"
MASTER_KEY=your-master-key-here-at-least-32-characters
# Rate Limiting
RATE_LIMIT_WINDOW_SECONDS=60
# GLOBAL_RPM_LIMIT=1000
# GLOBAL_TPM_LIMIT=1000000
# Retry
MAX_RETRIES=3
RETRY_INITIAL_DELAY=1.0
RETRY_MAX_DELAY=30.0
# Health Check
HEALTH_CHECK_INTERVAL=30
HEALTH_CHECK_TIMEOUT=10

65
llm-gateway/.gitignore vendored Normal file
View File

@ -0,0 +1,65 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Virtual environments
venv/
ENV/
env/
.venv/
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# Testing
.pytest_cache/
.coverage
htmlcov/
.tox/
.nox/
# Type checking
.mypy_cache/
.dmypy.json
dmypy.json
# Environment
.env
.env.local
.env.*.local
# Database
*.db
*.sqlite
*.sqlite3
data/
# Logs
*.log
logs/
# OS
.DS_Store
Thumbs.db

28
llm-gateway/Dockerfile Normal file
View File

@ -0,0 +1,28 @@
FROM python:3.11-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements first for caching
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Create data directory
RUN mkdir -p /app/data
# Expose port
EXPOSE 8000
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Run the application
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

41
llm-gateway/README.md Normal file
View File

@ -0,0 +1,41 @@
# LLM Gateway
A unified LLM Gateway supporting multiple providers and API formats.
## Features
- **Multi-format API**: OpenAI Chat Completions, OpenAI Responses API, Anthropic Messages API
- **Multi-provider support**: OpenAI, Anthropic, Azure OpenAI, Google Gemini, AWS Bedrock
- **Model aliasing and routing**: Flexible model mapping with load balancing
- **Rate limiting**: RPM/TPM limits at multiple levels
- **Budget control**: Key and Project level spending limits
- **High availability**: Fallback, retry, and circuit breaker
- **Observability**: Request logging and usage statistics
## Quick Start
```bash
# Install dependencies
pip install -r requirements.txt
# Run the server
uvicorn app.main:app --reload
```
## API Endpoints
### LLM APIs
- `POST /v1/chat/completions` - OpenAI-compatible chat completions
- `POST /v1/responses` - OpenAI Responses API
- `POST /v1/messages` - Anthropic Messages API
### Admin APIs
- `GET|POST|PUT|DELETE /admin/providers` - Provider management
- `GET|POST|PUT|DELETE /admin/keys` - API Key management
- `GET|POST|PUT|DELETE /admin/projects` - Project management
- `GET|POST|PUT|DELETE /admin/models/aliases` - Model alias management
- `GET /admin/usage/stats` - Usage statistics
## Configuration
See `.env.example` for configuration options.

View File

@ -0,0 +1 @@
# app module

View File

@ -0,0 +1,69 @@
"""Provider adapters package."""
from typing import Type
from app.adapters.anthropic import AnthropicAdapter
from app.adapters.azure import AzureOpenAIAdapter
from app.adapters.base import BaseAdapter, HealthStatus, ProviderConfig
from app.adapters.bedrock import BedrockAdapter
from app.adapters.gemini import GeminiAdapter
from app.adapters.openai import OpenAIAdapter
# Registry of provider adapters
ADAPTER_REGISTRY: dict[str, Type[BaseAdapter]] = {
"openai": OpenAIAdapter,
"anthropic": AnthropicAdapter,
"azure": AzureOpenAIAdapter,
"gemini": GeminiAdapter,
"bedrock": BedrockAdapter,
"google": GeminiAdapter, # Alias
"aws": BedrockAdapter, # Alias
}
def get_adapter_class(provider: str) -> Type[BaseAdapter]:
"""
Get the adapter class for a provider.
Args:
provider: Provider name.
Returns:
Adapter class.
Raises:
ValueError: If provider is not supported.
"""
provider_lower = provider.lower()
if provider_lower not in ADAPTER_REGISTRY:
raise ValueError(f"Unsupported provider: {provider}")
return ADAPTER_REGISTRY[provider_lower]
def create_adapter(config: ProviderConfig) -> BaseAdapter:
"""
Create an adapter instance for a provider.
Args:
config: Provider configuration.
Returns:
Adapter instance.
"""
adapter_class = get_adapter_class(config.name)
return adapter_class(config)
__all__ = [
"BaseAdapter",
"HealthStatus",
"ProviderConfig",
"OpenAIAdapter",
"AnthropicAdapter",
"AzureOpenAIAdapter",
"GeminiAdapter",
"BedrockAdapter",
"get_adapter_class",
"create_adapter",
"ADAPTER_REGISTRY",
]

View File

@ -0,0 +1,206 @@
"""Anthropic provider adapter."""
import json
import time
import uuid
from typing import AsyncIterator
import httpx
from app.adapters.base import BaseAdapter, HealthStatus, ProviderConfig
from app.core.fallback import RetryableError, PermanentError, classify_error
from app.core.transformer import RequestTransformer
from app.schemas.anthropic import (
AnthropicMessagesRequest,
AnthropicMessagesResponse,
AnthropicTextBlock,
AnthropicUsage,
)
from app.schemas.openai import (
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequest,
OpenAIChatCompletionResponse,
)
class AnthropicAdapter(BaseAdapter):
"""Adapter for Anthropic Claude API."""
def __init__(self, config: ProviderConfig):
super().__init__(config)
self.transformer = RequestTransformer()
def get_headers(self) -> dict[str, str]:
"""Get headers for Anthropic API requests."""
headers = {
**super().get_headers(),
"x-api-key": self.config.api_key,
"anthropic-version": self.config.api_version or "2023-06-01",
}
return headers
async def chat_completions(
self,
request: OpenAIChatCompletionRequest,
) -> OpenAIChatCompletionResponse:
"""
Execute an OpenAI-format chat completion request via Anthropic.
Converts request to Anthropic format, calls Anthropic, converts response back.
"""
# Convert OpenAI request to Anthropic format
anthropic_request = self.transformer.openai_to_anthropic(request)
# Call Anthropic
anthropic_response = await self.messages(anthropic_request)
# Convert response back to OpenAI format
return self.transformer.anthropic_response_to_openai(
anthropic_response, request.model
)
async def stream_chat_completions(
self,
request: OpenAIChatCompletionRequest,
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""Execute a streaming chat completion request via Anthropic."""
# Convert to Anthropic format
anthropic_request = self.transformer.openai_to_anthropic(request)
anthropic_request.stream = True
url = f"{self.config.api_base}/v1/messages"
payload = anthropic_request.model_dump(exclude_none=True)
async with httpx.AsyncClient(timeout=120.0) as client:
try:
async with client.stream(
"POST",
url,
headers=self.get_headers(),
json=payload,
) as response:
if response.status_code >= 400:
error = classify_error(
Exception(f"{response.status_code}: {await response.aread()}")
)
raise error
message_id = None
model = request.model
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
data = line[6:]
try:
event = json.loads(data)
except json.JSONDecodeError:
continue
event_type = event.get("type")
if event_type == "message_start":
message_id = event.get("message", {}).get("id")
elif event_type == "content_block_delta":
delta = event.get("delta", {})
text = delta.get("text", "")
chunk = OpenAIChatCompletionChunk(
id=message_id or f"chatcmpl-{uuid.uuid4().hex[:24]}",
object="chat.completion.chunk",
created=int(time.time()),
model=model,
choices=[
{
"index": 0,
"delta": {"content": text, "role": "assistant"},
"finish_reason": None,
}
],
)
yield chunk
elif event_type == "message_delta":
delta = event.get("delta", {})
stop_reason = delta.get("stop_reason")
if stop_reason:
finish_reason = "stop" if stop_reason == "end_turn" else "length"
chunk = OpenAIChatCompletionChunk(
id=message_id or f"chatcmpl-{uuid.uuid4().hex[:24]}",
object="chat.completion.chunk",
created=int(time.time()),
model=model,
choices=[
{
"index": 0,
"delta": {},
"finish_reason": finish_reason,
}
],
)
yield chunk
except httpx.TimeoutException as e:
raise RetryableError(str(e), error_type="timeout")
except httpx.ConnectError as e:
raise RetryableError(str(e), error_type="connection_error")
async def messages(
self,
request: AnthropicMessagesRequest,
) -> AnthropicMessagesResponse:
"""Execute an Anthropic Messages API request."""
url = f"{self.config.api_base}/v1/messages"
payload = request.model_dump(exclude_none=True)
async with httpx.AsyncClient(timeout=120.0) as client:
try:
response = await client.post(
url,
headers=self.get_headers(),
json=payload,
)
if response.status_code >= 400:
error = classify_error(Exception(f"{response.status_code}: {response.text}"))
raise error
data = response.json()
return AnthropicMessagesResponse(**data)
except httpx.TimeoutException as e:
raise RetryableError(str(e), error_type="timeout")
except httpx.ConnectError as e:
raise RetryableError(str(e), error_type="connection_error")
async def check_health(self) -> HealthStatus:
"""Check Anthropic API health by making a minimal request."""
# Anthropic doesn't have a health endpoint, so we make a minimal request
url = f"{self.config.api_base}/v1/messages"
# Minimal request
payload = {
"model": "claude-3-haiku-20240307",
"max_tokens": 1,
"messages": [{"role": "user", "content": "hi"}],
}
async with httpx.AsyncClient(timeout=10.0) as client:
try:
response = await client.post(
url,
headers=self.get_headers(),
json=payload,
)
if response.status_code == 200:
return HealthStatus.HEALTHY
elif response.status_code >= 500:
return HealthStatus.UNHEALTHY
else:
return HealthStatus.DEGRADED
except Exception:
return HealthStatus.UNHEALTHY

View File

@ -0,0 +1,79 @@
"""Azure OpenAI provider adapter."""
from typing import AsyncIterator
from app.adapters.base import BaseAdapter, HealthStatus, ProviderConfig
from app.adapters.openai import OpenAIAdapter
from app.schemas.anthropic import AnthropicMessagesRequest, AnthropicMessagesResponse
from app.schemas.openai import (
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequest,
OpenAIChatCompletionResponse,
)
class AzureOpenAIAdapter(BaseAdapter):
"""Adapter for Azure OpenAI API."""
def __init__(self, config: ProviderConfig):
super().__init__(config)
# Azure uses OpenAI-compatible API
self._openai_adapter: OpenAIAdapter | None = None
def _get_openai_adapter(self) -> OpenAIAdapter:
"""Get or create OpenAI adapter for Azure."""
if self._openai_adapter is None:
# Azure config should include deployment_name
azure_config = self.config.config or {}
deployment_name = azure_config.get("deployment_name", "")
# Build Azure-specific API base
api_base = self.config.api_base
if deployment_name:
api_base = f"{api_base}/openai/deployments/{deployment_name}"
openai_config = ProviderConfig(
name=self.config.name,
api_base=api_base,
api_key=self.config.api_key,
api_version=self.config.api_version,
)
self._openai_adapter = OpenAIAdapter(openai_config)
return self._openai_adapter
def get_headers(self) -> dict[str, str]:
"""Get headers for Azure OpenAI API requests."""
return {
"Content-Type": "application/json",
"api-key": self.config.api_key, # Azure uses api-key header
}
async def chat_completions(
self,
request: OpenAIChatCompletionRequest,
) -> OpenAIChatCompletionResponse:
"""Execute a chat completion request to Azure OpenAI."""
adapter = self._get_openai_adapter()
return await adapter.chat_completions(request)
async def stream_chat_completions(
self,
request: OpenAIChatCompletionRequest,
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""Execute a streaming chat completion request to Azure OpenAI."""
adapter = self._get_openai_adapter()
async for chunk in adapter.stream_chat_completions(request):
yield chunk
async def messages(
self,
request: AnthropicMessagesRequest,
) -> AnthropicMessagesResponse:
"""Execute an Anthropic Messages API request via Azure OpenAI."""
adapter = self._get_openai_adapter()
return await adapter.messages(request)
async def check_health(self) -> HealthStatus:
"""Check Azure OpenAI API health."""
# Use OpenAI-style health check
adapter = self._get_openai_adapter()
return await adapter.check_health()

View File

@ -0,0 +1,135 @@
"""Base adapter interface for LLM providers."""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any, AsyncIterator
from app.schemas.anthropic import AnthropicMessagesRequest, AnthropicMessagesResponse
from app.schemas.openai import (
OpenAIChatCompletionRequest,
OpenAIChatCompletionResponse,
OpenAIChatCompletionChunk,
)
class HealthStatus(Enum):
"""Provider health status."""
HEALTHY = "healthy"
UNHEALTHY = "unhealthy"
DEGRADED = "degraded"
@dataclass
class ProviderConfig:
"""Configuration for a provider."""
name: str
api_base: str
api_key: str
api_version: str | None = None
config: dict[str, Any] | None = None
rpm_limit: int | None = None
tpm_limit: int | None = None
class BaseAdapter(ABC):
"""Abstract base class for LLM provider adapters."""
def __init__(self, config: ProviderConfig):
self.config = config
@abstractmethod
async def chat_completions(
self,
request: OpenAIChatCompletionRequest,
) -> OpenAIChatCompletionResponse:
"""
Execute a chat completion request.
Args:
request: OpenAI-format chat completion request.
Returns:
OpenAI-format chat completion response.
"""
pass
@abstractmethod
async def stream_chat_completions(
self,
request: OpenAIChatCompletionRequest,
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""
Execute a streaming chat completion request.
Args:
request: OpenAI-format chat completion request.
Yields:
OpenAI-format chat completion chunks.
"""
pass
@abstractmethod
async def messages(
self,
request: AnthropicMessagesRequest,
) -> AnthropicMessagesResponse:
"""
Execute an Anthropic Messages API request.
Args:
request: Anthropic-format messages request.
Returns:
Anthropic-format messages response.
"""
pass
async def count_tokens(
self,
request: OpenAIChatCompletionRequest | AnthropicMessagesRequest,
) -> int:
"""
Estimate token count for a request.
Args:
request: The request to count tokens for.
Returns:
Estimated token count.
"""
# Default implementation: rough estimation
if isinstance(request, OpenAIChatCompletionRequest):
total = 0
for msg in request.messages:
if isinstance(msg.content, str):
total += len(msg.content.split())
return int(total * 1.3) # Rough word-to-token ratio
elif isinstance(request, AnthropicMessagesRequest):
total = 0
for msg in request.messages:
if isinstance(msg.content, str):
total += len(msg.content.split())
if request.system:
total += len(request.system.split())
return int(total * 1.3)
return 0
@abstractmethod
async def check_health(self) -> HealthStatus:
"""
Check the health of the provider.
Returns:
Health status of the provider.
"""
pass
def get_headers(self) -> dict[str, str]:
"""Get common headers for requests."""
return {
"Content-Type": "application/json",
"Accept": "application/json",
}

View File

@ -0,0 +1,240 @@
"""AWS Bedrock provider adapter."""
import json
import time
import uuid
from typing import AsyncIterator, Any
try:
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError
HAS_BOTO3 = True
except ImportError:
HAS_BOTO3 = False
from app.adapters.base import BaseAdapter, HealthStatus, ProviderConfig
from app.core.fallback import RetryableError, classify_error
from app.core.transformer import RequestTransformer
from app.schemas.anthropic import (
AnthropicMessagesRequest,
AnthropicMessagesResponse,
)
from app.schemas.openai import (
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequest,
OpenAIChatCompletionResponse,
OpenAIChatCompletionChoice,
OpenAIChatMessage,
OpenAIUsage,
)
class BedrockAdapter(BaseAdapter):
"""Adapter for AWS Bedrock API."""
def __init__(self, config: ProviderConfig):
super().__init__(config)
self.transformer = RequestTransformer()
self._client = None
def _get_client(self):
"""Get or create Bedrock runtime client."""
if self._client is None and HAS_BOTO3:
aws_config = self.config.config or {}
region = aws_config.get("region", "us-east-1")
self._client = boto3.client(
"bedrock-runtime",
region_name=region,
config=Config(
retries={"max_attempts": 3, "mode": "adaptive"},
connect_timeout=10,
read_timeout=120,
),
)
return self._client
def _openai_to_bedrock_anthropic(
self, request: OpenAIChatCompletionRequest
) -> tuple[str, dict[str, Any]]:
"""Convert OpenAI request to Bedrock Anthropic format."""
messages = []
system = None
for msg in request.messages:
if msg.role == "system":
system = msg.content if isinstance(msg.content, str) else None
else:
messages.append({
"role": msg.role,
"content": [{"text": msg.content}] if isinstance(msg.content, str) else msg.content,
})
bedrock_request = {
"messages": messages,
"max_tokens": request.max_tokens or 4096,
}
if system:
bedrock_request["system"] = system
if request.temperature is not None:
bedrock_request["temperature"] = request.temperature
if request.top_p is not None:
bedrock_request["top_p"] = request.top_p
# Return model ID and request body
model_id = request.model
return model_id, bedrock_request
def _bedrock_to_openai(
self,
response: dict[str, Any],
model: str,
) -> OpenAIChatCompletionResponse:
"""Convert Bedrock Anthropic response to OpenAI format."""
content = ""
output = response.get("output", {})
message = output.get("message", {})
for block in message.get("content", []):
if "text" in block:
content += block["text"]
stop_reason = response.get("stopReason", "end_turn")
finish_reason_map = {
"end_turn": "stop",
"max_tokens": "length",
"stop_sequence": "stop",
"tool_use": "tool_calls",
}
finish_reason = finish_reason_map.get(stop_reason, "stop")
usage = response.get("usage", {})
return OpenAIChatCompletionResponse(
id=f"bedrock-{uuid.uuid4().hex[:24]}",
object="chat.completion",
created=int(time.time()),
model=model,
choices=[
OpenAIChatCompletionChoice(
index=0,
message=OpenAIChatMessage(role="assistant", content=content),
finish_reason=finish_reason,
)
],
usage=OpenAIUsage(
prompt_tokens=usage.get("inputTokens", 0),
completion_tokens=usage.get("outputTokens", 0),
total_tokens=usage.get("inputTokens", 0) + usage.get("outputTokens", 0),
),
)
async def chat_completions(
self,
request: OpenAIChatCompletionRequest,
) -> OpenAIChatCompletionResponse:
"""Execute a chat completion request to Bedrock."""
client = self._get_client()
model_id, bedrock_request = self._openai_to_bedrock_anthropic(request)
try:
response = client.invoke_model(
modelId=model_id,
contentType="application/json",
accept="application/json",
body=json.dumps(bedrock_request),
)
response_body = json.loads(response["body"].read())
return self._bedrock_to_openai(response_body, request.model)
except ClientError as e:
error = classify_error(Exception(str(e)))
raise error
async def stream_chat_completions(
self,
request: OpenAIChatCompletionRequest,
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""Execute a streaming chat completion request to Bedrock."""
client = self._get_client()
model_id, bedrock_request = self._openai_to_bedrock_anthropic(request)
try:
response = client.invoke_model_with_response_stream(
modelId=model_id,
contentType="application/json",
accept="application/json",
body=json.dumps(bedrock_request),
)
chunk_id = f"bedrock-{uuid.uuid4().hex[:24]}"
for event in response["body"]:
chunk_data = json.loads(event["chunk"]["bytes"])
if chunk_data.get("type") == "content_block_delta":
delta = chunk_data.get("delta", {})
text = delta.get("text", "")
if text:
chunk = OpenAIChatCompletionChunk(
id=chunk_id,
object="chat.completion.chunk",
created=int(time.time()),
model=request.model,
choices=[
{
"index": 0,
"delta": {"content": text},
"finish_reason": None,
}
],
)
yield chunk
elif chunk_data.get("type") == "message_delta":
stop_reason = chunk_data.get("delta", {}).get("stop_reason")
if stop_reason:
finish_reason = "stop" if stop_reason == "end_turn" else "length"
chunk = OpenAIChatCompletionChunk(
id=chunk_id,
object="chat.completion.chunk",
created=int(time.time()),
model=request.model,
choices=[
{
"index": 0,
"delta": {},
"finish_reason": finish_reason,
}
],
)
yield chunk
except ClientError as e:
error = classify_error(Exception(str(e)))
raise error
async def messages(
self,
request: AnthropicMessagesRequest,
) -> AnthropicMessagesResponse:
"""Execute an Anthropic Messages API request via Bedrock."""
openai_request = self.transformer.anthropic_to_openai(request)
openai_response = await self.chat_completions(openai_request)
return self.transformer.openai_response_to_anthropic(openai_response)
async def check_health(self) -> HealthStatus:
"""Check Bedrock API health."""
if not HAS_BOTO3:
return HealthStatus.UNHEALTHY
client = self._get_client()
try:
# List available models to check health
client.list_foundation_models()
return HealthStatus.HEALTHY
except Exception:
return HealthStatus.UNHEALTHY

View File

@ -0,0 +1,236 @@
"""Google Gemini provider adapter."""
import json
import time
import uuid
from typing import AsyncIterator, Any
import httpx
from app.adapters.base import BaseAdapter, HealthStatus, ProviderConfig
from app.core.fallback import RetryableError, classify_error
from app.core.transformer import RequestTransformer
from app.schemas.anthropic import (
AnthropicMessagesRequest,
AnthropicMessagesResponse,
AnthropicTextBlock,
AnthropicUsage,
)
from app.schemas.openai import (
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequest,
OpenAIChatCompletionResponse,
OpenAIChatCompletionChoice,
OpenAIChatMessage,
OpenAIUsage,
)
class GeminiAdapter(BaseAdapter):
"""Adapter for Google Gemini API."""
def __init__(self, config: ProviderConfig):
super().__init__(config)
self.transformer = RequestTransformer()
def get_headers(self) -> dict[str, str]:
"""Get headers for Gemini API requests."""
return {
"Content-Type": "application/json",
}
def _openai_to_gemini(self, request: OpenAIChatCompletionRequest) -> dict[str, Any]:
"""Convert OpenAI request to Gemini format."""
contents = []
system_instruction = None
for msg in request.messages:
if msg.role == "system":
system_instruction = {"parts": [{"text": msg.content}]}
else:
role = "user" if msg.role == "user" else "model"
if isinstance(msg.content, str):
parts = [{"text": msg.content}]
else:
parts = [{"text": str(part)} for part in msg.content]
contents.append({"role": role, "parts": parts})
gemini_request = {
"contents": contents,
"generationConfig": {},
}
if system_instruction:
gemini_request["systemInstruction"] = system_instruction
if request.temperature is not None:
gemini_request["generationConfig"]["temperature"] = request.temperature
if request.max_tokens is not None:
gemini_request["generationConfig"]["maxOutputTokens"] = request.max_tokens
if request.top_p is not None:
gemini_request["generationConfig"]["topP"] = request.top_p
if request.stop:
if isinstance(request.stop, str):
gemini_request["generationConfig"]["stopSequences"] = [request.stop]
else:
gemini_request["generationConfig"]["stopSequences"] = request.stop
return gemini_request
def _gemini_to_openai(
self,
response: dict[str, Any],
model: str,
) -> OpenAIChatCompletionResponse:
"""Convert Gemini response to OpenAI format."""
candidates = response.get("candidates", [])
content = ""
finish_reason = "stop"
if candidates:
candidate = candidates[0]
content_parts = candidate.get("content", {}).get("parts", [])
content = "".join(p.get("text", "") for p in content_parts)
finish_reason_map = {
"STOP": "stop",
"MAX_TOKENS": "length",
"SAFETY": "content_filter",
}
finish_reason = finish_reason_map.get(
candidate.get("finishReason", "STOP"), "stop"
)
usage = response.get("usageMetadata", {})
return OpenAIChatCompletionResponse(
id=f"gemini-{uuid.uuid4().hex[:24]}",
object="chat.completion",
created=int(time.time()),
model=model,
choices=[
OpenAIChatCompletionChoice(
index=0,
message=OpenAIChatMessage(role="assistant", content=content),
finish_reason=finish_reason,
)
],
usage=OpenAIUsage(
prompt_tokens=usage.get("promptTokenCount", 0),
completion_tokens=usage.get("candidatesTokenCount", 0),
total_tokens=usage.get("totalTokenCount", 0),
),
)
async def chat_completions(
self,
request: OpenAIChatCompletionRequest,
) -> OpenAIChatCompletionResponse:
"""Execute a chat completion request to Gemini."""
# Extract model name
model = request.model
url = f"{self.config.api_base}/v1beta/models/{model}:generateContent?key={self.config.api_key}"
gemini_request = self._openai_to_gemini(request)
async with httpx.AsyncClient(timeout=120.0) as client:
try:
response = await client.post(
url,
headers=self.get_headers(),
json=gemini_request,
)
if response.status_code >= 400:
error = classify_error(Exception(f"{response.status_code}: {response.text}"))
raise error
data = response.json()
return self._gemini_to_openai(data, model)
except httpx.TimeoutException as e:
raise RetryableError(str(e), error_type="timeout")
except httpx.ConnectError as e:
raise RetryableError(str(e), error_type="connection_error")
async def stream_chat_completions(
self,
request: OpenAIChatCompletionRequest,
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""Execute a streaming chat completion request to Gemini."""
model = request.model
url = f"{self.config.api_base}/v1beta/models/{model}:streamGenerateContent?key={self.config.api_key}&alt=sse"
gemini_request = self._openai_to_gemini(request)
async with httpx.AsyncClient(timeout=120.0) as client:
try:
async with client.stream(
"POST",
url,
headers=self.get_headers(),
json=gemini_request,
) as response:
if response.status_code >= 400:
error = classify_error(
Exception(f"{response.status_code}: {await response.aread()}")
)
raise error
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:]
try:
gemini_response = json.loads(data)
openai_response = self._gemini_to_openai(gemini_response, model)
if openai_response.choices:
chunk = OpenAIChatCompletionChunk(
id=openai_response.id,
object="chat.completion.chunk",
created=int(time.time()),
model=model,
choices=[
{
"index": 0,
"delta": {
"content": openai_response.choices[0].message.content
},
"finish_reason": openai_response.choices[0].finish_reason,
}
],
)
yield chunk
except json.JSONDecodeError:
continue
except httpx.TimeoutException as e:
raise RetryableError(str(e), error_type="timeout")
except httpx.ConnectError as e:
raise RetryableError(str(e), error_type="connection_error")
async def messages(
self,
request: AnthropicMessagesRequest,
) -> AnthropicMessagesResponse:
"""Execute an Anthropic Messages API request via Gemini."""
openai_request = self.transformer.anthropic_to_openai(request)
openai_response = await self.chat_completions(openai_request)
return self.transformer.openai_response_to_anthropic(openai_response)
async def check_health(self) -> HealthStatus:
"""Check Gemini API health."""
url = f"{self.config.api_base}/v1beta/models?key={self.config.api_key}"
async with httpx.AsyncClient(timeout=10.0) as client:
try:
response = await client.get(url, headers=self.get_headers())
if response.status_code == 200:
return HealthStatus.HEALTHY
elif response.status_code >= 500:
return HealthStatus.UNHEALTHY
else:
return HealthStatus.DEGRADED
except Exception:
return HealthStatus.UNHEALTHY

View File

@ -0,0 +1,134 @@
"""OpenAI provider adapter."""
import json
from typing import Any, AsyncIterator
import httpx
from app.adapters.base import BaseAdapter, HealthStatus, ProviderConfig
from app.core.fallback import RetryableError, PermanentError, classify_error
from app.schemas.anthropic import AnthropicMessagesRequest, AnthropicMessagesResponse
from app.schemas.openai import (
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequest,
OpenAIChatCompletionResponse,
)
from app.core.transformer import RequestTransformer
class OpenAIAdapter(BaseAdapter):
"""Adapter for OpenAI API."""
def __init__(self, config: ProviderConfig):
super().__init__(config)
self.transformer = RequestTransformer()
def get_headers(self) -> dict[str, str]:
"""Get headers for OpenAI API requests."""
return {
**super().get_headers(),
"Authorization": f"Bearer {self.config.api_key}",
}
async def chat_completions(
self,
request: OpenAIChatCompletionRequest,
) -> OpenAIChatCompletionResponse:
"""Execute a chat completion request to OpenAI."""
url = f"{self.config.api_base}/chat/completions"
payload = request.model_dump(exclude_none=True)
async with httpx.AsyncClient(timeout=120.0) as client:
try:
response = await client.post(
url,
headers=self.get_headers(),
json=payload,
)
if response.status_code >= 400:
error = classify_error(Exception(f"{response.status_code}: {response.text}"))
raise error
data = response.json()
return OpenAIChatCompletionResponse(**data)
except httpx.TimeoutException as e:
raise RetryableError(str(e), error_type="timeout")
except httpx.ConnectError as e:
raise RetryableError(str(e), error_type="connection_error")
async def stream_chat_completions(
self,
request: OpenAIChatCompletionRequest,
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""Execute a streaming chat completion request to OpenAI."""
url = f"{self.config.api_base}/chat/completions"
payload = request.model_dump(exclude_none=True)
payload["stream"] = True
async with httpx.AsyncClient(timeout=120.0) as client:
try:
async with client.stream(
"POST",
url,
headers=self.get_headers(),
json=payload,
) as response:
if response.status_code >= 400:
error = classify_error(
Exception(f"{response.status_code}: {await response.aread()}")
)
raise error
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:]
if data == "[DONE]":
break
try:
chunk = json.loads(data)
yield OpenAIChatCompletionChunk(**chunk)
except json.JSONDecodeError:
continue
except httpx.TimeoutException as e:
raise RetryableError(str(e), error_type="timeout")
except httpx.ConnectError as e:
raise RetryableError(str(e), error_type="connection_error")
async def messages(
self,
request: AnthropicMessagesRequest,
) -> AnthropicMessagesResponse:
"""
Execute an Anthropic Messages API request via OpenAI.
Converts request to OpenAI format, calls OpenAI, converts response back.
"""
# Convert Anthropic request to OpenAI format
openai_request = self.transformer.anthropic_to_openai(request)
# Call OpenAI
openai_response = await self.chat_completions(openai_request)
# Convert response back to Anthropic format
return self.transformer.openai_response_to_anthropic(openai_response)
async def check_health(self) -> HealthStatus:
"""Check OpenAI API health by listing models."""
url = f"{self.config.api_base}/models"
async with httpx.AsyncClient(timeout=10.0) as client:
try:
response = await client.get(url, headers=self.get_headers())
if response.status_code == 200:
return HealthStatus.HEALTHY
elif response.status_code >= 500:
return HealthStatus.UNHEALTHY
else:
return HealthStatus.DEGRADED
except Exception:
return HealthStatus.UNHEALTHY

View File

@ -0,0 +1 @@
# api module

View File

@ -0,0 +1 @@
# admin module

View File

@ -0,0 +1,68 @@
"""Health check API endpoints."""
from datetime import datetime
from typing import Annotated
from fastapi import APIRouter, Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import get_settings
from app.db.database import get_db
from app.models.provider import Provider
router = APIRouter(tags=["Health"])
@router.get("/health")
async def health_check() -> dict:
"""Basic health check endpoint."""
settings = get_settings()
return {
"status": "healthy",
"app": settings.app_name,
"timestamp": datetime.utcnow().isoformat(),
}
@router.get("/admin/providers/{provider_id}/health")
async def provider_health_check(
provider_id: str,
db: Annotated[AsyncSession, Depends(get_db)],
) -> dict:
"""Check health status of a specific provider."""
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
return {
"status": "not_found",
"provider_id": provider_id,
}
return {
"status": provider.health_status,
"provider_id": provider_id,
"provider_name": provider.name,
"last_check": provider.last_health_check.isoformat() if provider.last_health_check else None,
"enabled": provider.enabled,
}
@router.get("/ready")
async def readiness_check(
db: Annotated[AsyncSession, Depends(get_db)],
) -> dict:
"""Readiness check - verify database connection."""
try:
# Simple query to verify database connection
await db.execute(select(1))
return {
"status": "ready",
"database": "connected",
}
except Exception as e:
return {
"status": "not_ready",
"database": "disconnected",
"error": str(e),
}

View File

@ -0,0 +1,233 @@
"""API Key management API endpoints."""
import json
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db
from app.models.api_key import APIKey
from app.models.project import Project
from app.schemas.api_key import (
APIKeyCreate,
APIKeyCreateResponse,
APIKeyListResponse,
APIKeyResponse,
APIKeyUpdate,
)
from app.utils.crypto import generate_api_key
router = APIRouter(prefix="/keys", tags=["Admin - Keys"])
@router.post("", response_model=APIKeyCreateResponse, status_code=status.HTTP_201_CREATED)
async def create_api_key(
data: APIKeyCreate,
db: Annotated[AsyncSession, Depends(get_db)],
) -> dict:
"""Create a new API key."""
# Verify project exists if specified
if data.project_id:
result = await db.execute(select(Project).where(Project.id == data.project_id))
if not result.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Project '{data.project_id}' not found",
)
# Generate API key
full_key, key_hash, key_prefix = generate_api_key(data.prefix)
# Create API key record
api_key = APIKey(
key_hash=key_hash,
key_prefix=key_prefix,
name=data.name,
project_id=data.project_id,
enabled=data.enabled,
expires_at=data.expires_at,
rpm_limit=data.rpm_limit,
tpm_limit=data.tpm_limit,
budget_limit=data.budget_limit,
budget_period=data.budget_period,
allowed_models=json.dumps(data.allowed_models) if data.allowed_models else None,
)
db.add(api_key)
await db.flush()
await db.refresh(api_key)
# Return response with full key (only time it's shown)
return {
"id": api_key.id,
"key": full_key, # Full key - shown only once!
"key_prefix": api_key.key_prefix,
"name": api_key.name,
"project_id": api_key.project_id,
"enabled": api_key.enabled,
"expires_at": api_key.expires_at,
"rpm_limit": api_key.rpm_limit,
"tpm_limit": api_key.tpm_limit,
"budget_limit": api_key.budget_limit,
"budget_period": api_key.budget_period,
"allowed_models": json.loads(api_key.allowed_models) if api_key.allowed_models else None,
"current_usage": api_key.current_usage,
"total_requests": api_key.total_requests,
"created_at": api_key.created_at,
"updated_at": api_key.updated_at,
}
@router.get("", response_model=APIKeyListResponse)
async def list_api_keys(
db: Annotated[AsyncSession, Depends(get_db)],
page: Annotated[int, Query(ge=1)] = 1,
page_size: Annotated[int, Query(ge=1, le=100)] = 20,
project_id: str | None = None,
enabled: bool | None = None,
) -> APIKeyListResponse:
"""List all API keys."""
query = select(APIKey)
count_query = select(func.count()).select_from(APIKey)
if project_id:
query = query.where(APIKey.project_id == project_id)
count_query = count_query.where(APIKey.project_id == project_id)
if enabled is not None:
query = query.where(APIKey.enabled == enabled)
count_query = count_query.where(APIKey.enabled == enabled)
offset = (page - 1) * page_size
query = query.offset(offset).limit(page_size).order_by(APIKey.created_at.desc())
result = await db.execute(query)
keys = result.scalars().all()
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
return APIKeyListResponse(
keys=[
APIKeyResponse(
id=k.id,
key_prefix=k.key_prefix,
name=k.name,
project_id=k.project_id,
enabled=k.enabled,
expires_at=k.expires_at,
rpm_limit=k.rpm_limit,
tpm_limit=k.tpm_limit,
budget_limit=k.budget_limit,
budget_period=k.budget_period,
allowed_models=json.loads(k.allowed_models) if k.allowed_models else None,
current_usage=k.current_usage,
total_requests=k.total_requests,
created_at=k.created_at,
updated_at=k.updated_at,
)
for k in keys
],
total=total,
page=page,
page_size=page_size,
)
@router.get("/{key_id}", response_model=APIKeyResponse)
async def get_api_key(
key_id: str,
db: Annotated[AsyncSession, Depends(get_db)],
) -> dict:
"""Get an API key by ID."""
result = await db.execute(select(APIKey).where(APIKey.id == key_id))
key = result.scalar_one_or_none()
if not key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"API key '{key_id}' not found",
)
return APIKeyResponse(
id=key.id,
key_prefix=key.key_prefix,
name=key.name,
project_id=key.project_id,
enabled=key.enabled,
expires_at=key.expires_at,
rpm_limit=key.rpm_limit,
tpm_limit=key.tpm_limit,
budget_limit=key.budget_limit,
budget_period=key.budget_period,
allowed_models=json.loads(key.allowed_models) if key.allowed_models else None,
current_usage=key.current_usage,
total_requests=key.total_requests,
created_at=key.created_at,
updated_at=key.updated_at,
)
@router.put("/{key_id}", response_model=APIKeyResponse)
async def update_api_key(
key_id: str,
data: APIKeyUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
) -> dict:
"""Update an API key."""
result = await db.execute(select(APIKey).where(APIKey.id == key_id))
key = result.scalar_one_or_none()
if not key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"API key '{key_id}' not found",
)
update_data = data.model_dump(exclude_unset=True)
if "allowed_models" in update_data:
key.allowed_models = json.dumps(update_data.pop("allowed_models"))
for k, v in update_data.items():
setattr(key, k, v)
await db.flush()
await db.refresh(key)
return APIKeyResponse(
id=key.id,
key_prefix=key.key_prefix,
name=key.name,
project_id=key.project_id,
enabled=key.enabled,
expires_at=key.expires_at,
rpm_limit=key.rpm_limit,
tpm_limit=key.tpm_limit,
budget_limit=key.budget_limit,
budget_period=key.budget_period,
allowed_models=json.loads(key.allowed_models) if key.allowed_models else None,
current_usage=key.current_usage,
total_requests=key.total_requests,
created_at=key.created_at,
updated_at=key.updated_at,
)
@router.delete("/{key_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_api_key(
key_id: str,
db: Annotated[AsyncSession, Depends(get_db)],
) -> None:
"""Delete an API key."""
result = await db.execute(select(APIKey).where(APIKey.id == key_id))
key = result.scalar_one_or_none()
if not key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"API key '{key_id}' not found",
)
await db.delete(key)

View File

@ -0,0 +1,169 @@
"""Model alias management API endpoints."""
import json
from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db
from app.models.model_alias import ModelAlias
from app.schemas.model_alias import (
ModelAliasCreate,
ModelAliasListResponse,
ModelAliasResponse,
ModelAliasUpdate,
)
router = APIRouter(prefix="/models/aliases", tags=["Admin - Models"])
@router.post("", response_model=ModelAliasResponse, status_code=status.HTTP_201_CREATED)
async def create_model_alias(
data: ModelAliasCreate,
db: Annotated[AsyncSession, Depends(get_db)],
) -> ModelAlias:
"""Create a new model alias."""
# Check if alias already exists
result = await db.execute(select(ModelAlias).where(ModelAlias.alias == data.alias))
if result.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Model alias '{data.alias}' already exists",
)
# Create model alias
alias = ModelAlias(
alias=data.alias,
provider=data.provider,
model=data.model,
enabled=data.enabled,
routing_type=data.routing_type,
routing_config=json.dumps(data.routing_config.model_dump()) if data.routing_config else None,
input_price_per_1k=data.input_price_per_1k,
output_price_per_1k=data.output_price_per_1k,
)
db.add(alias)
await db.flush()
await db.refresh(alias)
return alias
@router.get("", response_model=ModelAliasListResponse)
async def list_model_aliases(
db: Annotated[AsyncSession, Depends(get_db)],
page: Annotated[int, Query(ge=1)] = 1,
page_size: Annotated[int, Query(ge=1, le=100)] = 20,
provider: str | None = None,
enabled: bool | None = None,
) -> ModelAliasListResponse:
"""List all model aliases."""
query = select(ModelAlias)
count_query = select(func.count()).select_from(ModelAlias)
if provider:
query = query.where(ModelAlias.provider == provider)
count_query = count_query.where(ModelAlias.provider == provider)
if enabled is not None:
query = query.where(ModelAlias.enabled == enabled)
count_query = count_query.where(ModelAlias.enabled == enabled)
offset = (page - 1) * page_size
query = query.offset(offset).limit(page_size).order_by(ModelAlias.created_at.desc())
result = await db.execute(query)
aliases = result.scalars().all()
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
return ModelAliasListResponse(
aliases=[
ModelAliasResponse(
id=a.id,
alias=a.alias,
provider=a.provider,
model=a.model,
enabled=a.enabled,
routing_type=a.routing_type,
routing_config=json.loads(a.routing_config) if a.routing_config else None,
input_price_per_1k=a.input_price_per_1k,
output_price_per_1k=a.output_price_per_1k,
created_at=a.created_at,
updated_at=a.updated_at,
)
for a in aliases
],
total=total,
page=page,
page_size=page_size,
)
@router.get("/{alias_id}", response_model=ModelAliasResponse)
async def get_model_alias(
alias_id: str,
db: Annotated[AsyncSession, Depends(get_db)],
) -> ModelAlias:
"""Get a model alias by ID."""
result = await db.execute(select(ModelAlias).where(ModelAlias.id == alias_id))
alias = result.scalar_one_or_none()
if not alias:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model alias '{alias_id}' not found",
)
return alias
@router.put("/{alias_id}", response_model=ModelAliasResponse)
async def update_model_alias(
alias_id: str,
data: ModelAliasUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
) -> ModelAlias:
"""Update a model alias."""
result = await db.execute(select(ModelAlias).where(ModelAlias.id == alias_id))
alias = result.scalar_one_or_none()
if not alias:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model alias '{alias_id}' not found",
)
update_data = data.model_dump(exclude_unset=True)
if "routing_config" in update_data and update_data["routing_config"]:
alias.routing_config = json.dumps(update_data.pop("routing_config"))
for key, value in update_data.items():
setattr(alias, key, value)
await db.flush()
await db.refresh(alias)
return alias
@router.delete("/{alias_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_model_alias(
alias_id: str,
db: Annotated[AsyncSession, Depends(get_db)],
) -> None:
"""Delete a model alias."""
result = await db.execute(select(ModelAlias).where(ModelAlias.id == alias_id))
alias = result.scalar_one_or_none()
if not alias:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model alias '{alias_id}' not found",
)
await db.delete(alias)

View File

@ -0,0 +1,132 @@
"""Project management API endpoints."""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db
from app.models.project import Project
from app.schemas.project import (
ProjectCreate,
ProjectListResponse,
ProjectResponse,
ProjectUpdate,
)
router = APIRouter(prefix="/projects", tags=["Admin - Projects"])
@router.post("", response_model=ProjectResponse, status_code=status.HTTP_201_CREATED)
async def create_project(
data: ProjectCreate,
db: Annotated[AsyncSession, Depends(get_db)],
) -> Project:
"""Create a new project."""
project = Project(
name=data.name,
description=data.description,
budget_limit=data.budget_limit,
budget_period=data.budget_period,
enabled=data.enabled,
)
db.add(project)
await db.flush()
await db.refresh(project)
return project
@router.get("", response_model=ProjectListResponse)
async def list_projects(
db: Annotated[AsyncSession, Depends(get_db)],
page: Annotated[int, Query(ge=1)] = 1,
page_size: Annotated[int, Query(ge=1, le=100)] = 20,
enabled: bool | None = None,
) -> ProjectListResponse:
"""List all projects."""
query = select(Project)
count_query = select(func.count()).select_from(Project)
if enabled is not None:
query = query.where(Project.enabled == enabled)
count_query = count_query.where(Project.enabled == enabled)
offset = (page - 1) * page_size
query = query.offset(offset).limit(page_size).order_by(Project.created_at.desc())
result = await db.execute(query)
projects = result.scalars().all()
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
return ProjectListResponse(
projects=[ProjectResponse.model_validate(p) for p in projects],
total=total,
page=page,
page_size=page_size,
)
@router.get("/{project_id}", response_model=ProjectResponse)
async def get_project(
project_id: str,
db: Annotated[AsyncSession, Depends(get_db)],
) -> Project:
"""Get a project by ID."""
result = await db.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Project '{project_id}' not found",
)
return project
@router.put("/{project_id}", response_model=ProjectResponse)
async def update_project(
project_id: str,
data: ProjectUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
) -> Project:
"""Update a project."""
result = await db.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Project '{project_id}' not found",
)
update_data = data.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(project, key, value)
await db.flush()
await db.refresh(project)
return project
@router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_project(
project_id: str,
db: Annotated[AsyncSession, Depends(get_db)],
) -> None:
"""Delete a project."""
result = await db.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Project '{project_id}' not found",
)
await db.delete(project)

View File

@ -0,0 +1,192 @@
"""Provider management API endpoints."""
import json
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db
from app.models.provider import Provider
from app.schemas.common import PaginationParams
from app.schemas.provider import (
ProviderCreate,
ProviderListResponse,
ProviderResponse,
ProviderUpdate,
)
from app.utils.crypto import encrypt_value
router = APIRouter(prefix="/providers", tags=["Admin - Providers"])
@router.post("", response_model=ProviderResponse, status_code=status.HTTP_201_CREATED)
async def create_provider(
data: ProviderCreate,
db: Annotated[AsyncSession, Depends(get_db)],
) -> Provider:
"""Create a new provider."""
# Check if provider already exists
result = await db.execute(select(Provider).where(Provider.name == data.name))
if result.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Provider '{data.name}' already exists",
)
# Encrypt API key
encrypted_key = encrypt_value(data.api_key)
# Create provider
provider = Provider(
name=data.name,
api_base=data.api_base,
api_key_encrypted=encrypted_key,
api_version=data.api_version,
config=json.dumps(data.config) if data.config else None,
rpm_limit=data.rpm_limit,
tpm_limit=data.tpm_limit,
enabled=data.enabled,
)
db.add(provider)
await db.flush()
await db.refresh(provider)
# Decrypt config for response
if provider.config:
provider.config = json.loads(provider.config)
return provider
@router.get("", response_model=ProviderListResponse)
async def list_providers(
db: Annotated[AsyncSession, Depends(get_db)],
page: Annotated[int, Query(ge=1)] = 1,
page_size: Annotated[int, Query(ge=1, le=100)] = 20,
enabled: bool | None = None,
) -> ProviderListResponse:
"""List all providers."""
# Build query
query = select(Provider)
count_query = select(func.count()).select_from(Provider)
if enabled is not None:
query = query.where(Provider.enabled == enabled)
count_query = count_query.where(Provider.enabled == enabled)
# Apply pagination
offset = (page - 1) * page_size
query = query.offset(offset).limit(page_size).order_by(Provider.created_at.desc())
# Execute queries
result = await db.execute(query)
providers = result.scalars().all()
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
# Decrypt config for response
for p in providers:
if p.config:
p.config = json.loads(p.config)
return ProviderListResponse(
providers=[
ProviderResponse(
id=p.id,
name=p.name,
enabled=p.enabled,
api_base=p.api_base,
api_version=p.api_version,
config=json.loads(p.config) if p.config else None,
rpm_limit=p.rpm_limit,
tpm_limit=p.tpm_limit,
health_status=p.health_status,
last_health_check=p.last_health_check,
created_at=p.created_at,
updated_at=p.updated_at,
)
for p in providers
],
total=total,
page=page,
page_size=page_size,
)
@router.get("/{provider_id}", response_model=ProviderResponse)
async def get_provider(
provider_id: str,
db: Annotated[AsyncSession, Depends(get_db)],
) -> Provider:
"""Get a provider by ID."""
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Provider '{provider_id}' not found",
)
if provider.config:
provider.config = json.loads(provider.config)
return provider
@router.put("/{provider_id}", response_model=ProviderResponse)
async def update_provider(
provider_id: str,
data: ProviderUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
) -> Provider:
"""Update a provider."""
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Provider '{provider_id}' not found",
)
# Update fields
update_data = data.model_dump(exclude_unset=True)
if "api_key" in update_data:
provider.api_key_encrypted = encrypt_value(update_data.pop("api_key"))
if "config" in update_data:
provider.config = json.dumps(update_data.pop("config"))
for key, value in update_data.items():
setattr(provider, key, value)
await db.flush()
await db.refresh(provider)
if provider.config:
provider.config = json.loads(provider.config)
return provider
@router.delete("/{provider_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_provider(
provider_id: str,
db: Annotated[AsyncSession, Depends(get_db)],
) -> None:
"""Delete a provider."""
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Provider '{provider_id}' not found",
)
await db.delete(provider)

View File

@ -0,0 +1,221 @@
"""Usage statistics API endpoints."""
from datetime import datetime, timedelta
from decimal import Decimal
from typing import Annotated, Any
from fastapi import APIRouter, Depends, Query
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db
from app.models.usage import RequestLog, UsageStatsHourly
router = APIRouter(prefix="/usage", tags=["Admin - Usage"])
@router.get("/stats")
async def get_usage_stats(
db: Annotated[AsyncSession, Depends(get_db)],
start_date: Annotated[datetime | None, Query()] = None,
end_date: Annotated[datetime | None, Query()] = None,
group_by: Annotated[str, Query(pattern="^(model|provider|project|key)$")] = "model",
project_id: str | None = None,
key_id: str | None = None,
) -> dict[str, Any]:
"""Get usage statistics."""
# Default to last 7 days
if not start_date:
start_date = datetime.utcnow() - timedelta(days=7)
if not end_date:
end_date = datetime.utcnow()
# Build base query
query = select(
getattr(UsageStatsHourly, group_by).label("group"),
func.sum(UsageStatsHourly.request_count).label("requests"),
func.sum(UsageStatsHourly.input_tokens).label("input_tokens"),
func.sum(UsageStatsHourly.output_tokens).label("output_tokens"),
func.sum(UsageStatsHourly.total_tokens).label("total_tokens"),
func.sum(UsageStatsHourly.cost_usd).label("cost_usd"),
func.avg(UsageStatsHourly.avg_latency_ms).label("avg_latency_ms"),
func.sum(UsageStatsHourly.error_count).label("errors"),
).where(
UsageStatsHourly.timestamp >= start_date,
UsageStatsHourly.timestamp <= end_date,
)
if project_id:
query = query.where(UsageStatsHourly.project_id == project_id)
if key_id:
query = query.where(UsageStatsHourly.virtual_key_id == key_id)
query = query.group_by(getattr(UsageStatsHourly, group_by))
result = await db.execute(query)
rows = result.all()
# Calculate totals
total_query = select(
func.sum(UsageStatsHourly.request_count).label("requests"),
func.sum(UsageStatsHourly.input_tokens).label("input_tokens"),
func.sum(UsageStatsHourly.output_tokens).label("output_tokens"),
func.sum(UsageStatsHourly.total_tokens).label("total_tokens"),
func.sum(UsageStatsHourly.cost_usd).label("cost_usd"),
func.sum(UsageStatsHourly.error_count).label("errors"),
).where(
UsageStatsHourly.timestamp >= start_date,
UsageStatsHourly.timestamp <= end_date,
)
if project_id:
total_query = total_query.where(UsageStatsHourly.project_id == project_id)
if key_id:
total_query = total_query.where(UsageStatsHourly.virtual_key_id == key_id)
total_result = await db.execute(total_query)
total_row = total_result.one()
return {
"period": {
"start": start_date.isoformat(),
"end": end_date.isoformat(),
},
"totals": {
"requests": total_row.requests or 0,
"input_tokens": total_row.input_tokens or 0,
"output_tokens": total_row.output_tokens or 0,
"total_tokens": total_row.total_tokens or 0,
"cost_usd": float(total_row.cost_usd or 0),
"errors": total_row.errors or 0,
},
"by_" + group_by: [
{
group_by: row.group,
"requests": row.requests or 0,
"input_tokens": row.input_tokens or 0,
"output_tokens": row.output_tokens or 0,
"total_tokens": row.total_tokens or 0,
"cost_usd": float(row.cost_usd or 0),
"avg_latency_ms": int(row.avg_latency_ms or 0),
"errors": row.errors or 0,
}
for row in rows
],
}
@router.get("/logs")
async def get_usage_logs(
db: Annotated[AsyncSession, Depends(get_db)],
page: Annotated[int, Query(ge=1)] = 1,
page_size: Annotated[int, Query(ge=1, le=100)] = 50,
project_id: str | None = None,
key_id: str | None = None,
provider: str | None = None,
model: str | None = None,
status_code: int | None = None,
) -> dict[str, Any]:
"""Get request logs."""
query = select(RequestLog)
count_query = select(func.count()).select_from(RequestLog)
# Apply filters
if project_id:
query = query.where(RequestLog.project_id == project_id)
count_query = count_query.where(RequestLog.project_id == project_id)
if key_id:
query = query.where(RequestLog.virtual_key_id == key_id)
count_query = count_query.where(RequestLog.virtual_key_id == key_id)
if provider:
query = query.where(RequestLog.provider == provider)
count_query = count_query.where(RequestLog.provider == provider)
if model:
query = query.where(RequestLog.model == model)
count_query = count_query.where(RequestLog.model == model)
if status_code:
query = query.where(RequestLog.status_code == status_code)
count_query = count_query.where(RequestLog.status_code == status_code)
# Apply pagination
offset = (page - 1) * page_size
query = query.offset(offset).limit(page_size).order_by(RequestLog.timestamp.desc())
result = await db.execute(query)
logs = result.scalars().all()
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
return {
"logs": [
{
"id": log.id,
"timestamp": log.timestamp.isoformat(),
"key_id": log.virtual_key_id,
"project_id": log.project_id,
"provider": log.provider,
"model": log.model,
"model_alias": log.model_alias,
"request_type": log.request_type,
"input_tokens": log.input_tokens,
"output_tokens": log.output_tokens,
"total_tokens": log.total_tokens,
"status_code": log.status_code,
"latency_ms": log.latency_ms,
"finish_reason": log.finish_reason,
"cost_usd": float(log.cost_usd),
}
for log in logs
],
"total": total,
"page": page,
"page_size": page_size,
}
@router.get("/costs")
async def get_cost_breakdown(
db: Annotated[AsyncSession, Depends(get_db)],
start_date: Annotated[datetime | None, Query()] = None,
end_date: Annotated[datetime | None, Query()] = None,
group_by: Annotated[str, Query(pattern="^(model|provider|project|key)$")] = "provider",
) -> dict[str, Any]:
"""Get cost breakdown."""
# Default to current month
if not start_date:
start_date = datetime.utcnow().replace(day=1, hour=0, minute=0, second=0, microsecond=0)
if not end_date:
end_date = datetime.utcnow()
query = select(
getattr(UsageStatsHourly, group_by).label("group"),
func.sum(UsageStatsHourly.cost_usd).label("cost_usd"),
func.sum(UsageStatsHourly.request_count).label("requests"),
func.sum(UsageStatsHourly.total_tokens).label("tokens"),
).where(
UsageStatsHourly.timestamp >= start_date,
UsageStatsHourly.timestamp <= end_date,
).group_by(getattr(UsageStatsHourly, group_by))
result = await db.execute(query)
rows = result.all()
total_cost = sum(float(row.cost_usd or 0) for row in rows)
return {
"period": {
"start": start_date.isoformat(),
"end": end_date.isoformat(),
},
"total_cost_usd": total_cost,
"breakdown": [
{
group_by: row.group,
"cost_usd": float(row.cost_usd or 0),
"percentage": float(row.cost_usd or 0) / total_cost * 100 if total_cost > 0 else 0,
"requests": row.requests or 0,
"tokens": row.tokens or 0,
}
for row in sorted(rows, key=lambda r: r.cost_usd or 0, reverse=True)
],
}

View File

@ -0,0 +1 @@
# v1 module

View File

@ -0,0 +1,327 @@
"""Chat Completions API endpoint (OpenAI-compatible)."""
import time
import uuid
from datetime import datetime
from decimal import Decimal
from typing import Annotated, Any
from fastapi import APIRouter, Depends, Header, HTTPException, Request
from fastapi.responses import StreamingResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.adapters import ProviderConfig, create_adapter
from app.core.budget import BudgetController, BudgetExceeded
from app.core.circuit_breaker import get_circuit_breaker
from app.core.fallback import RetryExecutor, RetryConfig, RetryableError, classify_error
from app.core.load_balancer import LoadBalancer
from app.core.rate_limiter import RateLimiter, RateLimitExceeded
from app.core.router import Router, RoutingResult
from app.db.database import get_db
from app.models.api_key import APIKey
from app.models.provider import Provider
from app.models.usage import RequestLog
from app.schemas.openai import (
OpenAIChatCompletionRequest,
OpenAIChatCompletionResponse,
)
from app.utils.crypto import decrypt_value, verify_api_key
from app.utils.logging import get_logger
logger = get_logger(__name__)
router = APIRouter(prefix="/v1", tags=["Chat"])
async def authenticate(
authorization: str | None = Header(None),
x_api_key: str | None = Header(None),
db: AsyncSession = Depends(get_db),
) -> APIKey:
"""Authenticate request using virtual API key."""
# Extract key from header
key = None
if authorization:
if authorization.startswith("Bearer "):
key = authorization[7:]
elif x_api_key:
key = x_api_key
if not key:
raise HTTPException(
status_code=401,
detail={"error": {"type": "authentication_error", "message": "Missing API key"}},
)
# Find and verify key
result = await db.execute(select(APIKey))
api_keys = result.scalars().all()
for api_key in api_keys:
if verify_api_key(key, api_key.key_hash):
if not api_key.enabled:
raise HTTPException(
status_code=403,
detail={"error": {"type": "permission_error", "message": "API key is disabled"}},
)
if api_key.expires_at and api_key.expires_at < datetime.utcnow():
raise HTTPException(
status_code=403,
detail={"error": {"type": "permission_error", "message": "API key has expired"}},
)
return api_key
raise HTTPException(
status_code=401,
detail={"error": {"type": "authentication_error", "message": "Invalid API key"}},
)
@router.post("/chat/completions")
async def chat_completions(
request: OpenAIChatCompletionRequest,
db: Annotated[AsyncSession, Depends(get_db)],
api_key: Annotated[APIKey, Depends(authenticate)],
) -> OpenAIChatCompletionResponse:
"""Execute a chat completion request."""
start_time = time.time()
routing_result: RoutingResult | None = None
response: OpenAIChatCompletionResponse | None = None
error: Exception | None = None
try:
# Initialize services
router_service = Router(db)
rate_limiter = RateLimiter(db)
budget_controller = BudgetController(db)
load_balancer = LoadBalancer(db)
# Check rate limits
try:
await rate_limiter.check_and_increment(
f"key:{api_key.id}",
rpm_limit=api_key.rpm_limit,
tpm_limit=api_key.tpm_limit,
)
except RateLimitExceeded as e:
raise HTTPException(
status_code=429,
detail={
"error": {
"type": "rate_limit_error",
"message": str(e),
"details": {
"limit": e.limit,
"remaining": e.remaining,
"reset_at": e.reset_at.isoformat(),
},
}
},
)
# Check budget
try:
await budget_controller.check_budget(api_key.id)
except BudgetExceeded as e:
raise HTTPException(
status_code=402,
detail={
"error": {
"type": "budget_exceeded_error",
"message": str(e),
"details": {
"limit": float(e.limit),
"current_usage": float(e.current_usage),
"scope": e.scope,
},
}
},
)
# Resolve model routing
routing_result = await router_service.resolve_model(request.model)
# Check if model is allowed
if api_key.allowed_models:
import json
allowed = json.loads(api_key.allowed_models)
if request.model not in allowed:
raise HTTPException(
status_code=403,
detail={
"error": {
"type": "permission_error",
"message": f"Model '{request.model}' is not allowed for this API key",
}
},
)
# Get provider config
result = await db.execute(
select(Provider).where(Provider.name == routing_result.provider, Provider.enabled == True)
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(
status_code=503,
detail={
"error": {
"type": "provider_error",
"message": f"Provider '{routing_result.provider}' not found or disabled",
}
},
)
# Check circuit breaker
circuit_breaker = get_circuit_breaker()
if not circuit_breaker.is_available(provider.name):
raise HTTPException(
status_code=503,
detail={
"error": {
"type": "service_unavailable",
"message": f"Provider '{provider.name}' is currently unavailable",
}
},
)
# Create adapter
provider_config = ProviderConfig(
name=provider.name,
api_base=provider.api_base,
api_key=decrypt_value(provider.api_key_encrypted),
api_version=provider.api_version,
config=json.loads(provider.config) if provider.config else None,
)
adapter = create_adapter(provider_config)
# Update request model to actual model
request.model = routing_result.model
# Execute request with retry
retry_executor = RetryExecutor(RetryConfig())
async def execute():
return await adapter.chat_completions(request)
response = await retry_executor.execute(execute, provider.name, "chat_completions")
circuit_breaker.record_success(provider.name)
# Record usage
if response.usage:
cost = await _calculate_cost(db, routing_result, response.usage)
await budget_controller.record_usage(
api_key.id,
cost,
response.usage.prompt_tokens,
response.usage.completion_tokens,
)
# Record tokens for rate limiting
await rate_limiter.record_tokens(
f"key:{api_key.id}",
response.usage.total_tokens,
)
return response
except HTTPException:
raise
except RetryableError as e:
error = e
logger.error(
"Chat completion failed",
model=request.model,
provider=routing_result.provider if routing_result else None,
error=str(e),
)
raise HTTPException(
status_code=502,
detail={
"error": {
"type": "provider_error",
"message": str(e),
}
},
)
except Exception as e:
error = e
logger.exception("Unexpected error in chat completions")
raise HTTPException(
status_code=500,
detail={
"error": {
"type": "internal_error",
"message": str(e),
}
},
)
finally:
# Log request
await _log_request(
db,
api_key.id,
routing_result,
request,
response,
error,
start_time,
)
async def _calculate_cost(
db: AsyncSession,
routing_result: RoutingResult,
usage: Any,
) -> Decimal:
"""Calculate cost for a request."""
from app.models.model_alias import ModelAlias
import json
# Try to find pricing from model alias
result = await db.execute(
select(ModelAlias).where(ModelAlias.alias == routing_result.model)
)
alias = result.scalar_one_or_none()
if alias and alias.input_price_per_1k and alias.output_price_per_1k:
input_cost = (usage.prompt_tokens / 1000) * alias.input_price_per_1k
output_cost = (usage.completion_tokens / 1000) * alias.output_price_per_1k
return Decimal(str(input_cost + output_cost))
# Default pricing (rough estimate)
return Decimal("0.001")
async def _log_request(
db: AsyncSession,
key_id: str,
routing_result: RoutingResult | None,
request: OpenAIChatCompletionRequest,
response: OpenAIChatCompletionResponse | None,
error: Exception | None,
start_time: float,
) -> None:
"""Log request details."""
latency_ms = int((time.time() - start_time) * 1000)
log = RequestLog(
virtual_key_id=key_id,
project_id=None, # Will be filled from API key
provider=routing_result.provider if routing_result else "unknown",
model=request.model,
model_alias=request.model,
request_type="chat",
input_tokens=response.usage.prompt_tokens if response and response.usage else 0,
output_tokens=response.usage.completion_tokens if response and response.usage else 0,
total_tokens=response.usage.total_tokens if response and response.usage else 0,
status_code=500 if error else (response and 200 or 500),
latency_ms=latency_ms,
finish_reason=response.choices[0].finish_reason if response and response.choices else None,
cost_usd=Decimal("0.001") if response and response.usage else Decimal("0"),
)
db.add(log)
await db.commit()

View File

@ -0,0 +1,53 @@
"""Anthropic Messages API endpoint."""
from typing import Annotated
from fastapi import APIRouter, Depends
from app.api.v1.chat import authenticate, _calculate_cost, _log_request
from app.core.transformer import RequestTransformer
from app.db.database import get_db
from app.models.api_key import APIKey
from app.schemas.anthropic import AnthropicMessagesRequest, AnthropicMessagesResponse
from app.schemas.openai import OpenAIChatCompletionRequest
router = APIRouter(prefix="/v1", tags=["Messages"])
@router.post("/messages")
async def messages(
request: AnthropicMessagesRequest,
db: Annotated[None, Depends(get_db)],
api_key: Annotated[APIKey, Depends(authenticate)],
) -> AnthropicMessagesResponse:
"""
Execute an Anthropic Messages API request.
This endpoint accepts Anthropic-format requests and forwards them
to the appropriate provider (Anthropic, OpenAI, etc.).
"""
from app.api.v1.chat import chat_completions
from app.schemas.openai import OpenAIChatCompletionRequest
# Convert Anthropic request to OpenAI format
transformer = RequestTransformer()
openai_request = transformer.anthropic_to_openai(request)
# Use the chat completions endpoint
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
# Re-import with correct type
from app.db.database import get_db as _get_db
from typing import AsyncGenerator
# Get database session
async for session in _get_db():
# Call chat completions with converted request
response = await chat_completions(openai_request, session, api_key)
# Convert response back to Anthropic format
anthropic_response = transformer.openai_response_to_anthropic(response)
return anthropic_response
raise HTTPException(status_code=500, detail="Database session error")

View File

@ -0,0 +1,72 @@
"""OpenAI Responses API endpoint."""
from typing import Annotated
from fastapi import APIRouter, Depends
from app.api.v1.chat import authenticate
from app.db.database import get_db
from app.models.api_key import APIKey
from app.schemas.openai import OpenAIResponseRequest
router = APIRouter(prefix="/v1", tags=["Responses"])
@router.post("/responses")
async def responses(
request: OpenAIResponseRequest,
db: Annotated[None, Depends(get_db)],
api_key: Annotated[APIKey, Depends(authenticate)],
) -> dict:
"""
Execute an OpenAI Responses API request.
This is the new OpenAI Responses API format.
Converts to chat completions internally.
"""
from app.schemas.openai import OpenAIChatCompletionRequest, OpenAIChatMessage
from app.api.v1.chat import chat_completions
from app.db.database import get_db as _get_db
# Convert Responses API format to Chat Completions format
messages = []
# Add instructions as system message if present
if request.instructions:
messages.append(OpenAIChatMessage(role="system", content=request.instructions))
# Add input as user message
if isinstance(request.input, str):
messages.append(OpenAIChatMessage(role="user", content=request.input))
else:
# Handle structured input
messages.append(OpenAIChatMessage(role="user", content=str(request.input)))
chat_request = OpenAIChatCompletionRequest(
model=request.model,
messages=messages,
max_tokens=request.max_output_tokens,
temperature=request.temperature,
tools=request.tools,
tool_choice=request.tool_choice,
)
# Get database session and call chat completions
async for session in _get_db():
response = await chat_completions(chat_request, session, api_key)
# Convert to Responses API format
return {
"id": response.id,
"object": "response",
"created": response.created,
"model": response.model,
"output": response.choices[0].message.content if response.choices else "",
"usage": {
"input_tokens": response.usage.prompt_tokens if response.usage else 0,
"output_tokens": response.usage.completion_tokens if response.usage else 0,
"total_tokens": response.usage.total_tokens if response.usage else 0,
},
}
from fastapi import HTTPException
raise HTTPException(status_code=500, detail="Database session error")

73
llm-gateway/app/config.py Normal file
View File

@ -0,0 +1,73 @@
"""Application configuration using Pydantic Settings."""
import secrets
from functools import lru_cache
from pathlib import Path
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="ignore",
)
# Application
app_name: str = "LLM Gateway"
debug: bool = False
log_level: str = "INFO"
# Server
host: str = "0.0.0.0"
port: int = 8000
# Database
database_url: str = Field(
default="sqlite:///data/gateway.db",
description="SQLite database URL",
)
# Security
master_key: str = Field(
default_factory=lambda: secrets.token_hex(32),
description="Master key for encrypting provider API keys",
)
# Rate Limiting
rate_limit_window_seconds: int = 60
global_rpm_limit: int | None = None
global_tpm_limit: int | None = None
# Retry
max_retries: int = 3
retry_initial_delay: float = 1.0
retry_max_delay: float = 30.0
retry_exponential_base: float = 2.0
# Health Check
health_check_interval: int = 30
health_check_timeout: int = 10
@field_validator("master_key")
@classmethod
def validate_master_key(cls, v: str) -> str:
"""Ensure master key is at least 32 characters."""
if len(v) < 32:
raise ValueError("Master key must be at least 32 characters")
return v
@staticmethod
def generate_master_key() -> str:
"""Generate a secure random master key."""
return secrets.token_hex(32)
@lru_cache
def get_settings() -> Settings:
"""Get cached settings instance."""
return Settings()

View File

@ -0,0 +1 @@
# core module

View File

@ -0,0 +1,160 @@
"""Budget controller for spending limits."""
from datetime import datetime, timedelta
from decimal import Decimal
from typing import Any
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.api_key import APIKey
from app.models.project import Project
from app.models.usage import RequestLog
class BudgetExceeded(Exception):
"""Exception raised when budget is exceeded."""
def __init__(
self,
budget_type: str,
limit: Decimal,
current_usage: Decimal,
scope: str,
):
self.budget_type = budget_type
self.limit = limit
self.current_usage = current_usage
self.scope = scope
super().__init__(f"Budget exceeded: {scope} {budget_type}")
class BudgetController:
"""Controller for budget limits at key and project levels."""
def __init__(self, db: AsyncSession):
self.db = db
async def check_budget(
self,
key_id: str,
estimated_cost: Decimal = Decimal("0"),
) -> dict[str, Any]:
"""
Check budget limits before making a request.
Args:
key_id: The API key ID.
estimated_cost: Estimated cost for this request.
Returns:
Dict with budget status.
Raises:
BudgetExceeded: If budget is exceeded.
"""
# Get API key
result = await self.db.execute(select(APIKey).where(APIKey.id == key_id))
api_key = result.scalar_one_or_none()
if not api_key:
raise ValueError(f"API key '{key_id}' not found")
now = datetime.utcnow()
# Check key-level budget
if api_key.budget_limit:
period_start = self._get_period_start(api_key.budget_period, now)
current_usage = await self._get_usage_for_period(key_id, None, period_start, now)
if current_usage + estimated_cost > api_key.budget_limit:
raise BudgetExceeded(
budget_type=api_key.budget_period or "total",
limit=api_key.budget_limit,
current_usage=current_usage,
scope=f"key:{api_key.name}",
)
# Check soft limit (90% of hard limit)
soft_limit = api_key.budget_limit * Decimal("0.9")
if current_usage + estimated_cost > soft_limit:
# Log warning but don't block
pass
# Check project-level budget
if api_key.project_id:
result = await self.db.execute(
select(Project).where(Project.id == api_key.project_id)
)
project = result.scalar_one_or_none()
if project and project.budget_limit:
period_start = self._get_period_start(project.budget_period, now)
current_usage = await self._get_usage_for_period(
None, project.id, period_start, now
)
if current_usage + estimated_cost > project.budget_limit:
raise BudgetExceeded(
budget_type=project.budget_period or "total",
limit=project.budget_limit,
current_usage=current_usage,
scope=f"project:{project.name}",
)
return {
"key_budget_limit": float(api_key.budget_limit) if api_key.budget_limit else None,
"key_current_usage": float(api_key.current_usage),
"project_budget_limit": None,
"project_current_usage": None,
}
async def record_usage(
self,
key_id: str,
cost: Decimal,
input_tokens: int,
output_tokens: int,
) -> None:
"""Record actual usage after request completes."""
result = await self.db.execute(select(APIKey).where(APIKey.id == key_id))
api_key = result.scalar_one_or_none()
if api_key:
api_key.current_usage += cost
api_key.total_requests += 1
def _get_period_start(self, period: str | None, now: datetime) -> datetime:
"""Get the start of the budget period."""
if period == "daily":
return now.replace(hour=0, minute=0, second=0, microsecond=0)
elif period == "weekly":
start = now - timedelta(days=now.weekday())
return start.replace(hour=0, minute=0, second=0, microsecond=0)
elif period == "monthly":
return now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
else:
# No period means all-time budget
return datetime.min
async def _get_usage_for_period(
self,
key_id: str | None,
project_id: str | None,
period_start: datetime,
now: datetime,
) -> Decimal:
"""Get total usage for a period."""
query = select(func.sum(RequestLog.cost_usd)).where(
RequestLog.timestamp >= period_start,
RequestLog.timestamp <= now,
)
if key_id:
query = query.where(RequestLog.virtual_key_id == key_id)
if project_id:
query = query.where(RequestLog.project_id == project_id)
result = await self.db.execute(query)
total = result.scalar()
return Decimal(str(total or 0))

View File

@ -0,0 +1,168 @@
"""Circuit breaker for provider fault tolerance."""
import asyncio
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
class CircuitState(Enum):
"""Circuit breaker states."""
CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, reject all requests
HALF_OPEN = "half_open" # Testing if recovered
@dataclass
class CircuitStats:
"""Statistics for circuit breaker."""
failure_count: int = 0
success_count: int = 0
last_failure_time: float = 0
last_success_time: float = 0
state: CircuitState = CircuitState.CLOSED
last_state_change: float = field(default_factory=time.time)
class CircuitBreaker:
"""
Circuit breaker for provider fault tolerance.
States:
- CLOSED: Normal operation, requests pass through
- OPEN: Too many failures, reject all requests
- HALF_OPEN: Testing if provider has recovered
"""
def __init__(
self,
failure_threshold: int = 5,
recovery_timeout: float = 30.0,
success_threshold: int = 2,
):
"""
Initialize circuit breaker.
Args:
failure_threshold: Number of failures before opening circuit.
recovery_timeout: Seconds to wait before trying half-open state.
success_threshold: Number of successes in half-open to close circuit.
"""
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.success_threshold = success_threshold
self._circuits: dict[str, CircuitStats] = {}
def get_stats(self, provider: str) -> CircuitStats:
"""Get or create circuit stats for a provider."""
if provider not in self._circuits:
self._circuits[provider] = CircuitStats()
return self._circuits[provider]
def is_available(self, provider: str) -> bool:
"""
Check if a provider is available (circuit is not open).
Args:
provider: Provider name.
Returns:
True if requests should be allowed.
"""
stats = self.get_stats(provider)
now = time.time()
if stats.state == CircuitState.CLOSED:
return True
elif stats.state == CircuitState.OPEN:
# Check if recovery timeout has passed
if now - stats.last_state_change >= self.recovery_timeout:
# Transition to half-open
stats.state = CircuitState.HALF_OPEN
stats.success_count = 0
stats.last_state_change = now
return True
return False
elif stats.state == CircuitState.HALF_OPEN:
return True
return True
def record_success(self, provider: str) -> None:
"""
Record a successful request.
Args:
provider: Provider name.
"""
stats = self.get_stats(provider)
now = time.time()
stats.success_count += 1
stats.last_success_time = now
stats.failure_count = 0 # Reset failure count on success
if stats.state == CircuitState.HALF_OPEN:
if stats.success_count >= self.success_threshold:
# Transition to closed
stats.state = CircuitState.CLOSED
stats.last_state_change = now
def record_failure(self, provider: str) -> None:
"""
Record a failed request.
Args:
provider: Provider name.
"""
stats = self.get_stats(provider)
now = time.time()
stats.failure_count += 1
stats.last_failure_time = now
if stats.state == CircuitState.HALF_OPEN:
# Immediately open on failure in half-open
stats.state = CircuitState.OPEN
stats.last_state_change = now
elif stats.state == CircuitState.CLOSED:
if stats.failure_count >= self.failure_threshold:
# Transition to open
stats.state = CircuitState.OPEN
stats.last_state_change = now
def get_state(self, provider: str) -> CircuitState:
"""Get current circuit state for a provider."""
return self.get_stats(provider).state
def reset(self, provider: str) -> None:
"""Reset circuit breaker for a provider."""
if provider in self._circuits:
del self._circuits[provider]
def reset_all(self) -> None:
"""Reset all circuit breakers."""
self._circuits.clear()
# Global circuit breaker instance
_circuit_breaker: CircuitBreaker | None = None
def get_circuit_breaker() -> CircuitBreaker:
"""Get the global circuit breaker instance."""
global _circuit_breaker
if _circuit_breaker is None:
from app.config import get_settings
settings = get_settings()
_circuit_breaker = CircuitBreaker(
failure_threshold=5,
recovery_timeout=30.0,
)
return _circuit_breaker

View File

@ -0,0 +1,223 @@
"""Fallback and retry logic for provider failures."""
import asyncio
import random
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, TypeVar
from app.config import get_settings
from app.core.circuit_breaker import CircuitBreaker, get_circuit_breaker
from app.utils.logging import get_logger
logger = get_logger(__name__)
T = TypeVar("T")
class RetryableError(Exception):
"""Error that can be retried."""
def __init__(self, message: str, error_type: str = "unknown"):
self.message = message
self.error_type = error_type
super().__init__(message)
class PermanentError(Exception):
"""Error that should not be retried."""
def __init__(self, message: str, error_type: str = "unknown"):
self.message = message
self.error_type = error_type
super().__init__(message)
@dataclass
class RetryConfig:
"""Retry configuration."""
max_retries: int = 3
initial_delay: float = 1.0
max_delay: float = 30.0
exponential_base: float = 2.0
jitter: bool = True
retryable_errors: set[str] | None = None
def __post_init__(self):
if self.retryable_errors is None:
self.retryable_errors = {
"rate_limit_exceeded",
"timeout",
"service_unavailable",
"internal_error",
"connection_error",
"overloaded",
}
class RetryExecutor:
"""Executor for retrying failed operations."""
def __init__(self, config: RetryConfig | None = None):
self.config = config or RetryConfig()
self.settings = get_settings()
async def execute(
self,
operation: Callable[[], T],
provider: str,
operation_name: str = "operation",
) -> T:
"""
Execute an operation with retry logic.
Args:
operation: Async function to execute.
provider: Provider name for circuit breaker.
operation_name: Name for logging.
Returns:
Result of the operation.
Raises:
PermanentError: If operation fails permanently.
RetryableError: If all retries are exhausted.
"""
circuit_breaker = get_circuit_breaker()
last_error: Exception | None = None
for attempt in range(self.config.max_retries + 1):
# Check circuit breaker
if not circuit_breaker.is_available(provider):
raise RetryableError(
f"Circuit breaker open for provider '{provider}'",
error_type="circuit_open",
)
try:
result = await operation()
circuit_breaker.record_success(provider)
return result
except PermanentError as e:
logger.warning(
"Permanent error in %s",
operation_name,
provider=provider,
attempt=attempt,
error=e.message,
)
circuit_breaker.record_failure(provider)
raise
except RetryableError as e:
last_error = e
circuit_breaker.record_failure(provider)
if attempt < self.config.max_retries:
delay = self._calculate_delay(attempt)
logger.warning(
"Retryable error in %s, retrying in %.2fs",
operation_name,
delay,
provider=provider,
attempt=attempt,
error=e.message,
error_type=e.error_type,
)
await asyncio.sleep(delay)
else:
logger.error(
"All retries exhausted for %s",
operation_name,
provider=provider,
attempts=attempt + 1,
error=e.message,
)
except Exception as e:
# Unknown error - treat as retryable
last_error = RetryableError(str(e), error_type="unknown")
circuit_breaker.record_failure(provider)
if attempt < self.config.max_retries:
delay = self._calculate_delay(attempt)
logger.warning(
"Unknown error in %s, retrying in %.2fs",
operation_name,
delay,
provider=provider,
attempt=attempt,
error=str(e),
)
await asyncio.sleep(delay)
# All retries exhausted
if last_error:
raise last_error
raise RetryableError("All retries exhausted", error_type="max_retries")
def _calculate_delay(self, attempt: int) -> float:
"""Calculate delay for a retry attempt with exponential backoff and jitter."""
delay = self.config.initial_delay * (self.config.exponential_base**attempt)
delay = min(delay, self.config.max_delay)
if self.config.jitter:
# Add jitter (0.5 to 1.5 times the delay)
jitter = random.uniform(0.5, 1.5)
delay *= jitter
return delay
def classify_error(error: Exception) -> Exception:
"""
Classify an error as retryable or permanent.
Args:
error: The error to classify.
Returns:
RetryableError or PermanentError.
"""
error_str = str(error).lower()
# Retryable errors
retryable_patterns = [
"rate limit",
"429",
"timeout",
"503",
"502",
"service unavailable",
"internal error",
"500",
"connection",
"overloaded",
"capacity",
]
for pattern in retryable_patterns:
if pattern in error_str:
return RetryableError(str(error), error_type=pattern.replace(" ", "_"))
# Permanent errors
permanent_patterns = [
"401",
"unauthorized",
"invalid api key",
"403",
"forbidden",
"400",
"invalid request",
"404",
"not found",
"model not found",
]
for pattern in permanent_patterns:
if pattern in error_str:
return PermanentError(str(error), error_type=pattern.replace(" ", "_"))
# Default to retryable
return RetryableError(str(error), error_type="unknown")

View File

@ -0,0 +1,137 @@
"""Load balancer for distributing requests across providers."""
import random
from dataclasses import dataclass
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.provider import Provider
@dataclass
class ProviderInfo:
"""Information about a provider for load balancing."""
name: str
weight: int
health_status: str
current_requests: int = 0
class LoadBalancer:
"""Load balancer for distributing requests across providers."""
def __init__(self, db: AsyncSession):
self.db = db
async def select_provider(
self,
providers: list[dict[str, Any]],
strategy: str = "weighted_round_robin",
) -> dict[str, Any] | None:
"""
Select a provider based on the specified strategy.
Args:
providers: List of provider configs with weights.
strategy: Load balancing strategy.
Returns:
Selected provider config, or None if none available.
"""
if not providers:
return None
if len(providers) == 1:
return providers[0]
# Filter healthy providers
healthy_providers = await self._filter_healthy_providers(providers)
if not healthy_providers:
return providers[0] # Fallback to first if all unhealthy
if strategy == "weighted_round_robin":
return self._weighted_round_robin(healthy_providers)
elif strategy == "random":
return self._random_select(healthy_providers)
elif strategy == "least_connections":
return await self._least_connections(healthy_providers)
else:
return self._weighted_round_robin(healthy_providers)
def _weighted_round_robin(
self, providers: list[dict[str, Any]]
) -> dict[str, Any]:
"""Select provider using weighted round robin."""
total_weight = sum(p.get("weight", 1) for p in providers)
r = random.uniform(0, total_weight)
cumulative = 0
for provider in providers:
cumulative += provider.get("weight", 1)
if r <= cumulative:
return provider
return providers[0]
def _random_select(self, providers: list[dict[str, Any]]) -> dict[str, Any]:
"""Select provider randomly."""
return random.choice(providers)
async def _least_connections(
self, providers: list[dict[str, Any]]
) -> dict[str, Any]:
"""Select provider with least connections (placeholder)."""
# In a real implementation, this would track active connections
# For now, use random as a fallback
return self._random_select(providers)
async def _filter_healthy_providers(
self, providers: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Filter out unhealthy providers."""
healthy = []
for p in providers:
provider_name = p.get("provider")
if not provider_name:
continue
result = await self.db.execute(
select(Provider).where(
Provider.name == provider_name,
Provider.enabled == True,
Provider.health_status == "healthy",
)
)
provider = result.scalar_one_or_none()
if provider:
healthy.append(p)
return healthy
async def get_provider_health(
self, provider_name: str
) -> dict[str, Any]:
"""Get health status for a provider."""
result = await self.db.execute(
select(Provider).where(Provider.name == provider_name)
)
provider = result.scalar_one_or_none()
if not provider:
return {
"name": provider_name,
"status": "not_found",
"available": False,
}
return {
"name": provider.name,
"status": provider.health_status,
"available": provider.enabled and provider.health_status == "healthy",
"last_check": provider.last_health_check.isoformat() if provider.last_health_check else None,
}

View File

@ -0,0 +1,210 @@
"""Rate limiter for RPM/TPM limits."""
import time
from datetime import datetime, timedelta
from typing import Any
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import get_settings
from app.models.usage import RateLimitCounter
class RateLimitExceeded(Exception):
"""Exception raised when rate limit is exceeded."""
def __init__(
self,
limit_type: str,
limit: int,
remaining: int,
reset_at: datetime,
):
self.limit_type = limit_type
self.limit = limit
self.remaining = remaining
self.reset_at = reset_at
super().__init__(f"Rate limit exceeded: {limit_type}")
class RateLimiter:
"""Rate limiter for RPM and TPM limits."""
def __init__(self, db: AsyncSession):
self.db = db
self.settings = get_settings()
self.window_seconds = self.settings.rate_limit_window_seconds
async def check_and_increment(
self,
key: str,
rpm_limit: int | None = None,
tpm_limit: int | None = None,
estimated_tokens: int = 0,
) -> dict[str, Any]:
"""
Check rate limits and increment counters.
Args:
key: Unique key for rate limiting (e.g., "key:{key_id}:rpm").
rpm_limit: Requests per minute limit.
tpm_limit: Tokens per minute limit.
estimated_tokens: Estimated tokens for this request.
Returns:
Dict with remaining limits.
Raises:
RateLimitExceeded: If rate limit is exceeded.
"""
now = datetime.utcnow()
window_start = now - timedelta(seconds=self.window_seconds)
# Check global limits
if self.settings.global_rpm_limit:
await self._check_limit(
"global:rpm",
self.settings.global_rpm_limit,
window_start,
now,
1,
"request",
)
if self.settings.global_tpm_limit and estimated_tokens > 0:
await self._check_limit(
"global:tpm",
self.settings.global_tpm_limit,
window_start,
now,
estimated_tokens,
"token",
)
# Check specific limits
if rpm_limit:
await self._check_limit(
f"{key}:rpm",
rpm_limit,
window_start,
now,
1,
"request",
)
if tpm_limit and estimated_tokens > 0:
await self._check_limit(
f"{key}:tpm",
tpm_limit,
window_start,
now,
estimated_tokens,
"token",
)
# Get remaining limits
remaining = await self._get_remaining(key, rpm_limit, tpm_limit, window_start, now)
return remaining
async def _check_limit(
self,
key: str,
limit: int,
window_start: datetime,
now: datetime,
increment: int,
counter_type: str,
) -> None:
"""Check if a limit would be exceeded and increment counter."""
# Get current counter
result = await self.db.execute(
select(RateLimitCounter).where(RateLimitCounter.key == key)
)
counter = result.scalar_one_or_none()
if counter:
# Check if window has expired
if counter.window_start < window_start:
# Reset counter for new window
counter.request_count = 0
counter.token_count = 0
counter.window_start = now
# Check limit
current = (
counter.request_count if counter_type == "request" else counter.token_count
)
if current >= limit:
reset_at = counter.window_start + timedelta(seconds=counter.window_duration)
raise RateLimitExceeded(
limit_type=f"{counter_type}_per_{self.window_seconds}s",
limit=limit,
remaining=0,
reset_at=reset_at,
)
# Increment counter
if counter_type == "request":
counter.request_count += increment
else:
counter.token_count += increment
else:
# Create new counter
counter = RateLimitCounter(
key=key,
request_count=1 if counter_type == "request" else 0,
token_count=increment if counter_type == "token" else 0,
window_start=now,
window_duration=self.window_seconds,
)
self.db.add(counter)
async def _get_remaining(
self,
key: str,
rpm_limit: int | None,
tpm_limit: int | None,
window_start: datetime,
now: datetime,
) -> dict[str, Any]:
"""Get remaining limits for a key."""
remaining = {}
if rpm_limit:
result = await self.db.execute(
select(RateLimitCounter).where(RateLimitCounter.key == f"{key}:rpm")
)
counter = result.scalar_one_or_none()
remaining["rpm"] = rpm_limit - (counter.request_count if counter else 0)
if tpm_limit:
result = await self.db.execute(
select(RateLimitCounter).where(RateLimitCounter.key == f"{key}:tpm")
)
counter = result.scalar_one_or_none()
remaining["tpm"] = tpm_limit - (counter.token_count if counter else 0)
remaining["reset_at"] = (now + timedelta(seconds=self.window_seconds)).isoformat()
return remaining
async def record_tokens(self, key: str, tokens: int) -> None:
"""Record actual token usage after request completes."""
tpm_key = f"{key}:tpm"
result = await self.db.execute(
select(RateLimitCounter).where(RateLimitCounter.key == tpm_key)
)
counter = result.scalar_one_or_none()
if counter:
counter.token_count += tokens
else:
counter = RateLimitCounter(
key=tpm_key,
request_count=0,
token_count=tokens,
window_start=datetime.utcnow(),
window_duration=self.window_seconds,
)
self.db.add(counter)

View File

@ -0,0 +1,173 @@
"""Router for model alias resolution and routing."""
import json
import random
from dataclasses import dataclass
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.model_alias import ModelAlias
from app.models.provider import Provider
@dataclass
class RoutingResult:
"""Result of routing decision."""
provider: str
model: str
provider_config: dict[str, Any] | None = None
fallback_chain: list[dict[str, str]] | None = None
class Router:
"""Router for resolving model aliases and making routing decisions."""
def __init__(self, db: AsyncSession):
self.db = db
async def resolve_model(self, model_alias: str) -> RoutingResult:
"""
Resolve a model alias to a provider and model.
Args:
model_alias: The model alias to resolve.
Returns:
RoutingResult with provider, model, and optional fallback chain.
Raises:
ValueError: If the model alias is not found.
"""
# Look up the model alias
result = await self.db.execute(
select(ModelAlias).where(ModelAlias.alias == model_alias, ModelAlias.enabled == True)
)
alias = result.scalar_one_or_none()
if not alias:
# If no alias found, treat as direct provider/model reference
# Format: "provider/model" (e.g., "openai/gpt-4")
if "/" in model_alias:
provider, model = model_alias.split("/", 1)
return RoutingResult(provider=provider, model=model)
raise ValueError(f"Model alias '{model_alias}' not found")
# Parse routing config
routing_config = json.loads(alias.routing_config) if alias.routing_config else None
if alias.routing_type == "simple":
return RoutingResult(
provider=alias.provider,
model=alias.model,
)
elif alias.routing_type == "load_balance":
# Weighted random selection
providers = routing_config.get("providers", []) if routing_config else []
if not providers:
return RoutingResult(provider=alias.provider, model=alias.model)
# Filter healthy providers
healthy_providers = await self._filter_healthy_providers(providers)
if not healthy_providers:
# Fallback to default if all unhealthy
return RoutingResult(provider=alias.provider, model=alias.model)
# Weighted random selection
total_weight = sum(p.get("weight", 1) for p in healthy_providers)
r = random.uniform(0, total_weight)
cumulative = 0
for p in healthy_providers:
cumulative += p.get("weight", 1)
if r <= cumulative:
return RoutingResult(
provider=p["provider"],
model=p.get("model", alias.model),
)
# Fallback
return RoutingResult(
provider=healthy_providers[0]["provider"],
model=healthy_providers[0].get("model", alias.model),
)
elif alias.routing_type == "fallback":
# Return primary with fallback chain
primary = (
routing_config.get("primary", {}) if routing_config else {}
)
fallback = routing_config.get("fallback", []) if routing_config else []
return RoutingResult(
provider=primary.get("provider", alias.provider),
model=primary.get("model", alias.model),
fallback_chain=fallback,
)
# Default to simple routing
return RoutingResult(provider=alias.provider, model=alias.model)
async def _filter_healthy_providers(
self, providers: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Filter out unhealthy providers."""
healthy = []
for p in providers:
provider_name = p.get("provider")
if not provider_name:
continue
result = await self.db.execute(
select(Provider).where(
Provider.name == provider_name,
Provider.enabled == True,
Provider.health_status == "healthy",
)
)
provider = result.scalar_one_or_none()
if provider:
healthy.append(p)
return healthy
async def get_fallback_provider(
self, failed_provider: str, fallback_chain: list[dict[str, str]] | None
) -> RoutingResult | None:
"""
Get the next fallback provider after a failure.
Args:
failed_provider: The provider that failed.
fallback_chain: List of fallback providers.
Returns:
RoutingResult for the next provider, or None if no more fallbacks.
"""
if not fallback_chain:
return None
# Find the failed provider in the chain and return the next one
for i, fallback in enumerate(fallback_chain):
if fallback.get("provider") == failed_provider and i + 1 < len(fallback_chain):
next_fallback = fallback_chain[i + 1]
return RoutingResult(
provider=next_fallback.get("provider", ""),
model=next_fallback.get("model", ""),
fallback_chain=fallback_chain[i + 1 :],
)
# If failed provider not in chain or at the end, try from the beginning
for fallback in fallback_chain:
if fallback.get("provider") != failed_provider:
return RoutingResult(
provider=fallback.get("provider", ""),
model=fallback.get("model", ""),
fallback_chain=fallback_chain,
)
return None

View File

@ -0,0 +1,228 @@
"""Request transformer for converting between API formats."""
import time
import uuid
from typing import Any
from app.schemas.anthropic import (
AnthropicContentBlock,
AnthropicMessage,
AnthropicMessagesRequest,
AnthropicMessagesResponse,
AnthropicTextBlock,
AnthropicUsage,
)
from app.schemas.openai import (
OpenAIChatCompletionChoice,
OpenAIChatCompletionRequest,
OpenAIChatCompletionResponse,
OpenAIChatMessage,
OpenAIUsage,
)
class RequestTransformer:
"""Transform requests between different API formats."""
@staticmethod
def openai_to_anthropic(
request: OpenAIChatCompletionRequest,
) -> AnthropicMessagesRequest:
"""Convert OpenAI Chat Completions request to Anthropic Messages request."""
messages = []
system = None
for msg in request.messages:
if msg.role == "system":
# Extract system message
system = msg.content if isinstance(msg.content, str) else None
elif msg.role in ("user", "assistant"):
# Convert content to Anthropic format
if isinstance(msg.content, str):
content = msg.content
elif isinstance(msg.content, list):
# Handle multi-modal content
content = []
for part in msg.content:
if part.get("type") == "text":
content.append(
AnthropicContentBlock(type="text", text=part.get("text", ""))
)
elif part.get("type") == "image_url":
# Convert OpenAI image format to Anthropic
content.append(
AnthropicContentBlock(
type="image",
source={
"type": "url",
"url": part.get("image_url", {}).get("url", ""),
},
)
)
else:
content.append(
AnthropicContentBlock(type="text", text=str(part))
)
else:
content = str(msg.content) if msg.content else ""
messages.append(
AnthropicMessage(
role=msg.role, # type: ignore
content=content,
)
)
# Handle max_tokens
max_tokens = request.max_tokens or 4096
return AnthropicMessagesRequest(
model=request.model,
messages=messages,
system=system,
max_tokens=max_tokens,
temperature=request.temperature,
top_p=request.top_p,
stop_sequences=[request.stop] if isinstance(request.stop, str) else request.stop,
stream=request.stream,
tools=request.tools,
tool_choice=request.tool_choice,
)
@staticmethod
def anthropic_to_openai(
request: AnthropicMessagesRequest,
model: str | None = None,
) -> OpenAIChatCompletionRequest:
"""Convert Anthropic Messages request to OpenAI Chat Completions request."""
messages = []
# Add system message if present
if request.system:
messages.append(
OpenAIChatMessage(
role="system",
content=request.system,
)
)
# Convert messages
for msg in request.messages:
if isinstance(msg.content, str):
content = msg.content
elif isinstance(msg.content, list):
# Convert Anthropic content blocks to OpenAI format
content = []
for block in msg.content:
if hasattr(block, "type") and block.type == "text":
content.append({"type": "text", "text": block.text or ""})
elif isinstance(block, dict) and block.get("type") == "text":
content.append({"type": "text", "text": block.get("text", "")})
else:
content.append({"type": "text", "text": str(block)})
else:
content = str(msg.content)
messages.append(
OpenAIChatMessage(
role=msg.role, # type: ignore
content=content,
)
)
return OpenAIChatCompletionRequest(
model=model or request.model,
messages=messages,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
stream=request.stream,
tools=request.tools,
tool_choice=request.tool_choice,
)
@staticmethod
def anthropic_response_to_openai(
response: AnthropicMessagesResponse,
model: str,
) -> OpenAIChatCompletionResponse:
"""Convert Anthropic Messages response to OpenAI Chat Completions response."""
# Extract text content
content = ""
for block in response.content:
if hasattr(block, "type") and block.type == "text":
content += block.text or ""
elif isinstance(block, dict) and block.get("type") == "text":
content += block.get("text", "")
# Map stop reasons
finish_reason_map = {
"end_turn": "stop",
"max_tokens": "length",
"stop_sequence": "stop",
"tool_use": "tool_calls",
}
finish_reason = finish_reason_map.get(response.stop_reason or "", "stop")
return OpenAIChatCompletionResponse(
id=response.id,
object="chat.completion",
created=int(time.time()),
model=model,
choices=[
OpenAIChatCompletionChoice(
index=0,
message=OpenAIChatMessage(
role="assistant",
content=content,
),
finish_reason=finish_reason,
)
],
usage=OpenAIUsage(
prompt_tokens=response.usage.input_tokens,
completion_tokens=response.usage.output_tokens,
total_tokens=response.usage.input_tokens + response.usage.output_tokens,
)
if response.usage
else None,
)
@staticmethod
def openai_response_to_anthropic(
response: OpenAIChatCompletionResponse,
) -> AnthropicMessagesResponse:
"""Convert OpenAI Chat Completions response to Anthropic Messages response."""
# Extract content from choices
content = []
for choice in response.choices:
if choice.message.content:
content.append(
AnthropicTextBlock(
type="text",
text=choice.message.content,
)
)
# Map finish reasons
stop_reason_map = {
"stop": "end_turn",
"length": "max_tokens",
"tool_calls": "tool_use",
"content_filter": "end_turn",
}
stop_reason = stop_reason_map.get(
response.choices[0].finish_reason or "", "end_turn"
)
return AnthropicMessagesResponse(
id=response.id or f"msg_{uuid.uuid4().hex[:24]}",
type="message",
role="assistant",
content=content,
model=response.model,
stop_reason=stop_reason,
usage=AnthropicUsage(
input_tokens=response.usage.prompt_tokens if response.usage else 0,
output_tokens=response.usage.completion_tokens if response.usage else 0,
),
)

View File

@ -0,0 +1 @@
# db module

View File

@ -0,0 +1,70 @@
"""Database connection and session management."""
import os
from pathlib import Path
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from app.config import get_settings
class Base(DeclarativeBase):
"""Base class for all models."""
pass
# Create async engine
def get_engine():
"""Create database engine based on settings."""
settings = get_settings()
db_url = settings.database_url
# Ensure data directory exists for SQLite
if db_url.startswith("sqlite:///"):
db_path = db_url.replace("sqlite:///", "")
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
# Use aiosqlite for async SQLite
db_url = db_url.replace("sqlite:///", "sqlite+aiosqlite:///")
return create_async_engine(db_url, echo=settings.debug)
engine = get_engine()
# Session factory
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
async def init_db():
"""Initialize database tables."""
# Import all models to ensure they are registered
from app.models import ( # noqa: F401
provider,
api_key,
project,
model_alias,
usage,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def get_db():
"""Dependency to get database session."""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()

61
llm-gateway/app/main.py Normal file
View File

@ -0,0 +1,61 @@
"""FastAPI application entry point."""
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.admin import health, keys, models, projects, providers, usage
from app.api.v1 import chat, messages, responses
from app.config import get_settings
from app.db.database import init_db
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan handler."""
# Startup - skip if testing
if not os.environ.get("TESTING"):
await init_db()
yield
# Shutdown
def create_app() -> FastAPI:
"""Create and configure the FastAPI application."""
settings = get_settings()
app = FastAPI(
title=settings.app_name,
description="A unified LLM Gateway supporting multiple providers and API formats",
version="0.1.0",
debug=settings.debug,
lifespan=lifespan,
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Admin API routers
app.include_router(providers.router, prefix="/admin")
app.include_router(projects.router, prefix="/admin")
app.include_router(keys.router, prefix="/admin")
app.include_router(models.router, prefix="/admin")
app.include_router(usage.router, prefix="/admin")
app.include_router(health.router)
# LLM API routers
app.include_router(chat.router)
app.include_router(messages.router)
app.include_router(responses.router)
return app
app = create_app()

View File

@ -0,0 +1 @@
# middleware module

View File

@ -0,0 +1,17 @@
"""Models package."""
from app.models.provider import Provider
from app.models.project import Project
from app.models.api_key import APIKey
from app.models.model_alias import ModelAlias
from app.models.usage import RequestLog, UsageStatsHourly, AuditLog, RateLimitCounter
__all__ = [
"Provider",
"Project",
"APIKey",
"ModelAlias",
"RequestLog",
"UsageStatsHourly",
"AuditLog",
"RateLimitCounter",
]

View File

@ -0,0 +1,48 @@
"""API Key model."""
from datetime import datetime
from decimal import Decimal
from uuid import uuid4
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, Numeric, String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.database import Base
class APIKey(Base):
"""Virtual API Key for authentication and usage tracking."""
__tablename__ = "api_keys"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4()))
key_hash: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True)
key_prefix: Mapped[str] = mapped_column(String(20), nullable=False) # e.g., "sk-proj_abc..."
name: Mapped[str] = mapped_column(String(100), nullable=False)
# Project relationship
project_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("projects.id"))
project: Mapped["Project"] = relationship("Project", backref="api_keys")
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
expires_at: Mapped[datetime | None] = mapped_column(DateTime)
# Rate Limits
rpm_limit: Mapped[int | None] = mapped_column(Integer)
tpm_limit: Mapped[int | None] = mapped_column(Integer)
# Budget
budget_limit: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
budget_period: Mapped[str | None] = mapped_column(String(20)) # daily, weekly, monthly
# Permissions
allowed_models: Mapped[str | None] = mapped_column(Text) # JSON array
# Usage Stats
current_usage: Mapped[Decimal] = mapped_column(Numeric(10, 2), default=Decimal("0"))
total_requests: Mapped[int] = mapped_column(Integer, default=0)
# Timestamps
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)

View File

@ -0,0 +1,36 @@
"""Model alias model."""
from datetime import datetime
from decimal import Decimal
from uuid import uuid4
from sqlalchemy import Boolean, DateTime, Numeric, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from app.db.database import Base
class ModelAlias(Base):
"""Model alias for routing and cost tracking."""
__tablename__ = "model_aliases"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4()))
alias: Mapped[str] = mapped_column(String(100), unique=True, nullable=False, index=True)
provider: Mapped[str] = mapped_column(String(50), nullable=False)
model: Mapped[str] = mapped_column(String(100), nullable=False)
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
# Routing: simple, load_balance, fallback
routing_type: Mapped[str] = mapped_column(String(20), default="simple")
routing_config: Mapped[str | None] = mapped_column(Text) # JSON
# Pricing (per 1K tokens)
input_price_per_1k: Mapped[Decimal | None] = mapped_column(Numeric(10, 6))
output_price_per_1k: Mapped[Decimal | None] = mapped_column(Numeric(10, 6))
# Timestamps
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)

View File

@ -0,0 +1,31 @@
"""Project model."""
from datetime import datetime
from decimal import Decimal
from uuid import uuid4
from sqlalchemy import Boolean, DateTime, Numeric, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from app.db.database import Base
class Project(Base):
"""Project for organizing API keys and tracking usage."""
__tablename__ = "projects"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4()))
name: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
description: Mapped[str | None] = mapped_column(Text)
# Budget
budget_limit: Mapped[Decimal | None] = mapped_column(Numeric(10, 2))
budget_period: Mapped[str | None] = mapped_column(String(20)) # daily, weekly, monthly
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
# Timestamps
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)

View File

@ -0,0 +1,40 @@
"""Provider model."""
from datetime import datetime
from uuid import uuid4
from sqlalchemy import Boolean, DateTime, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from app.db.database import Base
class Provider(Base):
"""LLM Provider configuration."""
__tablename__ = "providers"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4()))
name: Mapped[str] = mapped_column(String(50), unique=True, nullable=False, index=True)
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
# API Configuration
api_base: Mapped[str] = mapped_column(String(500), nullable=False)
api_key_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
api_version: Mapped[str | None] = mapped_column(String(50))
# Provider-specific config (JSON)
config: Mapped[str | None] = mapped_column(Text)
# Rate Limits
rpm_limit: Mapped[int | None] = mapped_column(Integer)
tpm_limit: Mapped[int | None] = mapped_column(Integer)
# Health Status
health_status: Mapped[str] = mapped_column(String(20), default="healthy")
last_health_check: Mapped[datetime | None] = mapped_column(DateTime)
# Timestamps
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
)

View File

@ -0,0 +1,110 @@
"""Usage tracking models."""
from datetime import datetime
from decimal import Decimal
from uuid import uuid4
from sqlalchemy import BigInteger, DateTime, ForeignKey, Integer, Numeric, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from app.db.database import Base
class RequestLog(Base):
"""Request log for tracking individual API calls."""
__tablename__ = "request_logs"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4()))
timestamp: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, index=True)
# Keys and Projects
virtual_key_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("api_keys.id"), index=True
)
project_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("projects.id"), index=True
)
# Request details
provider: Mapped[str] = mapped_column(String(50), nullable=False)
model: Mapped[str] = mapped_column(String(100), nullable=False)
model_alias: Mapped[str | None] = mapped_column(String(100))
request_type: Mapped[str] = mapped_column(String(20)) # chat, completion, embedding
# Tokens
input_tokens: Mapped[int] = mapped_column(Integer)
output_tokens: Mapped[int] = mapped_column(Integer)
total_tokens: Mapped[int] = mapped_column(Integer)
# Response
status_code: Mapped[int] = mapped_column(Integer)
latency_ms: Mapped[int] = mapped_column(Integer)
finish_reason: Mapped[str | None] = mapped_column(String(50))
# Cost
cost_usd: Mapped[Decimal] = mapped_column(Numeric(10, 6))
# Request metadata (JSON)
request_metadata: Mapped[str | None] = mapped_column(Text)
class UsageStatsHourly(Base):
"""Hourly aggregated usage statistics."""
__tablename__ = "usage_stats_hourly"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
timestamp: Mapped[datetime] = mapped_column(DateTime, nullable=False, index=True)
virtual_key_id: Mapped[str | None] = mapped_column(String(36), index=True)
project_id: Mapped[str | None] = mapped_column(String(36), index=True)
provider: Mapped[str] = mapped_column(String(50))
model: Mapped[str] = mapped_column(String(100))
# Aggregates
request_count: Mapped[int] = mapped_column(Integer, default=0)
input_tokens: Mapped[int] = mapped_column(BigInteger, default=0)
output_tokens: Mapped[int] = mapped_column(BigInteger, default=0)
total_tokens: Mapped[int] = mapped_column(BigInteger, default=0)
cost_usd: Mapped[Decimal] = mapped_column(Numeric(10, 6), default=Decimal("0"))
avg_latency_ms: Mapped[int | None] = mapped_column(Integer)
error_count: Mapped[int] = mapped_column(Integer, default=0)
class AuditLog(Base):
"""Audit log for tracking admin operations."""
__tablename__ = "audit_logs"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4()))
timestamp: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, index=True)
# Who
actor: Mapped[str] = mapped_column(String(100), nullable=False)
# What
action: Mapped[str] = mapped_column(String(50), nullable=False) # create, update, delete
# Which
resource: Mapped[str] = mapped_column(String(50), nullable=False) # provider, key, project
resource_id: Mapped[str | None] = mapped_column(String(36))
# Changes
changes: Mapped[str | None] = mapped_column(Text) # JSON with before/after
# Request info
ip_address: Mapped[str | None] = mapped_column(String(50))
user_agent: Mapped[str | None] = mapped_column(String(500))
class RateLimitCounter(Base):
"""Rate limit counter for tracking usage windows."""
__tablename__ = "rate_limit_counters"
key: Mapped[str] = mapped_column(String(200), primary_key=True)
request_count: Mapped[int] = mapped_column(Integer, default=0)
token_count: Mapped[int] = mapped_column(BigInteger, default=0)
window_start: Mapped[datetime] = mapped_column(DateTime, nullable=False)
window_duration: Mapped[int] = mapped_column(Integer, nullable=False) # seconds

View File

@ -0,0 +1 @@
# schemas module

View File

@ -0,0 +1,107 @@
"""Anthropic Messages API request and response schemas."""
from typing import Any, Literal
from pydantic import BaseModel, Field
# === Request Models ===
class AnthropicContentBlock(BaseModel):
"""Anthropic content block."""
type: Literal["text", "image", "tool_use", "tool_result"]
text: str | None = None
source: dict[str, Any] | None = None # For images
tool_use_id: str | None = None # For tool_result
content: str | list[dict[str, Any]] | None = None # For tool_result
name: str | None = None # For tool_use
input: dict[str, Any] | None = None # For tool_use
class AnthropicMessage(BaseModel):
"""Anthropic message."""
role: Literal["user", "assistant"]
content: str | list[AnthropicContentBlock]
class AnthropicMessagesRequest(BaseModel):
"""Anthropic Messages API request."""
model: str
messages: list[AnthropicMessage]
system: str | None = None
max_tokens: int = Field(..., ge=1)
temperature: float | None = Field(None, ge=0, le=1)
top_p: float | None = Field(None, ge=0, le=1)
top_k: int | None = Field(None, ge=0)
stop_sequences: list[str] | None = None
stream: bool = False
tools: list[dict[str, Any]] | None = None
tool_choice: dict[str, Any] | None = None
metadata: dict[str, Any] | None = None
# === Response Models ===
class AnthropicTextBlock(BaseModel):
"""Anthropic text content block."""
type: Literal["text"] = "text"
text: str
class AnthropicToolUseBlock(BaseModel):
"""Anthropic tool use content block."""
type: Literal["tool_use"] = "tool_use"
id: str
name: str
input: dict[str, Any]
class AnthropicUsage(BaseModel):
"""Anthropic token usage."""
input_tokens: int
output_tokens: int
cache_creation_input_tokens: int | None = None
cache_read_input_tokens: int | None = None
class AnthropicMessagesResponse(BaseModel):
"""Anthropic Messages API response."""
id: str
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: list[AnthropicTextBlock | AnthropicToolUseBlock]
model: str
stop_reason: Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"] | None = None
stop_sequence: str | None = None
usage: AnthropicUsage
class AnthropicStreamMessage(BaseModel):
"""Anthropic streaming message start."""
type: Literal["message_start"] = "message_start"
message: AnthropicMessagesResponse
class AnthropicContentBlockDelta(BaseModel):
"""Anthropic streaming content block delta."""
type: Literal["content_block_delta"] = "content_block_delta"
index: int
delta: dict[str, Any]
class AnthropicMessageDelta(BaseModel):
"""Anthropic streaming message delta."""
type: Literal["message_delta"] = "message_delta"
delta: dict[str, Any]
usage: AnthropicUsage | None = None

View File

@ -0,0 +1,71 @@
"""API Key schemas for API requests and responses."""
from datetime import datetime
from decimal import Decimal
from pydantic import BaseModel, ConfigDict, Field
from app.schemas.common import BaseResponse
class APIKeyCreate(BaseModel):
"""Schema for creating an API key."""
name: str = Field(..., min_length=1, max_length=100)
project_id: str | None = Field(None, description="Associated project ID")
prefix: str = Field(default="sk", max_length=10, description="Key prefix")
expires_at: datetime | None = Field(None, description="Expiration date")
rpm_limit: int | None = Field(None, ge=1)
tpm_limit: int | None = Field(None, ge=1)
budget_limit: Decimal | None = Field(None, ge=0)
budget_period: str | None = Field(None, pattern="^(daily|weekly|monthly)$")
allowed_models: list[str] | None = Field(None, description="List of allowed model aliases")
enabled: bool = Field(True)
class APIKeyUpdate(BaseModel):
"""Schema for updating an API key."""
name: str | None = Field(None, min_length=1, max_length=100)
project_id: str | None = None
expires_at: datetime | None = None
rpm_limit: int | None = Field(None, ge=1)
tpm_limit: int | None = Field(None, ge=1)
budget_limit: Decimal | None = Field(None, ge=0)
budget_period: str | None = Field(None, pattern="^(daily|weekly|monthly)$")
allowed_models: list[str] | None = None
enabled: bool | None = None
class APIKeyResponse(BaseResponse):
"""Schema for API key response."""
id: str
key_prefix: str
name: str
project_id: str | None
enabled: bool
expires_at: datetime | None
rpm_limit: int | None
tpm_limit: int | None
budget_limit: Decimal | None
budget_period: str | None
allowed_models: list[str] | None
current_usage: Decimal
total_requests: int
created_at: datetime
updated_at: datetime
class APIKeyCreateResponse(APIKeyResponse):
"""Schema for API key creation response (includes the full key)."""
key: str = Field(..., description="Full API key (shown only once)")
class APIKeyListResponse(BaseResponse):
"""Schema for API key list response."""
keys: list[APIKeyResponse]
total: int
page: int
page_size: int

View File

@ -0,0 +1,30 @@
"""Common schemas for API requests and responses."""
from datetime import datetime
from decimal import Decimal
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
class BaseResponse(BaseModel):
"""Base response model."""
model_config = ConfigDict(from_attributes=True)
class ErrorResponse(BaseModel):
"""Error response model."""
error: dict[str, Any]
class PaginationParams(BaseModel):
"""Pagination parameters."""
page: int = Field(default=1, ge=1)
page_size: int = Field(default=20, ge=1, le=100)
@property
def offset(self) -> int:
"""Calculate offset for database query."""
return (self.page - 1) * self.page_size

View File

@ -0,0 +1,71 @@
"""Model alias schemas for API requests and responses."""
from datetime import datetime
from decimal import Decimal
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from app.schemas.common import BaseResponse
class RoutingConfig(BaseModel):
"""Routing configuration for load balancing and fallback."""
providers: list[dict[str, Any]] | None = Field(
None, description="List of providers with weights for load balancing"
)
primary: dict[str, str] | None = Field(None, description="Primary provider for fallback")
fallback: list[dict[str, str]] | None = Field(None, description="Fallback providers")
class ModelAliasCreate(BaseModel):
"""Schema for creating a model alias."""
alias: str = Field(..., min_length=1, max_length=100, description="User-facing model name")
provider: str = Field(..., min_length=1, max_length=50, description="Provider name")
model: str = Field(..., min_length=1, max_length=100, description="Actual model name")
routing_type: str = Field(
default="simple", pattern="^(simple|load_balance|fallback)$"
)
routing_config: RoutingConfig | None = None
input_price_per_1k: Decimal | None = Field(None, ge=0)
output_price_per_1k: Decimal | None = Field(None, ge=0)
enabled: bool = Field(True)
class ModelAliasUpdate(BaseModel):
"""Schema for updating a model alias."""
alias: str | None = Field(None, min_length=1, max_length=100)
provider: str | None = Field(None, min_length=1, max_length=50)
model: str | None = Field(None, min_length=1, max_length=100)
routing_type: str | None = Field(None, pattern="^(simple|load_balance|fallback)$")
routing_config: RoutingConfig | None = None
input_price_per_1k: Decimal | None = Field(None, ge=0)
output_price_per_1k: Decimal | None = Field(None, ge=0)
enabled: bool | None = None
class ModelAliasResponse(BaseResponse):
"""Schema for model alias response."""
id: str
alias: str
provider: str
model: str
enabled: bool
routing_type: str
routing_config: dict[str, Any] | None
input_price_per_1k: Decimal | None
output_price_per_1k: Decimal | None
created_at: datetime
updated_at: datetime
class ModelAliasListResponse(BaseResponse):
"""Schema for model alias list response."""
aliases: list[ModelAliasResponse]
total: int
page: int
page_size: int

View File

@ -0,0 +1,100 @@
"""OpenAI API request and response schemas."""
from typing import Any, Literal
from pydantic import BaseModel, Field
# === Request Models ===
class OpenAIChatMessage(BaseModel):
"""OpenAI chat message."""
role: Literal["system", "user", "assistant", "tool", "function"]
content: str | list[dict[str, Any]] | None = None
name: str | None = None
tool_calls: list[dict[str, Any]] | None = None
tool_call_id: str | None = None
class OpenAIChatCompletionRequest(BaseModel):
"""OpenAI Chat Completions API request."""
model: str
messages: list[OpenAIChatMessage]
temperature: float | None = Field(None, ge=0, le=2)
top_p: float | None = Field(None, ge=0, le=1)
n: int | None = Field(None, ge=1)
stream: bool = False
stop: str | list[str] | None = None
max_tokens: int | None = Field(None, ge=1)
presence_penalty: float | None = Field(None, ge=-2, le=2)
frequency_penalty: float | None = Field(None, ge=-2, le=2)
logit_bias: dict[str, float] | None = None
user: str | None = None
tools: list[dict[str, Any]] | None = None
tool_choice: str | dict[str, Any] | None = None
response_format: dict[str, Any] | None = None
class OpenAIResponseRequest(BaseModel):
"""OpenAI Responses API request (new format)."""
model: str
input: str | list[dict[str, Any]]
instructions: str | None = None
temperature: float | None = Field(None, ge=0, le=2)
max_output_tokens: int | None = Field(None, ge=1)
tools: list[dict[str, Any]] | None = None
tool_choice: str | dict[str, Any] | None = None
metadata: dict[str, Any] | None = None
# === Response Models ===
class OpenAIChatCompletionChoice(BaseModel):
"""OpenAI chat completion choice."""
index: int
message: OpenAIChatMessage
finish_reason: str | None
class OpenAIUsage(BaseModel):
"""Token usage information."""
prompt_tokens: int
completion_tokens: int
total_tokens: int
class OpenAIChatCompletionResponse(BaseModel):
"""OpenAI Chat Completions API response."""
id: str
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
choices: list[OpenAIChatCompletionChoice]
usage: OpenAIUsage | None = None
system_fingerprint: str | None = None
class OpenAIStreamChoice(BaseModel):
"""OpenAI streaming choice."""
index: int
delta: dict[str, Any]
finish_reason: str | None
class OpenAIChatCompletionChunk(BaseModel):
"""OpenAI Chat Completions streaming chunk."""
id: str
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int
model: str
choices: list[OpenAIStreamChoice]
system_fingerprint: str | None = None

View File

@ -0,0 +1,49 @@
"""Project schemas for API requests and responses."""
from datetime import datetime
from decimal import Decimal
from pydantic import BaseModel, ConfigDict, Field
from app.schemas.common import BaseResponse
class ProjectCreate(BaseModel):
"""Schema for creating a project."""
name: str = Field(..., min_length=1, max_length=100)
description: str | None = Field(None, max_length=1000)
budget_limit: Decimal | None = Field(None, ge=0)
budget_period: str | None = Field(None, pattern="^(daily|weekly|monthly)$")
enabled: bool = Field(True)
class ProjectUpdate(BaseModel):
"""Schema for updating a project."""
name: str | None = Field(None, min_length=1, max_length=100)
description: str | None = Field(None, max_length=1000)
budget_limit: Decimal | None = Field(None, ge=0)
budget_period: str | None = Field(None, pattern="^(daily|weekly|monthly)$")
enabled: bool | None = None
class ProjectResponse(BaseResponse):
"""Schema for project response."""
id: str
name: str
description: str | None
budget_limit: Decimal | None
budget_period: str | None
enabled: bool
created_at: datetime
updated_at: datetime
class ProjectListResponse(BaseResponse):
"""Schema for project list response."""
projects: list[ProjectResponse]
total: int
page: int
page_size: int

View File

@ -0,0 +1,58 @@
"""Provider schemas for API requests and responses."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from app.schemas.common import BaseResponse
class ProviderCreate(BaseModel):
"""Schema for creating a provider."""
name: str = Field(..., min_length=1, max_length=50, description="Provider name")
api_base: str = Field(..., min_length=1, max_length=500, description="API base URL")
api_key: str = Field(..., min_length=1, description="API key (will be encrypted)")
api_version: str | None = Field(None, max_length=50, description="API version")
config: dict[str, Any] | None = Field(None, description="Provider-specific configuration")
rpm_limit: int | None = Field(None, ge=1, description="Requests per minute limit")
tpm_limit: int | None = Field(None, ge=1, description="Tokens per minute limit")
enabled: bool = Field(True, description="Whether provider is enabled")
class ProviderUpdate(BaseModel):
"""Schema for updating a provider."""
api_base: str | None = Field(None, min_length=1, max_length=500)
api_key: str | None = Field(None, min_length=1)
api_version: str | None = Field(None, max_length=50)
config: dict[str, Any] | None = None
rpm_limit: int | None = Field(None, ge=1)
tpm_limit: int | None = Field(None, ge=1)
enabled: bool | None = None
class ProviderResponse(BaseResponse):
"""Schema for provider response."""
id: str
name: str
enabled: bool
api_base: str
api_version: str | None
config: dict[str, Any] | None
rpm_limit: int | None
tpm_limit: int | None
health_status: str
last_health_check: datetime | None
created_at: datetime
updated_at: datetime
class ProviderListResponse(BaseResponse):
"""Schema for provider list response."""
providers: list[ProviderResponse]
total: int
page: int
page_size: int

View File

@ -0,0 +1 @@
# utils module

View File

@ -0,0 +1,123 @@
"""Cryptographic utilities for key management."""
import base64
import os
from typing import Tuple
import bcrypt
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from app.config import get_settings
def hash_api_key(key: str) -> str:
"""
Hash an API key using bcrypt.
Args:
key: The plaintext API key to hash.
Returns:
The hashed key string.
"""
salt = bcrypt.gensalt()
hashed = bcrypt.hashpw(key.encode("utf-8"), salt)
return hashed.decode("utf-8")
def verify_api_key(key: str, hashed: str) -> bool:
"""
Verify an API key against its hash.
Args:
key: The plaintext API key to verify.
hashed: The hashed key to compare against.
Returns:
True if the key matches, False otherwise.
"""
try:
return bcrypt.checkpw(key.encode("utf-8"), hashed.encode("utf-8"))
except Exception:
return False
def generate_api_key(prefix: str = "sk") -> Tuple[str, str, str]:
"""
Generate a new API key.
Args:
prefix: The prefix for the key (e.g., "sk", "sk-proj").
Returns:
Tuple of (full_key, key_hash, key_prefix).
"""
# Generate random key part
random_part = base64.urlsafe_b64encode(os.urandom(24)).decode("utf-8").rstrip("=")
full_key = f"{prefix}_{random_part}"
# Hash the key
key_hash = hash_api_key(full_key)
# Create prefix for display (first 12 chars)
key_prefix = full_key[:12] + "..."
return full_key, key_hash, key_prefix
def get_encryption_key() -> bytes:
"""
Derive an encryption key from the master key.
Returns:
A URL-safe base64-encoded 32-byte key for Fernet.
"""
settings = get_settings()
master_key = settings.master_key.encode("utf-8")
# Use a fixed salt for deterministic key derivation
# In production, consider using a random salt stored securely
salt = b"llm-gateway-encryption-salt-v1"
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(master_key))
return key
def encrypt_value(plaintext: str) -> str:
"""
Encrypt a value using Fernet symmetric encryption.
Args:
plaintext: The value to encrypt.
Returns:
The encrypted value as a base64 string.
"""
key = get_encryption_key()
f = Fernet(key)
encrypted = f.encrypt(plaintext.encode("utf-8"))
return encrypted.decode("utf-8")
def decrypt_value(encrypted: str) -> str:
"""
Decrypt a value using Fernet symmetric encryption.
Args:
encrypted: The encrypted value as a base64 string.
Returns:
The decrypted plaintext.
"""
key = get_encryption_key()
f = Fernet(key)
decrypted = f.decrypt(encrypted.encode("utf-8"))
return decrypted.decode("utf-8")

View File

@ -0,0 +1,65 @@
"""Structured logging configuration."""
import logging
import sys
import structlog
from structlog.types import Processor
from app.config import get_settings
def setup_logging() -> None:
"""Configure structured logging for the application."""
settings = get_settings()
# Shared processors for both stdlib and structlog
shared_processors: list[Processor] = [
structlog.contextvars.merge_contextvars,
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
]
# Configure structlog
structlog.configure(
processors=shared_processors
+ [
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
],
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
# Configure stdlib logging
handler = logging.StreamHandler(sys.stdout)
if settings.debug or settings.log_level == "DEBUG":
formatter = structlog.stdlib.ProcessorFormatter(
foreign_pre_chain=shared_processors,
processors=[
structlog.dev.ConsoleRenderer(colors=True),
],
)
else:
formatter = structlog.stdlib.ProcessorFormatter(
foreign_pre_chain=shared_processors,
processors=[
structlog.processors.format_exc_info,
structlog.processors.JSONRenderer(),
],
)
handler.setFormatter(formatter)
# Clear existing handlers and add our handler
root_logger = logging.getLogger()
root_logger.handlers.clear()
root_logger.addHandler(handler)
root_logger.setLevel(getattr(logging, settings.log_level.upper(), logging.INFO))
def get_logger(name: str | None = None) -> structlog.stdlib.BoundLogger:
"""Get a structured logger instance."""
return structlog.get_logger(name)

View File

@ -0,0 +1,33 @@
version: '3.8'
services:
gateway:
build: .
ports:
- "8000:8000"
volumes:
- ./data:/app/data
environment:
- DATABASE_URL=sqlite:///data/gateway.db
- MASTER_KEY=${MASTER_KEY}
- DEBUG=${DEBUG:-false}
- LOG_LEVEL=${LOG_LEVEL:-INFO}
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
# Optional: Redis for distributed rate limiting (future)
# redis:
# image: redis:7-alpine
# ports:
# - "6379:6379"
# volumes:
# - redis_data:/data
volumes:
gateway_data:
# redis_data:

View File

@ -0,0 +1,37 @@
# Web Framework
fastapi>=0.109.0
uvicorn[standard]>=0.27.0
pydantic>=2.5.0
pydantic-settings>=2.1.0
# HTTP Client
httpx>=0.26.0
# Database
sqlalchemy>=2.0.0
aiosqlite>=0.19.0
# Security
bcrypt>=4.1.0
cryptography>=42.0.0
# Logging
structlog>=24.1.0
# Testing
pytest>=8.0.0
pytest-asyncio>=0.23.0
pytest-cov>=4.1.0
httpx>=0.26.0
# Type Checking
mypy>=1.8.0
# Code Quality
ruff>=0.2.0
# Provider SDKs
openai>=1.12.0
anthropic>=0.18.0
google-generativeai>=0.3.0
boto3>=1.34.0

View File

@ -0,0 +1 @@
# tests module

View File

@ -0,0 +1,106 @@
"""Pytest configuration and fixtures."""
import asyncio
import os
import tempfile
from typing import AsyncGenerator, Generator
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
# Set test environment before importing app
os.environ["TESTING"] = "1"
# Create a temp file for the test database
_test_db_fd, _test_db_path = tempfile.mkstemp(suffix=".db")
os.environ["DATABASE_URL"] = f"sqlite:///{_test_db_path}"
os.environ["MASTER_KEY"] = "test-master-key-for-testing-32cha"
# Now import the app modules
from app.db import database
from app.db.database import Base, get_db
# Import all models to register them with Base
from app.models import ( # noqa: F401
APIKey,
AuditLog,
ModelAlias,
Project,
Provider,
RateLimitCounter,
RequestLog,
UsageStatsHourly,
)
# Create test engine
_test_db_url = f"sqlite+aiosqlite:///{_test_db_path}"
_test_engine = create_async_engine(_test_db_url, echo=False)
_test_session_factory = async_sessionmaker(
_test_engine,
class_=AsyncSession,
expire_on_commit=False,
)
# Override the global engine and session factory
database.engine = _test_engine
database.AsyncSessionLocal = _test_session_factory
@pytest.fixture(scope="session")
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
"""Create an event loop for the test session."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture(scope="session")
async def setup_test_db() -> AsyncGenerator[None, None]:
"""Set up test database with all tables."""
# Create all tables
async with _test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
# Clean up
async with _test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await _test_engine.dispose()
os.close(_test_db_fd)
os.unlink(_test_db_path)
@pytest_asyncio.fixture
async def db_session(setup_test_db: None) -> AsyncGenerator[AsyncSession, None]:
"""Provide a database session for tests."""
async with _test_session_factory() as session:
yield session
async def _get_test_db():
"""Override dependency for test database."""
async with _test_session_factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
@pytest_asyncio.fixture
async def client(setup_test_db: None) -> AsyncGenerator[AsyncClient, None]:
"""Provide an async HTTP client for API tests."""
from app.main import create_app
test_app = create_app()
# Override the database dependency
test_app.dependency_overrides[get_db] = _get_test_db
transport = ASGITransport(app=test_app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac

View File

@ -0,0 +1 @@
# integration module

View File

@ -0,0 +1,237 @@
"""Tests for Admin API endpoints."""
import pytest
from httpx import AsyncClient
class TestProviderAPI:
"""Test Provider management API."""
@pytest.mark.asyncio
async def test_create_provider(self, client: AsyncClient, setup_test_db):
"""Test creating a provider."""
response = await client.post(
"/admin/providers",
json={
"name": "openai",
"api_base": "https://api.openai.com/v1",
"api_key": "test-api-key-1234567890",
"api_version": "v1",
"rpm_limit": 1000,
"tpm_limit": 100000,
},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "openai"
assert data["api_base"] == "https://api.openai.com/v1"
assert data["health_status"] == "healthy"
@pytest.mark.asyncio
async def test_list_providers(self, client: AsyncClient, setup_test_db):
"""Test listing providers."""
# Create a provider first
await client.post(
"/admin/providers",
json={
"name": "anthropic",
"api_base": "https://api.anthropic.com",
"api_key": "test-key",
},
)
response = await client.get("/admin/providers")
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
assert len(data["providers"]) >= 1
@pytest.mark.asyncio
async def test_get_provider(self, client: AsyncClient, setup_test_db):
"""Test getting a specific provider."""
create_response = await client.post(
"/admin/providers",
json={
"name": "gemini",
"api_base": "https://generativelanguage.googleapis.com",
"api_key": "test-key",
},
)
provider_id = create_response.json()["id"]
response = await client.get(f"/admin/providers/{provider_id}")
assert response.status_code == 200
data = response.json()
assert data["name"] == "gemini"
@pytest.mark.asyncio
async def test_update_provider(self, client: AsyncClient, setup_test_db):
"""Test updating a provider."""
create_response = await client.post(
"/admin/providers",
json={
"name": "azure",
"api_base": "https://example.openai.azure.com",
"api_key": "test-key",
},
)
provider_id = create_response.json()["id"]
response = await client.put(
f"/admin/providers/{provider_id}",
json={"rpm_limit": 500, "enabled": False},
)
assert response.status_code == 200
data = response.json()
assert data["rpm_limit"] == 500
assert data["enabled"] is False
@pytest.mark.asyncio
async def test_delete_provider(self, client: AsyncClient, setup_test_db):
"""Test deleting a provider."""
create_response = await client.post(
"/admin/providers",
json={
"name": "bedrock",
"api_base": "https://bedrock.us-east-1.amazonaws.com",
"api_key": "test-key",
},
)
provider_id = create_response.json()["id"]
response = await client.delete(f"/admin/providers/{provider_id}")
assert response.status_code == 204
# Verify it's deleted
get_response = await client.get(f"/admin/providers/{provider_id}")
assert get_response.status_code == 404
class TestProjectAPI:
"""Test Project management API."""
@pytest.mark.asyncio
async def test_create_project(self, client: AsyncClient, setup_test_db):
"""Test creating a project."""
response = await client.post(
"/admin/projects",
json={
"name": "Test Project",
"description": "A test project",
"budget_limit": 100.00,
"budget_period": "monthly",
},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "Test Project"
assert data["budget_limit"] == "100.00"
@pytest.mark.asyncio
async def test_list_projects(self, client: AsyncClient, setup_test_db):
"""Test listing projects."""
await client.post(
"/admin/projects",
json={"name": "Project 1"},
)
response = await client.get("/admin/projects")
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
class TestAPIKeyAPI:
"""Test API Key management API."""
@pytest.mark.asyncio
async def test_create_api_key(self, client: AsyncClient, setup_test_db):
"""Test creating an API key."""
response = await client.post(
"/admin/keys",
json={
"name": "Test Key",
"prefix": "sk-test",
"rpm_limit": 100,
"tpm_limit": 10000,
},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "Test Key"
assert "key" in data # Full key shown on creation
assert data["key"].startswith("sk-test_")
assert data["key_prefix"].startswith("sk-test_")
@pytest.mark.asyncio
async def test_list_api_keys(self, client: AsyncClient, setup_test_db):
"""Test listing API keys."""
await client.post(
"/admin/keys",
json={"name": "Key 1"},
)
response = await client.get("/admin/keys")
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
class TestModelAliasAPI:
"""Test Model Alias management API."""
@pytest.mark.asyncio
async def test_create_model_alias(self, client: AsyncClient, setup_test_db):
"""Test creating a model alias."""
response = await client.post(
"/admin/models/aliases",
json={
"alias": "gpt-4",
"provider": "openai",
"model": "gpt-4-turbo",
"input_price_per_1k": 0.01,
"output_price_per_1k": 0.03,
},
)
assert response.status_code == 201
data = response.json()
assert data["alias"] == "gpt-4"
assert data["provider"] == "openai"
assert data["model"] == "gpt-4-turbo"
@pytest.mark.asyncio
async def test_list_model_aliases(self, client: AsyncClient, setup_test_db):
"""Test listing model aliases."""
await client.post(
"/admin/models/aliases",
json={
"alias": "claude",
"provider": "anthropic",
"model": "claude-3-sonnet",
},
)
response = await client.get("/admin/models/aliases")
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
class TestHealthAPI:
"""Test Health check API."""
@pytest.mark.asyncio
async def test_health_check(self, client: AsyncClient, setup_test_db):
"""Test health check endpoint."""
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
@pytest.mark.asyncio
async def test_readiness_check(self, client: AsyncClient, setup_test_db):
"""Test readiness check endpoint."""
response = await client.get("/ready")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ready"
assert data["database"] == "connected"

View File

@ -0,0 +1 @@
# unit module

View File

@ -0,0 +1,44 @@
"""Tests for configuration module."""
import os
import pytest
from pathlib import Path
from app.config import Settings
class TestSettings:
"""Test configuration settings."""
def test_default_settings(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
"""Test default configuration values."""
# Change to temp directory for database path
monkeypatch.chdir(tmp_path)
settings = Settings()
assert settings.app_name == "LLM Gateway"
assert settings.debug is False
assert settings.database_url.startswith("sqlite:///")
assert settings.master_key is not None
def test_custom_settings_from_env(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
"""Test loading settings from environment variables."""
monkeypatch.chdir(tmp_path)
monkeypatch.setenv("APP_NAME", "Custom Gateway")
monkeypatch.setenv("DEBUG", "true")
monkeypatch.setenv("DATABASE_URL", "sqlite:///custom.db")
monkeypatch.setenv("MASTER_KEY", "test-master-key-12345678901234567890") # 40 chars
settings = Settings()
assert settings.app_name == "Custom Gateway"
assert settings.debug is True
assert settings.database_url == "sqlite:///custom.db"
assert settings.master_key == "test-master-key-12345678901234567890"
def test_generate_master_key(self):
"""Test master key generation."""
key = Settings.generate_master_key()
assert len(key) == 64 # 32 bytes = 64 hex chars
assert isinstance(key, str)

View File

@ -0,0 +1,34 @@
"""Tests for database module."""
import pytest
from app.db.database import init_db, get_db
from app.models import Provider, APIKey, Project
class TestDatabase:
"""Test database initialization and connections."""
@pytest.mark.asyncio
async def test_init_db_creates_tables(self, db_session):
"""Test that init_db creates all required tables."""
# Tables should exist if we got a session
assert db_session is not None
# Try to query each table
from sqlalchemy import text
result = await db_session.execute(text("SELECT name FROM sqlite_master WHERE type='table'"))
tables = [row[0] for row in result.fetchall()]
assert "providers" in tables
assert "api_keys" in tables
assert "projects" in tables
assert "model_aliases" in tables
assert "request_logs" in tables
@pytest.mark.asyncio
async def test_get_db_returns_session(self, setup_test_db):
"""Test that get_db returns a valid database session."""
async for session in get_db():
assert session is not None
break