docs: 补充完整测试和文档
- 添加 README.md(功能说明、快速开始、API 文档) - 修复冲突检测字段提取正则表达式 - 添加集成测试 tests/integration/test_full_flow.py Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
220f6db288
commit
0e37c30ae8
185
README.md
Normal file
185
README.md
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
# AI Rule Engine
|
||||||
|
|
||||||
|
将自然语言规则描述通过 AI 转换为可执行 Python 函数的规则引擎。
|
||||||
|
|
||||||
|
## 功能特性
|
||||||
|
|
||||||
|
- **自然语言 → 代码**:使用 LLM 将规则描述转换为 Python 函数
|
||||||
|
- **安全执行**:RestrictedPython 沙箱隔离
|
||||||
|
- **快速校验**:规则编译后存储,后续判断无需 LLM 调用
|
||||||
|
- **冲突检测**:自动检测矛盾规则
|
||||||
|
- **LLM 兜底**:无匹配时自动调用 LLM 补充规则
|
||||||
|
|
||||||
|
## 技术栈
|
||||||
|
|
||||||
|
- Python 3.10+
|
||||||
|
- SQLite(规则存储)
|
||||||
|
- LLM API(OpenAI GPT-4o / Anthropic Claude)
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
### 安装
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
### 启动服务
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m rule_engine.api
|
||||||
|
# 默认 http://0.0.0.0:8000
|
||||||
|
```
|
||||||
|
|
||||||
|
### API 接口
|
||||||
|
|
||||||
|
#### 创建规则
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/api/rules \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"name": "premium_discount",
|
||||||
|
"description": "高级会员8折",
|
||||||
|
"condition_template": "如果用户订阅类型为 premium 且年龄大于 18",
|
||||||
|
"priority": 1
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 校验事实
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/api/rules/evaluate \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"facts": {
|
||||||
|
"subscription": "premium",
|
||||||
|
"age": 25
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 列出规则
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8000/api/rules
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 获取规则
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8000/api/rules/{rule_id}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 删除规则
|
||||||
|
```bash
|
||||||
|
curl -X DELETE http://localhost:8000/api/rules/{rule_id}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
src/rule_engine/
|
||||||
|
├── __init__.py
|
||||||
|
├── api.py # REST API
|
||||||
|
├── callback.py # LLM 兜底
|
||||||
|
├── compiler.py # LLM 编译器
|
||||||
|
├── conflict.py # 冲突检测
|
||||||
|
├── executor.py # 沙箱执行器
|
||||||
|
├── matcher.py # 规则匹配
|
||||||
|
├── models.py # 数据模型
|
||||||
|
└── store.py # SQLite 存储
|
||||||
|
```
|
||||||
|
|
||||||
|
## 配置
|
||||||
|
|
||||||
|
| 环境变量 | 说明 |
|
||||||
|
|----------|------|
|
||||||
|
| `OPENAI_API_KEY` | OpenAI API Key |
|
||||||
|
| `ANTHROPIC_API_KEY` | Anthropic API Key |
|
||||||
|
|
||||||
|
## 开发
|
||||||
|
|
||||||
|
### 测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 运行所有测试
|
||||||
|
python3 -c "
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, 'src')
|
||||||
|
sys.path.insert(0, 'tests')
|
||||||
|
|
||||||
|
# 模型测试
|
||||||
|
from rule_engine.models import Rule, CreateRuleRequest
|
||||||
|
r = Rule(id='t1', name='test', condition_template='t', code='def r(f): pass')
|
||||||
|
assert r.id == 't1'
|
||||||
|
print('models: OK')
|
||||||
|
|
||||||
|
# 存储测试
|
||||||
|
import tempfile, os
|
||||||
|
from rule_engine.store import RuleStore
|
||||||
|
fd, path = tempfile.mkstemp(suffix='.db')
|
||||||
|
os.close(fd)
|
||||||
|
store = RuleStore(path)
|
||||||
|
rule = store.create_rule(name='t', condition_template='c', code='def r(f): pass')
|
||||||
|
assert store.get_rule(rule.id).name == 't'
|
||||||
|
os.unlink(path)
|
||||||
|
print('store: OK')
|
||||||
|
|
||||||
|
# 执行器测试
|
||||||
|
from rule_engine.executor import RuleExecutor
|
||||||
|
ex = RuleExecutor()
|
||||||
|
code = 'def rule(f): return {\"action\": \"ok\"}' if False else 'def rule(facts):\\n if facts.get(\"x\") == 1: return {\"a\": 1}\\n return None'
|
||||||
|
result = ex.execute_rule(code, {'x': 1})
|
||||||
|
assert result == {'a': 1}
|
||||||
|
print('executor: OK')
|
||||||
|
|
||||||
|
# 冲突检测测试
|
||||||
|
from rule_engine.conflict import ConflictDetector
|
||||||
|
det = ConflictDetector()
|
||||||
|
rules = [
|
||||||
|
{'id': 'r1', 'name': 'a', 'code': 'def rule(f): return {\"action\": \"allow\"}', 'priority': 1, 'is_active': True},
|
||||||
|
{'id': 'r2', 'name': 'b', 'code': 'def rule(f): return {\"action\": \"deny\"}', 'priority': 1, 'is_active': True},
|
||||||
|
]
|
||||||
|
conflicts = det.detect_conflicts(rules)
|
||||||
|
assert len(conflicts) == 1
|
||||||
|
print('conflict: OK')
|
||||||
|
|
||||||
|
print('\\nAll tests passed!')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 启动服务(带 LLM 回调)
|
||||||
|
```bash
|
||||||
|
python3 -c "
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, 'src')
|
||||||
|
from rule_engine.api import create_app
|
||||||
|
from rule_engine.store import RuleStore
|
||||||
|
|
||||||
|
store = RuleStore('rules.db')
|
||||||
|
srv = create_app(store, enable_callback=True)
|
||||||
|
print('Server started on http://0.0.0.0:8000')
|
||||||
|
srv.serve_forever()
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
## 规则格式
|
||||||
|
|
||||||
|
### 规则描述模板
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "rule_name",
|
||||||
|
"condition_template": "如果用户是 premium 会员且年龄大于 18",
|
||||||
|
"code": "def rule(facts):\n if facts.get('subscription') == 'premium' and facts.get('age', 0) > 18:\n return {'action': 'discount', 'rate': 0.8}\n return None"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Facts 格式
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"user_id": "u123",
|
||||||
|
"subscription": "premium",
|
||||||
|
"age": 25
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 许可证
|
||||||
|
|
||||||
|
MIT
|
||||||
@ -58,12 +58,16 @@ class ConflictDetector:
|
|||||||
def _conditions_may_overlap(self, code_a: str, code_b: str) -> bool:
|
def _conditions_may_overlap(self, code_a: str, code_b: str) -> bool:
|
||||||
"""检查两条规则的条件可能重叠。"""
|
"""检查两条规则的条件可能重叠。"""
|
||||||
import re
|
import re
|
||||||
# 提取 facts.get() 调用的字段
|
# 提取 .get() 调用的字段(支持 f.get 和 facts.get)
|
||||||
pattern = r'facts\.get\(["\']([^"\']+)["\']\)'
|
pattern = r'\.\s*get\s*\(["\']([^"\']+)["\']\)'
|
||||||
fields_a = set(re.findall(pattern, code_a))
|
fields_a = set(re.findall(pattern, code_a))
|
||||||
fields_b = set(re.findall(pattern, code_b))
|
fields_b = set(re.findall(pattern, code_b))
|
||||||
|
|
||||||
# 如果没有字段信息,假设可能重叠
|
# 如果没有字段信息,假设可能重叠
|
||||||
|
if not fields_a and not fields_b:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 如果任一规则没有字段,假设可能重叠
|
||||||
if not fields_a or not fields_b:
|
if not fields_a or not fields_b:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
192
tests/integration/test_full_flow.py
Normal file
192
tests/integration/test_full_flow.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
"""集成测试 - 完整流程测试."""
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, 'src')
|
||||||
|
sys.path.insert(0, 'tests')
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import json
|
||||||
|
from http.client import HTTPConnection
|
||||||
|
from threading import Thread
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
from rule_engine.api import create_app
|
||||||
|
from rule_engine.store import RuleStore
|
||||||
|
from rule_engine.executor import RuleExecutor
|
||||||
|
from rule_engine.matcher import RuleMatcher
|
||||||
|
from rule_engine.conflict import ConflictDetector
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_workflow():
|
||||||
|
"""完整工作流测试:创建规则 -> 校验 -> 执行."""
|
||||||
|
# 创建临时数据库
|
||||||
|
fd, path = tempfile.mkstemp(suffix='.db')
|
||||||
|
os.close(fd)
|
||||||
|
store = RuleStore(path)
|
||||||
|
|
||||||
|
# 创建规则
|
||||||
|
rule = store.create_rule(
|
||||||
|
name='premium_discount',
|
||||||
|
condition_template='如果用户订阅 premium 返回8折',
|
||||||
|
code='def rule(facts):\n if facts.get("subscription") == "premium":\n return {"action": "discount", "rate": 0.8}\n return None',
|
||||||
|
priority=1
|
||||||
|
)
|
||||||
|
print(f'Created rule: {rule.id}')
|
||||||
|
|
||||||
|
# 执行器测试
|
||||||
|
executor = RuleExecutor()
|
||||||
|
result = executor.execute_rule(rule.code, {"subscription": "premium"})
|
||||||
|
assert result == {"action": "discount", "rate": 0.8}
|
||||||
|
print('Executor test: OK')
|
||||||
|
|
||||||
|
# 匹配器测试
|
||||||
|
matcher = RuleMatcher(store, executor)
|
||||||
|
rule_id, result = matcher.find_matching_rule({"subscription": "premium"})
|
||||||
|
assert rule_id == rule.id
|
||||||
|
print('Matcher test: OK')
|
||||||
|
|
||||||
|
os.unlink(path)
|
||||||
|
print('Full workflow test: PASSED')
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_server():
|
||||||
|
"""API 服务器集成测试."""
|
||||||
|
fd, path = tempfile.mkstemp(suffix='.db')
|
||||||
|
os.close(fd)
|
||||||
|
store = RuleStore(path)
|
||||||
|
|
||||||
|
srv = create_app(store, host='127.0.0.1', port=18003)
|
||||||
|
thread = Thread(target=srv.serve_forever, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
sleep(0.2)
|
||||||
|
|
||||||
|
client = HTTPConnection('127.0.0.1', 18003)
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
client.request('GET', '/health')
|
||||||
|
resp = client.getresponse()
|
||||||
|
assert resp.status == 200
|
||||||
|
data = json.loads(resp.read())
|
||||||
|
assert data['status'] == 'ok'
|
||||||
|
print('Health check: OK')
|
||||||
|
|
||||||
|
# Create rule
|
||||||
|
body = json.dumps({
|
||||||
|
'name': 'test_rule',
|
||||||
|
'condition_template': '测试规则',
|
||||||
|
'code': 'def rule(facts):\n if facts.get("active"):\n return {"action": "ok"}\n return None'
|
||||||
|
})
|
||||||
|
client.request('POST', '/api/rules', body=body, headers={'Content-Type': 'application/json'})
|
||||||
|
resp = client.getresponse()
|
||||||
|
assert resp.status == 201
|
||||||
|
data = json.loads(resp.read())
|
||||||
|
rule_id = data['id']
|
||||||
|
print(f'Created rule: {rule_id}')
|
||||||
|
|
||||||
|
# List rules
|
||||||
|
client.request('GET', '/api/rules')
|
||||||
|
resp = client.getresponse()
|
||||||
|
data = json.loads(resp.read())
|
||||||
|
assert len(data['rules']) >= 1
|
||||||
|
print('List rules: OK')
|
||||||
|
|
||||||
|
# Get rule
|
||||||
|
client.request('GET', f'/api/rules/{rule_id}')
|
||||||
|
resp = client.getresponse()
|
||||||
|
data = json.loads(resp.read())
|
||||||
|
assert data['name'] == 'test_rule'
|
||||||
|
print('Get rule: OK')
|
||||||
|
|
||||||
|
# Evaluate (no match yet - executor not integrated in mock)
|
||||||
|
body = json.dumps({'facts': {'active': True}})
|
||||||
|
client.request('POST', '/api/rules/evaluate', body=body, headers={'Content-Type': 'application/json'})
|
||||||
|
resp = client.getresponse()
|
||||||
|
data = json.loads(resp.read())
|
||||||
|
print(f'Evaluate response: {data}')
|
||||||
|
|
||||||
|
# Delete rule
|
||||||
|
client.request('DELETE', f'/api/rules/{rule_id}')
|
||||||
|
resp = client.getresponse()
|
||||||
|
assert resp.status == 200
|
||||||
|
print('Delete rule: OK')
|
||||||
|
|
||||||
|
srv.shutdown()
|
||||||
|
os.unlink(path)
|
||||||
|
print('API server test: PASSED')
|
||||||
|
|
||||||
|
|
||||||
|
def test_conflict_detection():
|
||||||
|
"""冲突检测测试."""
|
||||||
|
detector = ConflictDetector()
|
||||||
|
|
||||||
|
# 测试1:相反动作冲突
|
||||||
|
rules = [
|
||||||
|
{'id': 'r1', 'name': 'allow_vip', 'code': 'def rule(f):\n if f.get("type") == "vip":\n return {"action": "allow"}\n return None', 'priority': 1, 'is_active': True},
|
||||||
|
{'id': 'r2', 'name': 'deny_vip', 'code': 'def rule(f):\n if f.get("type") == "vip":\n return {"action": "deny"}\n return None', 'priority': 1, 'is_active': True},
|
||||||
|
]
|
||||||
|
conflicts = detector.detect_conflicts(rules)
|
||||||
|
assert len(conflicts) == 1
|
||||||
|
assert conflicts[0]['requires_resolution'] is True
|
||||||
|
print('Conflict detection (opposite actions): OK')
|
||||||
|
|
||||||
|
# 测试2:不同字段不冲突
|
||||||
|
rules = [
|
||||||
|
{'id': 'r1', 'name': 'rule_a', 'code': 'def rule(f):\n if f.get("a") == 1:\n return {"action": "allow"}\n return None', 'priority': 1, 'is_active': True},
|
||||||
|
{'id': 'r2', 'name': 'rule_b', 'code': 'def rule(f):\n if f.get("b") == 1:\n return {"action": "deny"}\n return None', 'priority': 1, 'is_active': True},
|
||||||
|
]
|
||||||
|
conflicts = detector.detect_conflicts(rules)
|
||||||
|
assert len(conflicts) == 0
|
||||||
|
print('Conflict detection (different fields): OK')
|
||||||
|
|
||||||
|
# 测试3:不同优先级不冲突
|
||||||
|
rules = [
|
||||||
|
{'id': 'r1', 'name': 'allow', 'code': 'def rule(f):\n return {"action": "allow"}', 'priority': 2, 'is_active': True},
|
||||||
|
{'id': 'r2', 'name': 'deny', 'code': 'def rule(f):\n return {"action": "deny"}', 'priority': 1, 'is_active': True},
|
||||||
|
]
|
||||||
|
conflicts = detector.detect_conflicts(rules)
|
||||||
|
assert len(conflicts) == 0
|
||||||
|
print('Conflict detection (different priority): OK')
|
||||||
|
|
||||||
|
print('Conflict detection test: PASSED')
|
||||||
|
|
||||||
|
|
||||||
|
def test_executor_sandbox():
|
||||||
|
"""沙箱执行器安全测试."""
|
||||||
|
executor = RuleExecutor()
|
||||||
|
|
||||||
|
# 危险操作应被验证拦截
|
||||||
|
assert executor.validate_rule_code('def rule(f):\n import os') is False
|
||||||
|
assert executor.validate_rule_code('def rule(f):\n exec("")') is False
|
||||||
|
assert executor.validate_rule_code('def rule(f):\n eval("")') is False
|
||||||
|
print('Sandbox validation (blocked keywords): OK')
|
||||||
|
|
||||||
|
# 合法代码应通过
|
||||||
|
assert executor.validate_rule_code('def rule(facts):\n return None') is True
|
||||||
|
print('Sandbox validation (valid code): OK')
|
||||||
|
|
||||||
|
# 正常执行
|
||||||
|
code = 'def rule(facts):\n if facts.get("x", 0) > 10:\n return {"result": "big"}\n return None'
|
||||||
|
result = executor.execute_rule(code, {"x": 20})
|
||||||
|
assert result == {"result": "big"}
|
||||||
|
print('Sandbox execution: OK')
|
||||||
|
|
||||||
|
print('Sandbox test: PASSED')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print('=' * 50)
|
||||||
|
print('Running integration tests...')
|
||||||
|
print('=' * 50)
|
||||||
|
|
||||||
|
test_full_workflow()
|
||||||
|
print()
|
||||||
|
test_api_server()
|
||||||
|
print()
|
||||||
|
test_conflict_detection()
|
||||||
|
print()
|
||||||
|
test_executor_sandbox()
|
||||||
|
|
||||||
|
print()
|
||||||
|
print('=' * 50)
|
||||||
|
print('ALL INTEGRATION TESTS PASSED')
|
||||||
|
print('=' * 50)
|
||||||
Loading…
x
Reference in New Issue
Block a user