ai-rule-engine/tests/integration/test_full_flow.py
root 0e37c30ae8 docs: 补充完整测试和文档
- 添加 README.md(功能说明、快速开始、API 文档)
- 修复冲突检测字段提取正则表达式
- 添加集成测试 tests/integration/test_full_flow.py

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-11 22:45:57 +08:00

193 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""集成测试 - 完整流程测试."""
import sys
sys.path.insert(0, 'src')
sys.path.insert(0, 'tests')
import os
import tempfile
import json
from http.client import HTTPConnection
from threading import Thread
from time import sleep
from rule_engine.api import create_app
from rule_engine.store import RuleStore
from rule_engine.executor import RuleExecutor
from rule_engine.matcher import RuleMatcher
from rule_engine.conflict import ConflictDetector
def test_full_workflow():
"""完整工作流测试:创建规则 -> 校验 -> 执行."""
# 创建临时数据库
fd, path = tempfile.mkstemp(suffix='.db')
os.close(fd)
store = RuleStore(path)
# 创建规则
rule = store.create_rule(
name='premium_discount',
condition_template='如果用户订阅 premium 返回8折',
code='def rule(facts):\n if facts.get("subscription") == "premium":\n return {"action": "discount", "rate": 0.8}\n return None',
priority=1
)
print(f'Created rule: {rule.id}')
# 执行器测试
executor = RuleExecutor()
result = executor.execute_rule(rule.code, {"subscription": "premium"})
assert result == {"action": "discount", "rate": 0.8}
print('Executor test: OK')
# 匹配器测试
matcher = RuleMatcher(store, executor)
rule_id, result = matcher.find_matching_rule({"subscription": "premium"})
assert rule_id == rule.id
print('Matcher test: OK')
os.unlink(path)
print('Full workflow test: PASSED')
def test_api_server():
"""API 服务器集成测试."""
fd, path = tempfile.mkstemp(suffix='.db')
os.close(fd)
store = RuleStore(path)
srv = create_app(store, host='127.0.0.1', port=18003)
thread = Thread(target=srv.serve_forever, daemon=True)
thread.start()
sleep(0.2)
client = HTTPConnection('127.0.0.1', 18003)
# Health check
client.request('GET', '/health')
resp = client.getresponse()
assert resp.status == 200
data = json.loads(resp.read())
assert data['status'] == 'ok'
print('Health check: OK')
# Create rule
body = json.dumps({
'name': 'test_rule',
'condition_template': '测试规则',
'code': 'def rule(facts):\n if facts.get("active"):\n return {"action": "ok"}\n return None'
})
client.request('POST', '/api/rules', body=body, headers={'Content-Type': 'application/json'})
resp = client.getresponse()
assert resp.status == 201
data = json.loads(resp.read())
rule_id = data['id']
print(f'Created rule: {rule_id}')
# List rules
client.request('GET', '/api/rules')
resp = client.getresponse()
data = json.loads(resp.read())
assert len(data['rules']) >= 1
print('List rules: OK')
# Get rule
client.request('GET', f'/api/rules/{rule_id}')
resp = client.getresponse()
data = json.loads(resp.read())
assert data['name'] == 'test_rule'
print('Get rule: OK')
# Evaluate (no match yet - executor not integrated in mock)
body = json.dumps({'facts': {'active': True}})
client.request('POST', '/api/rules/evaluate', body=body, headers={'Content-Type': 'application/json'})
resp = client.getresponse()
data = json.loads(resp.read())
print(f'Evaluate response: {data}')
# Delete rule
client.request('DELETE', f'/api/rules/{rule_id}')
resp = client.getresponse()
assert resp.status == 200
print('Delete rule: OK')
srv.shutdown()
os.unlink(path)
print('API server test: PASSED')
def test_conflict_detection():
"""冲突检测测试."""
detector = ConflictDetector()
# 测试1相反动作冲突
rules = [
{'id': 'r1', 'name': 'allow_vip', 'code': 'def rule(f):\n if f.get("type") == "vip":\n return {"action": "allow"}\n return None', 'priority': 1, 'is_active': True},
{'id': 'r2', 'name': 'deny_vip', 'code': 'def rule(f):\n if f.get("type") == "vip":\n return {"action": "deny"}\n return None', 'priority': 1, 'is_active': True},
]
conflicts = detector.detect_conflicts(rules)
assert len(conflicts) == 1
assert conflicts[0]['requires_resolution'] is True
print('Conflict detection (opposite actions): OK')
# 测试2不同字段不冲突
rules = [
{'id': 'r1', 'name': 'rule_a', 'code': 'def rule(f):\n if f.get("a") == 1:\n return {"action": "allow"}\n return None', 'priority': 1, 'is_active': True},
{'id': 'r2', 'name': 'rule_b', 'code': 'def rule(f):\n if f.get("b") == 1:\n return {"action": "deny"}\n return None', 'priority': 1, 'is_active': True},
]
conflicts = detector.detect_conflicts(rules)
assert len(conflicts) == 0
print('Conflict detection (different fields): OK')
# 测试3不同优先级不冲突
rules = [
{'id': 'r1', 'name': 'allow', 'code': 'def rule(f):\n return {"action": "allow"}', 'priority': 2, 'is_active': True},
{'id': 'r2', 'name': 'deny', 'code': 'def rule(f):\n return {"action": "deny"}', 'priority': 1, 'is_active': True},
]
conflicts = detector.detect_conflicts(rules)
assert len(conflicts) == 0
print('Conflict detection (different priority): OK')
print('Conflict detection test: PASSED')
def test_executor_sandbox():
"""沙箱执行器安全测试."""
executor = RuleExecutor()
# 危险操作应被验证拦截
assert executor.validate_rule_code('def rule(f):\n import os') is False
assert executor.validate_rule_code('def rule(f):\n exec("")') is False
assert executor.validate_rule_code('def rule(f):\n eval("")') is False
print('Sandbox validation (blocked keywords): OK')
# 合法代码应通过
assert executor.validate_rule_code('def rule(facts):\n return None') is True
print('Sandbox validation (valid code): OK')
# 正常执行
code = 'def rule(facts):\n if facts.get("x", 0) > 10:\n return {"result": "big"}\n return None'
result = executor.execute_rule(code, {"x": 20})
assert result == {"result": "big"}
print('Sandbox execution: OK')
print('Sandbox test: PASSED')
if __name__ == '__main__':
print('=' * 50)
print('Running integration tests...')
print('=' * 50)
test_full_workflow()
print()
test_api_server()
print()
test_conflict_detection()
print()
test_executor_sandbox()
print()
print('=' * 50)
print('ALL INTEGRATION TESTS PASSED')
print('=' * 50)