From 681ad84674e2f38cb4814f80daecec4571735a55 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 1 May 2026 22:33:38 +0800 Subject: [PATCH] test(router): add router unit tests and fix test data isolation - Add unit tests for Router model alias resolution - Fix async fixture configuration with pytest_asyncio - Add automatic rollback in db_session fixture for data isolation - Fix test_get_fallback_provider chain handling Co-Authored-By: Claude Opus 4.7 --- llm-gateway/tests/conftest.py | 6 +- llm-gateway/tests/unit/test_router.py | 174 ++++++++++++++++++++++++++ 2 files changed, 179 insertions(+), 1 deletion(-) create mode 100644 llm-gateway/tests/unit/test_router.py diff --git a/llm-gateway/tests/conftest.py b/llm-gateway/tests/conftest.py index 389c229..84bcdb2 100644 --- a/llm-gateway/tests/conftest.py +++ b/llm-gateway/tests/conftest.py @@ -75,9 +75,13 @@ async def setup_test_db() -> AsyncGenerator[None, None]: @pytest_asyncio.fixture async def db_session(setup_test_db: None) -> AsyncGenerator[AsyncSession, None]: - """Provide a database session for tests.""" + """Provide a database session for tests with automatic rollback.""" async with _test_session_factory() as session: + # Start a transaction + await session.begin() yield session + # Rollback after test to ensure data isolation + await session.rollback() async def _get_test_db(): diff --git a/llm-gateway/tests/unit/test_router.py b/llm-gateway/tests/unit/test_router.py new file mode 100644 index 0000000..1b46068 --- /dev/null +++ b/llm-gateway/tests/unit/test_router.py @@ -0,0 +1,174 @@ +"""Tests for router module.""" +import json +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.router import Router, RoutingResult +from app.models.model_alias import ModelAlias +from app.models.provider import Provider + + +class TestRouter: + """Test Router class.""" + + @pytest_asyncio.fixture + async def setup_data(self, db_session: AsyncSession): + """Set up test data.""" + # Create test provider + provider = Provider( + name="openai", + api_base="https://api.openai.com/v1", + api_key_encrypted="encrypted_key", + enabled=True, + health_status="healthy", + ) + db_session.add(provider) + await db_session.flush() + + # Create simple alias + simple_alias = ModelAlias( + alias="gpt-4", + provider="openai", + model="gpt-4-turbo", + routing_type="simple", + enabled=True, + ) + db_session.add(simple_alias) + + # Create load balance alias + lb_alias = ModelAlias( + alias="gpt-smart", + provider="openai", + model="gpt-4-turbo", + routing_type="load_balance", + routing_config=json.dumps({ + "providers": [ + {"provider": "openai", "model": "gpt-4-turbo", "weight": 2}, + ] + }), + enabled=True, + ) + db_session.add(lb_alias) + + # Create fallback alias + fb_alias = ModelAlias( + alias="gpt-fallback", + provider="openai", + model="gpt-4-turbo", + routing_type="fallback", + routing_config=json.dumps({ + "primary": {"provider": "openai", "model": "gpt-4-turbo"}, + "fallback": [ + {"provider": "anthropic", "model": "claude-3-opus"}, + ] + }), + enabled=True, + ) + db_session.add(fb_alias) + + await db_session.flush() + + return { + "provider": provider, + "simple_alias": simple_alias, + "lb_alias": lb_alias, + "fb_alias": fb_alias, + } + + @pytest.mark.asyncio + async def test_resolve_simple_alias(self, db_session: AsyncSession, setup_data): + """Test resolving a simple alias.""" + router = Router(db_session) + result = await router.resolve_model("gpt-4") + + assert result.provider == "openai" + assert result.model == "gpt-4-turbo" + assert result.fallback_chain is None + + @pytest.mark.asyncio + async def test_resolve_direct_provider_model(self, db_session: AsyncSession): + """Test resolving a direct provider/model reference.""" + router = Router(db_session) + result = await router.resolve_model("anthropic/claude-3-opus") + + assert result.provider == "anthropic" + assert result.model == "claude-3-opus" + + @pytest.mark.asyncio + async def test_resolve_unknown_alias_raises_error(self, db_session: AsyncSession): + """Test that resolving an unknown alias raises an error.""" + router = Router(db_session) + + with pytest.raises(ValueError, match="not found"): + await router.resolve_model("unknown-model") + + @pytest.mark.asyncio + async def test_resolve_load_balance_alias(self, db_session: AsyncSession, setup_data): + """Test resolving a load balance alias.""" + router = Router(db_session) + result = await router.resolve_model("gpt-smart") + + # Should return one of the configured providers + assert result.provider == "openai" + assert result.model == "gpt-4-turbo" + + @pytest.mark.asyncio + async def test_resolve_fallback_alias(self, db_session: AsyncSession, setup_data): + """Test resolving a fallback alias.""" + router = Router(db_session) + result = await router.resolve_model("gpt-fallback") + + assert result.provider == "openai" + assert result.model == "gpt-4-turbo" + assert result.fallback_chain is not None + assert len(result.fallback_chain) == 1 + assert result.fallback_chain[0]["provider"] == "anthropic" + + @pytest.mark.asyncio + async def test_resolve_disabled_alias_raises_error( + self, db_session: AsyncSession, setup_data + ): + """Test that resolving a disabled alias raises an error.""" + # Disable the alias + setup_data["simple_alias"].enabled = False + await db_session.flush() + + router = Router(db_session) + + with pytest.raises(ValueError, match="not found"): + await router.resolve_model("gpt-4") + + @pytest.mark.asyncio + async def test_get_fallback_provider(self, db_session: AsyncSession): + """Test getting the next fallback provider.""" + router = Router(db_session) + + fallback_chain = [ + {"provider": "openai", "model": "gpt-4"}, + {"provider": "anthropic", "model": "claude-3"}, + {"provider": "google", "model": "gemini"}, + ] + + # Get next fallback after openai fails + result = await router.get_fallback_provider("openai", fallback_chain) + assert result is not None + assert result.provider == "anthropic" + assert result.model == "claude-3" + + # Get next fallback after anthropic fails + result = await router.get_fallback_provider("anthropic", result.fallback_chain) + assert result is not None + assert result.provider == "google" + + # No more fallbacks after google (use the returned fallback_chain which only contains google) + result = await router.get_fallback_provider("google", result.fallback_chain) + assert result is None + + @pytest.mark.asyncio + async def test_get_fallback_provider_none_chain(self, db_session: AsyncSession): + """Test getting fallback when chain is None.""" + router = Router(db_session) + + result = await router.get_fallback_provider("openai", None) + assert result is None