test(router): add router unit tests and fix test data isolation
- Add unit tests for Router model alias resolution - Fix async fixture configuration with pytest_asyncio - Add automatic rollback in db_session fixture for data isolation - Fix test_get_fallback_provider chain handling Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
8348520bdf
commit
681ad84674
@ -75,9 +75,13 @@ async def setup_test_db() -> AsyncGenerator[None, None]:
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session(setup_test_db: None) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Provide a database session for tests."""
|
||||
"""Provide a database session for tests with automatic rollback."""
|
||||
async with _test_session_factory() as session:
|
||||
# Start a transaction
|
||||
await session.begin()
|
||||
yield session
|
||||
# Rollback after test to ensure data isolation
|
||||
await session.rollback()
|
||||
|
||||
|
||||
async def _get_test_db():
|
||||
|
||||
174
llm-gateway/tests/unit/test_router.py
Normal file
174
llm-gateway/tests/unit/test_router.py
Normal file
@ -0,0 +1,174 @@
|
||||
"""Tests for router module."""
|
||||
import json
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.router import Router, RoutingResult
|
||||
from app.models.model_alias import ModelAlias
|
||||
from app.models.provider import Provider
|
||||
|
||||
|
||||
class TestRouter:
|
||||
"""Test Router class."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_data(self, db_session: AsyncSession):
|
||||
"""Set up test data."""
|
||||
# Create test provider
|
||||
provider = Provider(
|
||||
name="openai",
|
||||
api_base="https://api.openai.com/v1",
|
||||
api_key_encrypted="encrypted_key",
|
||||
enabled=True,
|
||||
health_status="healthy",
|
||||
)
|
||||
db_session.add(provider)
|
||||
await db_session.flush()
|
||||
|
||||
# Create simple alias
|
||||
simple_alias = ModelAlias(
|
||||
alias="gpt-4",
|
||||
provider="openai",
|
||||
model="gpt-4-turbo",
|
||||
routing_type="simple",
|
||||
enabled=True,
|
||||
)
|
||||
db_session.add(simple_alias)
|
||||
|
||||
# Create load balance alias
|
||||
lb_alias = ModelAlias(
|
||||
alias="gpt-smart",
|
||||
provider="openai",
|
||||
model="gpt-4-turbo",
|
||||
routing_type="load_balance",
|
||||
routing_config=json.dumps({
|
||||
"providers": [
|
||||
{"provider": "openai", "model": "gpt-4-turbo", "weight": 2},
|
||||
]
|
||||
}),
|
||||
enabled=True,
|
||||
)
|
||||
db_session.add(lb_alias)
|
||||
|
||||
# Create fallback alias
|
||||
fb_alias = ModelAlias(
|
||||
alias="gpt-fallback",
|
||||
provider="openai",
|
||||
model="gpt-4-turbo",
|
||||
routing_type="fallback",
|
||||
routing_config=json.dumps({
|
||||
"primary": {"provider": "openai", "model": "gpt-4-turbo"},
|
||||
"fallback": [
|
||||
{"provider": "anthropic", "model": "claude-3-opus"},
|
||||
]
|
||||
}),
|
||||
enabled=True,
|
||||
)
|
||||
db_session.add(fb_alias)
|
||||
|
||||
await db_session.flush()
|
||||
|
||||
return {
|
||||
"provider": provider,
|
||||
"simple_alias": simple_alias,
|
||||
"lb_alias": lb_alias,
|
||||
"fb_alias": fb_alias,
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_simple_alias(self, db_session: AsyncSession, setup_data):
|
||||
"""Test resolving a simple alias."""
|
||||
router = Router(db_session)
|
||||
result = await router.resolve_model("gpt-4")
|
||||
|
||||
assert result.provider == "openai"
|
||||
assert result.model == "gpt-4-turbo"
|
||||
assert result.fallback_chain is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_direct_provider_model(self, db_session: AsyncSession):
|
||||
"""Test resolving a direct provider/model reference."""
|
||||
router = Router(db_session)
|
||||
result = await router.resolve_model("anthropic/claude-3-opus")
|
||||
|
||||
assert result.provider == "anthropic"
|
||||
assert result.model == "claude-3-opus"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_unknown_alias_raises_error(self, db_session: AsyncSession):
|
||||
"""Test that resolving an unknown alias raises an error."""
|
||||
router = Router(db_session)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await router.resolve_model("unknown-model")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_load_balance_alias(self, db_session: AsyncSession, setup_data):
|
||||
"""Test resolving a load balance alias."""
|
||||
router = Router(db_session)
|
||||
result = await router.resolve_model("gpt-smart")
|
||||
|
||||
# Should return one of the configured providers
|
||||
assert result.provider == "openai"
|
||||
assert result.model == "gpt-4-turbo"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_fallback_alias(self, db_session: AsyncSession, setup_data):
|
||||
"""Test resolving a fallback alias."""
|
||||
router = Router(db_session)
|
||||
result = await router.resolve_model("gpt-fallback")
|
||||
|
||||
assert result.provider == "openai"
|
||||
assert result.model == "gpt-4-turbo"
|
||||
assert result.fallback_chain is not None
|
||||
assert len(result.fallback_chain) == 1
|
||||
assert result.fallback_chain[0]["provider"] == "anthropic"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_disabled_alias_raises_error(
|
||||
self, db_session: AsyncSession, setup_data
|
||||
):
|
||||
"""Test that resolving a disabled alias raises an error."""
|
||||
# Disable the alias
|
||||
setup_data["simple_alias"].enabled = False
|
||||
await db_session.flush()
|
||||
|
||||
router = Router(db_session)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await router.resolve_model("gpt-4")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_fallback_provider(self, db_session: AsyncSession):
|
||||
"""Test getting the next fallback provider."""
|
||||
router = Router(db_session)
|
||||
|
||||
fallback_chain = [
|
||||
{"provider": "openai", "model": "gpt-4"},
|
||||
{"provider": "anthropic", "model": "claude-3"},
|
||||
{"provider": "google", "model": "gemini"},
|
||||
]
|
||||
|
||||
# Get next fallback after openai fails
|
||||
result = await router.get_fallback_provider("openai", fallback_chain)
|
||||
assert result is not None
|
||||
assert result.provider == "anthropic"
|
||||
assert result.model == "claude-3"
|
||||
|
||||
# Get next fallback after anthropic fails
|
||||
result = await router.get_fallback_provider("anthropic", result.fallback_chain)
|
||||
assert result is not None
|
||||
assert result.provider == "google"
|
||||
|
||||
# No more fallbacks after google (use the returned fallback_chain which only contains google)
|
||||
result = await router.get_fallback_provider("google", result.fallback_chain)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_fallback_provider_none_chain(self, db_session: AsyncSession):
|
||||
"""Test getting fallback when chain is None."""
|
||||
router = Router(db_session)
|
||||
|
||||
result = await router.get_fallback_provider("openai", None)
|
||||
assert result is None
|
||||
Loading…
x
Reference in New Issue
Block a user