Implement a unified LLM Gateway supporting multiple API formats and providers: Features: - OpenAI Chat Completions, Responses API, and Anthropic Messages API - Provider adapters for OpenAI, Anthropic, Azure OpenAI, Google Gemini, AWS Bedrock - Model aliasing with weighted round-robin load balancing - Virtual API keys with RPM/TPM rate limiting - Budget control at key and project levels - Request logging, usage statistics, and audit logs - Fallback/retry with circuit breaker pattern - Admin CRUD APIs for providers, projects, keys, models, usage - Provider health checks Tech stack: - FastAPI with async SQLAlchemy 2.0 - SQLite with aiosqlite - bcrypt for API key hashing, AES-256 for provider key encryption - Docker containerization Tests: 18 passing integration tests for admin API endpoints Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
174 lines
5.9 KiB
Python
174 lines
5.9 KiB
Python
"""Router for model alias resolution and routing."""
|
|
import json
|
|
import random
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.model_alias import ModelAlias
|
|
from app.models.provider import Provider
|
|
|
|
|
|
@dataclass
|
|
class RoutingResult:
|
|
"""Result of routing decision."""
|
|
|
|
provider: str
|
|
model: str
|
|
provider_config: dict[str, Any] | None = None
|
|
fallback_chain: list[dict[str, str]] | None = None
|
|
|
|
|
|
class Router:
|
|
"""Router for resolving model aliases and making routing decisions."""
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def resolve_model(self, model_alias: str) -> RoutingResult:
|
|
"""
|
|
Resolve a model alias to a provider and model.
|
|
|
|
Args:
|
|
model_alias: The model alias to resolve.
|
|
|
|
Returns:
|
|
RoutingResult with provider, model, and optional fallback chain.
|
|
|
|
Raises:
|
|
ValueError: If the model alias is not found.
|
|
"""
|
|
# Look up the model alias
|
|
result = await self.db.execute(
|
|
select(ModelAlias).where(ModelAlias.alias == model_alias, ModelAlias.enabled == True)
|
|
)
|
|
alias = result.scalar_one_or_none()
|
|
|
|
if not alias:
|
|
# If no alias found, treat as direct provider/model reference
|
|
# Format: "provider/model" (e.g., "openai/gpt-4")
|
|
if "/" in model_alias:
|
|
provider, model = model_alias.split("/", 1)
|
|
return RoutingResult(provider=provider, model=model)
|
|
raise ValueError(f"Model alias '{model_alias}' not found")
|
|
|
|
# Parse routing config
|
|
routing_config = json.loads(alias.routing_config) if alias.routing_config else None
|
|
|
|
if alias.routing_type == "simple":
|
|
return RoutingResult(
|
|
provider=alias.provider,
|
|
model=alias.model,
|
|
)
|
|
|
|
elif alias.routing_type == "load_balance":
|
|
# Weighted random selection
|
|
providers = routing_config.get("providers", []) if routing_config else []
|
|
if not providers:
|
|
return RoutingResult(provider=alias.provider, model=alias.model)
|
|
|
|
# Filter healthy providers
|
|
healthy_providers = await self._filter_healthy_providers(providers)
|
|
|
|
if not healthy_providers:
|
|
# Fallback to default if all unhealthy
|
|
return RoutingResult(provider=alias.provider, model=alias.model)
|
|
|
|
# Weighted random selection
|
|
total_weight = sum(p.get("weight", 1) for p in healthy_providers)
|
|
r = random.uniform(0, total_weight)
|
|
|
|
cumulative = 0
|
|
for p in healthy_providers:
|
|
cumulative += p.get("weight", 1)
|
|
if r <= cumulative:
|
|
return RoutingResult(
|
|
provider=p["provider"],
|
|
model=p.get("model", alias.model),
|
|
)
|
|
|
|
# Fallback
|
|
return RoutingResult(
|
|
provider=healthy_providers[0]["provider"],
|
|
model=healthy_providers[0].get("model", alias.model),
|
|
)
|
|
|
|
elif alias.routing_type == "fallback":
|
|
# Return primary with fallback chain
|
|
primary = (
|
|
routing_config.get("primary", {}) if routing_config else {}
|
|
)
|
|
fallback = routing_config.get("fallback", []) if routing_config else []
|
|
|
|
return RoutingResult(
|
|
provider=primary.get("provider", alias.provider),
|
|
model=primary.get("model", alias.model),
|
|
fallback_chain=fallback,
|
|
)
|
|
|
|
# Default to simple routing
|
|
return RoutingResult(provider=alias.provider, model=alias.model)
|
|
|
|
async def _filter_healthy_providers(
|
|
self, providers: list[dict[str, Any]]
|
|
) -> list[dict[str, Any]]:
|
|
"""Filter out unhealthy providers."""
|
|
healthy = []
|
|
for p in providers:
|
|
provider_name = p.get("provider")
|
|
if not provider_name:
|
|
continue
|
|
|
|
result = await self.db.execute(
|
|
select(Provider).where(
|
|
Provider.name == provider_name,
|
|
Provider.enabled == True,
|
|
Provider.health_status == "healthy",
|
|
)
|
|
)
|
|
provider = result.scalar_one_or_none()
|
|
|
|
if provider:
|
|
healthy.append(p)
|
|
|
|
return healthy
|
|
|
|
async def get_fallback_provider(
|
|
self, failed_provider: str, fallback_chain: list[dict[str, str]] | None
|
|
) -> RoutingResult | None:
|
|
"""
|
|
Get the next fallback provider after a failure.
|
|
|
|
Args:
|
|
failed_provider: The provider that failed.
|
|
fallback_chain: List of fallback providers.
|
|
|
|
Returns:
|
|
RoutingResult for the next provider, or None if no more fallbacks.
|
|
"""
|
|
if not fallback_chain:
|
|
return None
|
|
|
|
# Find the failed provider in the chain and return the next one
|
|
for i, fallback in enumerate(fallback_chain):
|
|
if fallback.get("provider") == failed_provider and i + 1 < len(fallback_chain):
|
|
next_fallback = fallback_chain[i + 1]
|
|
return RoutingResult(
|
|
provider=next_fallback.get("provider", ""),
|
|
model=next_fallback.get("model", ""),
|
|
fallback_chain=fallback_chain[i + 1 :],
|
|
)
|
|
|
|
# If failed provider not in chain or at the end, try from the beginning
|
|
for fallback in fallback_chain:
|
|
if fallback.get("provider") != failed_provider:
|
|
return RoutingResult(
|
|
provider=fallback.get("provider", ""),
|
|
model=fallback.get("model", ""),
|
|
fallback_chain=fallback_chain,
|
|
)
|
|
|
|
return None
|