test(router): add router unit tests and fix test data isolation

- Add unit tests for Router model alias resolution
- Fix async fixture configuration with pytest_asyncio
- Add automatic rollback in db_session fixture for data isolation
- Fix test_get_fallback_provider chain handling

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
root 2026-05-01 22:33:38 +08:00
parent 8348520bdf
commit 681ad84674
2 changed files with 179 additions and 1 deletions

View File

@ -75,9 +75,13 @@ async def setup_test_db() -> AsyncGenerator[None, None]:
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def db_session(setup_test_db: None) -> AsyncGenerator[AsyncSession, None]: async def db_session(setup_test_db: None) -> AsyncGenerator[AsyncSession, None]:
"""Provide a database session for tests.""" """Provide a database session for tests with automatic rollback."""
async with _test_session_factory() as session: async with _test_session_factory() as session:
# Start a transaction
await session.begin()
yield session yield session
# Rollback after test to ensure data isolation
await session.rollback()
async def _get_test_db(): async def _get_test_db():

View File

@ -0,0 +1,174 @@
"""Tests for router module."""
import json
import pytest
import pytest_asyncio
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.router import Router, RoutingResult
from app.models.model_alias import ModelAlias
from app.models.provider import Provider
class TestRouter:
"""Test Router class."""
@pytest_asyncio.fixture
async def setup_data(self, db_session: AsyncSession):
"""Set up test data."""
# Create test provider
provider = Provider(
name="openai",
api_base="https://api.openai.com/v1",
api_key_encrypted="encrypted_key",
enabled=True,
health_status="healthy",
)
db_session.add(provider)
await db_session.flush()
# Create simple alias
simple_alias = ModelAlias(
alias="gpt-4",
provider="openai",
model="gpt-4-turbo",
routing_type="simple",
enabled=True,
)
db_session.add(simple_alias)
# Create load balance alias
lb_alias = ModelAlias(
alias="gpt-smart",
provider="openai",
model="gpt-4-turbo",
routing_type="load_balance",
routing_config=json.dumps({
"providers": [
{"provider": "openai", "model": "gpt-4-turbo", "weight": 2},
]
}),
enabled=True,
)
db_session.add(lb_alias)
# Create fallback alias
fb_alias = ModelAlias(
alias="gpt-fallback",
provider="openai",
model="gpt-4-turbo",
routing_type="fallback",
routing_config=json.dumps({
"primary": {"provider": "openai", "model": "gpt-4-turbo"},
"fallback": [
{"provider": "anthropic", "model": "claude-3-opus"},
]
}),
enabled=True,
)
db_session.add(fb_alias)
await db_session.flush()
return {
"provider": provider,
"simple_alias": simple_alias,
"lb_alias": lb_alias,
"fb_alias": fb_alias,
}
@pytest.mark.asyncio
async def test_resolve_simple_alias(self, db_session: AsyncSession, setup_data):
"""Test resolving a simple alias."""
router = Router(db_session)
result = await router.resolve_model("gpt-4")
assert result.provider == "openai"
assert result.model == "gpt-4-turbo"
assert result.fallback_chain is None
@pytest.mark.asyncio
async def test_resolve_direct_provider_model(self, db_session: AsyncSession):
"""Test resolving a direct provider/model reference."""
router = Router(db_session)
result = await router.resolve_model("anthropic/claude-3-opus")
assert result.provider == "anthropic"
assert result.model == "claude-3-opus"
@pytest.mark.asyncio
async def test_resolve_unknown_alias_raises_error(self, db_session: AsyncSession):
"""Test that resolving an unknown alias raises an error."""
router = Router(db_session)
with pytest.raises(ValueError, match="not found"):
await router.resolve_model("unknown-model")
@pytest.mark.asyncio
async def test_resolve_load_balance_alias(self, db_session: AsyncSession, setup_data):
"""Test resolving a load balance alias."""
router = Router(db_session)
result = await router.resolve_model("gpt-smart")
# Should return one of the configured providers
assert result.provider == "openai"
assert result.model == "gpt-4-turbo"
@pytest.mark.asyncio
async def test_resolve_fallback_alias(self, db_session: AsyncSession, setup_data):
"""Test resolving a fallback alias."""
router = Router(db_session)
result = await router.resolve_model("gpt-fallback")
assert result.provider == "openai"
assert result.model == "gpt-4-turbo"
assert result.fallback_chain is not None
assert len(result.fallback_chain) == 1
assert result.fallback_chain[0]["provider"] == "anthropic"
@pytest.mark.asyncio
async def test_resolve_disabled_alias_raises_error(
self, db_session: AsyncSession, setup_data
):
"""Test that resolving a disabled alias raises an error."""
# Disable the alias
setup_data["simple_alias"].enabled = False
await db_session.flush()
router = Router(db_session)
with pytest.raises(ValueError, match="not found"):
await router.resolve_model("gpt-4")
@pytest.mark.asyncio
async def test_get_fallback_provider(self, db_session: AsyncSession):
"""Test getting the next fallback provider."""
router = Router(db_session)
fallback_chain = [
{"provider": "openai", "model": "gpt-4"},
{"provider": "anthropic", "model": "claude-3"},
{"provider": "google", "model": "gemini"},
]
# Get next fallback after openai fails
result = await router.get_fallback_provider("openai", fallback_chain)
assert result is not None
assert result.provider == "anthropic"
assert result.model == "claude-3"
# Get next fallback after anthropic fails
result = await router.get_fallback_provider("anthropic", result.fallback_chain)
assert result is not None
assert result.provider == "google"
# No more fallbacks after google (use the returned fallback_chain which only contains google)
result = await router.get_fallback_provider("google", result.fallback_chain)
assert result is None
@pytest.mark.asyncio
async def test_get_fallback_provider_none_chain(self, db_session: AsyncSession):
"""Test getting fallback when chain is None."""
router = Router(db_session)
result = await router.get_fallback_provider("openai", None)
assert result is None