root 8348520bdf feat: implement LLM Gateway with multi-provider support
Implement a unified LLM Gateway supporting multiple API formats and providers:

Features:
- OpenAI Chat Completions, Responses API, and Anthropic Messages API
- Provider adapters for OpenAI, Anthropic, Azure OpenAI, Google Gemini, AWS Bedrock
- Model aliasing with weighted round-robin load balancing
- Virtual API keys with RPM/TPM rate limiting
- Budget control at key and project levels
- Request logging, usage statistics, and audit logs
- Fallback/retry with circuit breaker pattern
- Admin CRUD APIs for providers, projects, keys, models, usage
- Provider health checks

Tech stack:
- FastAPI with async SQLAlchemy 2.0
- SQLite with aiosqlite
- bcrypt for API key hashing, AES-256 for provider key encryption
- Docker containerization

Tests: 18 passing integration tests for admin API endpoints

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-01 15:39:21 +08:00

174 lines
5.9 KiB
Python

"""Router for model alias resolution and routing."""
import json
import random
from dataclasses import dataclass
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.model_alias import ModelAlias
from app.models.provider import Provider
@dataclass
class RoutingResult:
"""Result of routing decision."""
provider: str
model: str
provider_config: dict[str, Any] | None = None
fallback_chain: list[dict[str, str]] | None = None
class Router:
"""Router for resolving model aliases and making routing decisions."""
def __init__(self, db: AsyncSession):
self.db = db
async def resolve_model(self, model_alias: str) -> RoutingResult:
"""
Resolve a model alias to a provider and model.
Args:
model_alias: The model alias to resolve.
Returns:
RoutingResult with provider, model, and optional fallback chain.
Raises:
ValueError: If the model alias is not found.
"""
# Look up the model alias
result = await self.db.execute(
select(ModelAlias).where(ModelAlias.alias == model_alias, ModelAlias.enabled == True)
)
alias = result.scalar_one_or_none()
if not alias:
# If no alias found, treat as direct provider/model reference
# Format: "provider/model" (e.g., "openai/gpt-4")
if "/" in model_alias:
provider, model = model_alias.split("/", 1)
return RoutingResult(provider=provider, model=model)
raise ValueError(f"Model alias '{model_alias}' not found")
# Parse routing config
routing_config = json.loads(alias.routing_config) if alias.routing_config else None
if alias.routing_type == "simple":
return RoutingResult(
provider=alias.provider,
model=alias.model,
)
elif alias.routing_type == "load_balance":
# Weighted random selection
providers = routing_config.get("providers", []) if routing_config else []
if not providers:
return RoutingResult(provider=alias.provider, model=alias.model)
# Filter healthy providers
healthy_providers = await self._filter_healthy_providers(providers)
if not healthy_providers:
# Fallback to default if all unhealthy
return RoutingResult(provider=alias.provider, model=alias.model)
# Weighted random selection
total_weight = sum(p.get("weight", 1) for p in healthy_providers)
r = random.uniform(0, total_weight)
cumulative = 0
for p in healthy_providers:
cumulative += p.get("weight", 1)
if r <= cumulative:
return RoutingResult(
provider=p["provider"],
model=p.get("model", alias.model),
)
# Fallback
return RoutingResult(
provider=healthy_providers[0]["provider"],
model=healthy_providers[0].get("model", alias.model),
)
elif alias.routing_type == "fallback":
# Return primary with fallback chain
primary = (
routing_config.get("primary", {}) if routing_config else {}
)
fallback = routing_config.get("fallback", []) if routing_config else []
return RoutingResult(
provider=primary.get("provider", alias.provider),
model=primary.get("model", alias.model),
fallback_chain=fallback,
)
# Default to simple routing
return RoutingResult(provider=alias.provider, model=alias.model)
async def _filter_healthy_providers(
self, providers: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Filter out unhealthy providers."""
healthy = []
for p in providers:
provider_name = p.get("provider")
if not provider_name:
continue
result = await self.db.execute(
select(Provider).where(
Provider.name == provider_name,
Provider.enabled == True,
Provider.health_status == "healthy",
)
)
provider = result.scalar_one_or_none()
if provider:
healthy.append(p)
return healthy
async def get_fallback_provider(
self, failed_provider: str, fallback_chain: list[dict[str, str]] | None
) -> RoutingResult | None:
"""
Get the next fallback provider after a failure.
Args:
failed_provider: The provider that failed.
fallback_chain: List of fallback providers.
Returns:
RoutingResult for the next provider, or None if no more fallbacks.
"""
if not fallback_chain:
return None
# Find the failed provider in the chain and return the next one
for i, fallback in enumerate(fallback_chain):
if fallback.get("provider") == failed_provider and i + 1 < len(fallback_chain):
next_fallback = fallback_chain[i + 1]
return RoutingResult(
provider=next_fallback.get("provider", ""),
model=next_fallback.get("model", ""),
fallback_chain=fallback_chain[i + 1 :],
)
# If failed provider not in chain or at the end, try from the beginning
for fallback in fallback_chain:
if fallback.get("provider") != failed_provider:
return RoutingResult(
provider=fallback.get("provider", ""),
model=fallback.get("model", ""),
fallback_chain=fallback_chain,
)
return None