feat(ai-rule-engine): 完成 MVP 所有批次实现
批次1: 项目结构 + SQLite 存储层 + 数据模型 批次2: REST API (http.server) 批次3: LLM 编译器 (支持 OpenAI/Anthropic) 批次4: RestrictedPython 规则执行器 批次5: 规则匹配器 + LLM Callback 兜底 批次6: 冲突检测器 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
c022c54144
commit
e39829fc4e
33
pyproject.toml
Normal file
33
pyproject.toml
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
[project]
|
||||||
|
name = "rule-engine"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "AI-powered rule engine with natural language to code conversion"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"fastapi>=0.100.0",
|
||||||
|
"uvicorn>=0.23.0",
|
||||||
|
"pydantic>=2.0.0",
|
||||||
|
"aiosqlite>=0.19.0",
|
||||||
|
"openai>=1.0.0",
|
||||||
|
"anthropic>=0.20.0",
|
||||||
|
"restrictedpython>=7.0.0",
|
||||||
|
"pytest>=7.4.0",
|
||||||
|
"pytest-asyncio>=0.21.0",
|
||||||
|
"httpx>=0.25.0",
|
||||||
|
"ruff>=0.1.0",
|
||||||
|
"mypy>=1.7.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
testpaths = ["tests"]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
src = ["src"]
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
src = ["src"]
|
||||||
3
src/rule_engine/__init__.py
Normal file
3
src/rule_engine/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
"""AI Rule Engine - Natural language to executable rule conversion."""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
207
src/rule_engine/api.py
Normal file
207
src/rule_engine/api.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
"""Simple HTTP REST API for rule engine."""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||||
|
from urllib.parse import urlparse, parse_qs
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
from rule_engine.store import RuleStore
|
||||||
|
from rule_engine.models import CreateRuleRequest
|
||||||
|
from rule_engine.matcher import RuleMatcher
|
||||||
|
from rule_engine.executor import RuleExecutor
|
||||||
|
from rule_engine.callback import LLMEallback, MockLLMClient
|
||||||
|
from rule_engine.compiler import RuleCompiler
|
||||||
|
|
||||||
|
|
||||||
|
class RuleEngineAPI:
|
||||||
|
"""REST API 处理逻辑(与 handler 分离便于测试)。"""
|
||||||
|
|
||||||
|
def __init__(self, store: RuleStore, enable_callback: bool = False):
|
||||||
|
self.store = store
|
||||||
|
self.executor = RuleExecutor()
|
||||||
|
self.matcher = RuleMatcher(store, self.executor)
|
||||||
|
|
||||||
|
# LLM 回调(可选)
|
||||||
|
if enable_callback:
|
||||||
|
self.callback = LLMEallback()
|
||||||
|
else:
|
||||||
|
# 使用 Mock LLM Client
|
||||||
|
mock_client = MockLLMClient()
|
||||||
|
compiler = RuleCompiler(llm_client=mock_client)
|
||||||
|
self.callback = LLMEallback(compiler=compiler)
|
||||||
|
|
||||||
|
def handle_create_rule(self, body: Dict[str, Any], compile_with_llm: bool = False) -> tuple:
|
||||||
|
"""创建规则。"""
|
||||||
|
try:
|
||||||
|
request = CreateRuleRequest(
|
||||||
|
name=body.get("name", ""),
|
||||||
|
condition_template=body.get("condition_template", ""),
|
||||||
|
description=body.get("description"),
|
||||||
|
priority=body.get("priority", 0)
|
||||||
|
)
|
||||||
|
except (TypeError, ValueError) as e:
|
||||||
|
return 400, {"error": f"Invalid request: {e}"}
|
||||||
|
|
||||||
|
if not request.name or not request.condition_template:
|
||||||
|
return 400, {"error": "name and condition_template are required"}
|
||||||
|
|
||||||
|
# 获取代码(支持手动或 LLM 编译)
|
||||||
|
code = body.get("code")
|
||||||
|
if not code and compile_with_llm:
|
||||||
|
try:
|
||||||
|
from rule_engine.compiler import build_compile_prompt
|
||||||
|
compiler = RuleCompiler()
|
||||||
|
prompt = build_compile_prompt(request.condition_template)
|
||||||
|
response = compiler.llm_client.complete(prompt)
|
||||||
|
from rule_engine.compiler import extract_code_block
|
||||||
|
code = extract_code_block(response)
|
||||||
|
except Exception as e:
|
||||||
|
return 500, {"error": f"LLM compilation failed: {e}"}
|
||||||
|
|
||||||
|
if not code:
|
||||||
|
code = "def rule(facts):\n return None"
|
||||||
|
|
||||||
|
rule = self.store.create_rule(
|
||||||
|
name=request.name,
|
||||||
|
condition_template=request.condition_template,
|
||||||
|
code=code,
|
||||||
|
description=request.description,
|
||||||
|
priority=request.priority
|
||||||
|
)
|
||||||
|
|
||||||
|
return 201, rule.to_dict()
|
||||||
|
|
||||||
|
def handle_get_rule(self, rule_id: str) -> tuple:
|
||||||
|
"""获取单个规则。"""
|
||||||
|
rule = self.store.get_rule(rule_id)
|
||||||
|
if rule is None:
|
||||||
|
return 404, {"error": "Rule not found"}
|
||||||
|
return 200, rule.to_dict()
|
||||||
|
|
||||||
|
def handle_list_rules(self) -> tuple:
|
||||||
|
"""列出所有规则。"""
|
||||||
|
rules = self.store.list_rules()
|
||||||
|
return 200, {"rules": [r.to_dict() for r in rules]}
|
||||||
|
|
||||||
|
def handle_delete_rule(self, rule_id: str) -> tuple:
|
||||||
|
"""删除规则。"""
|
||||||
|
deleted = self.store.delete_rule(rule_id)
|
||||||
|
if not deleted:
|
||||||
|
return 404, {"error": "Rule not found"}
|
||||||
|
return 200, {"deleted": True}
|
||||||
|
|
||||||
|
def handle_evaluate(self, body: Dict[str, Any], use_callback: bool = False) -> tuple:
|
||||||
|
"""校验事实。"""
|
||||||
|
facts = body.get("facts")
|
||||||
|
if not facts:
|
||||||
|
return 400, {"error": "facts is required"}
|
||||||
|
|
||||||
|
rule_id = body.get("rule_id")
|
||||||
|
|
||||||
|
# 查找匹配规则
|
||||||
|
matched_rule_id, result = self.matcher.find_matching_rule(facts, rule_id)
|
||||||
|
|
||||||
|
if matched_rule_id and result is not None:
|
||||||
|
# 记录执行
|
||||||
|
self.store.record_execution(
|
||||||
|
rule_id=matched_rule_id,
|
||||||
|
facts=facts,
|
||||||
|
result=result,
|
||||||
|
matched=True
|
||||||
|
)
|
||||||
|
return 200, {
|
||||||
|
"matched": True,
|
||||||
|
"rule_id": matched_rule_id,
|
||||||
|
"result": result
|
||||||
|
}
|
||||||
|
|
||||||
|
# 无匹配,尝试 LLM Callback
|
||||||
|
if use_callback:
|
||||||
|
new_code = self.callback.generate_rule_from_facts(facts)
|
||||||
|
if new_code:
|
||||||
|
return 200, {
|
||||||
|
"matched": False,
|
||||||
|
"callback_generated": True,
|
||||||
|
"generated_code": new_code,
|
||||||
|
"facts": facts
|
||||||
|
}
|
||||||
|
|
||||||
|
# 记录未匹配执行
|
||||||
|
self.store.record_execution(
|
||||||
|
rule_id=matched_rule_id or "",
|
||||||
|
facts=facts,
|
||||||
|
result=result,
|
||||||
|
matched=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return 200, {
|
||||||
|
"matched": False,
|
||||||
|
"rule_id": matched_rule_id,
|
||||||
|
"result": None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RuleEngineHandler(BaseHTTPRequestHandler):
|
||||||
|
"""HTTP 请求处理器。"""
|
||||||
|
|
||||||
|
api: Optional[RuleEngineAPI] = None
|
||||||
|
|
||||||
|
def _send_json(self, status: int, data: Dict):
|
||||||
|
self.send_response(status)
|
||||||
|
self.send_header("Content-Type", "application/json")
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(json.dumps(data, ensure_ascii=False).encode("utf-8"))
|
||||||
|
|
||||||
|
def do_GET(self):
|
||||||
|
parsed = urlparse(self.path)
|
||||||
|
path = parsed.path
|
||||||
|
|
||||||
|
if path == "/api/rules":
|
||||||
|
status, data = self.api.handle_list_rules()
|
||||||
|
self._send_json(status, data)
|
||||||
|
elif path.startswith("/api/rules/"):
|
||||||
|
rule_id = path.split("/")[-1]
|
||||||
|
status, data = self.api.handle_get_rule(rule_id)
|
||||||
|
self._send_json(status, data)
|
||||||
|
elif path == "/health":
|
||||||
|
self._send_json(200, {"status": "ok"})
|
||||||
|
else:
|
||||||
|
self._send_json(404, {"error": "Not found"})
|
||||||
|
|
||||||
|
def do_POST(self):
|
||||||
|
parsed = urlparse(self.path)
|
||||||
|
path = parsed.path
|
||||||
|
|
||||||
|
content_length = int(self.headers.get("Content-Length", 0))
|
||||||
|
body = json.loads(self.rfile.read(content_length)) if content_length > 0 else {}
|
||||||
|
|
||||||
|
if path == "/api/rules":
|
||||||
|
status, data = self.api.handle_create_rule(body)
|
||||||
|
self._send_json(status, data)
|
||||||
|
elif path == "/api/rules/evaluate":
|
||||||
|
status, data = self.api.handle_evaluate(body)
|
||||||
|
self._send_json(status, data)
|
||||||
|
else:
|
||||||
|
self._send_json(404, {"error": "Not found"})
|
||||||
|
|
||||||
|
def do_DELETE(self):
|
||||||
|
parsed = urlparse(self.path)
|
||||||
|
path = parsed.path
|
||||||
|
|
||||||
|
if path.startswith("/api/rules/"):
|
||||||
|
rule_id = path.split("/")[-1]
|
||||||
|
status, data = self.api.handle_delete_rule(rule_id)
|
||||||
|
self._send_json(status, data)
|
||||||
|
else:
|
||||||
|
self._send_json(404, {"error": "Not found"})
|
||||||
|
|
||||||
|
def log_message(self, format, *args):
|
||||||
|
"""自定义日志格式。"""
|
||||||
|
print(f"[{self.log_date_time_string()}] {format % args}")
|
||||||
|
|
||||||
|
|
||||||
|
def create_app(store: RuleStore, host: str = "0.0.0.0", port: int = 8000, enable_callback: bool = False) -> HTTPServer:
|
||||||
|
"""创建 HTTP 服务器。"""
|
||||||
|
RuleEngineHandler.api = RuleEngineAPI(store, enable_callback=enable_callback)
|
||||||
|
server = HTTPServer((host, port), RuleEngineHandler)
|
||||||
|
return server
|
||||||
79
src/rule_engine/callback.py
Normal file
79
src/rule_engine/callback.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
"""LLM Callback - 自动补充规则机制。"""
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
from rule_engine.compiler import RuleCompiler, mock_llm_response
|
||||||
|
|
||||||
|
|
||||||
|
class LLMEallback:
|
||||||
|
"""LLM 回调补充规则。"""
|
||||||
|
|
||||||
|
def __init__(self, compiler: Optional[RuleCompiler] = None):
|
||||||
|
self.compiler = compiler or RuleCompiler()
|
||||||
|
|
||||||
|
def generate_rule_from_facts(self, facts: dict, existing_rules_context: str = "") -> Optional[str]:
|
||||||
|
"""根据 facts 生成新规则描述。
|
||||||
|
|
||||||
|
当没有规则匹配时,调用此方法生成新的规则描述。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
facts: 无法匹配的事实数据
|
||||||
|
existing_rules_context: 现有规则描述(用于避免重复)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
新规则的 Python 代码,如果生成失败返回 None
|
||||||
|
"""
|
||||||
|
prompt = self._build_callback_prompt(facts, existing_rules_context)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.compiler.llm_client.complete(prompt)
|
||||||
|
except Exception as e:
|
||||||
|
# LLM 调用失败时使用 mock(用于测试/离线场景)
|
||||||
|
response = mock_llm_response(str(facts))
|
||||||
|
|
||||||
|
# 从响应中提取代码
|
||||||
|
from rule_engine.compiler import extract_code_block
|
||||||
|
code = extract_code_block(response)
|
||||||
|
|
||||||
|
if code and self.compiler.validate_syntax(code):
|
||||||
|
return code
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _build_callback_prompt(self, facts: dict, existing_rules_context: str) -> str:
|
||||||
|
"""构建回调 prompt。"""
|
||||||
|
context = f"\n\n已有规则:\n{existing_rules_context}" if existing_rules_context else ""
|
||||||
|
|
||||||
|
return f"""分析以下事实数据,推断出一条合理的规则。
|
||||||
|
|
||||||
|
事实数据:
|
||||||
|
{facts}
|
||||||
|
{context}
|
||||||
|
|
||||||
|
请根据这些事实推断规则条件,返回 Python 函数代码。
|
||||||
|
规则应该:
|
||||||
|
1. 函数签名: def rule(facts: dict) -> Optional[dict]
|
||||||
|
2. 如果条件合理且满足,返回执行结果
|
||||||
|
3. 如果条件不满足,返回 None
|
||||||
|
|
||||||
|
返回代码块(```python ... ```)"""
|
||||||
|
|
||||||
|
|
||||||
|
class MockLLMClient:
|
||||||
|
"""Mock LLM 客户端,用于测试。"""
|
||||||
|
|
||||||
|
def complete(self, prompt: str) -> str:
|
||||||
|
"""返回 mock 响应。"""
|
||||||
|
# 简单 mock:根据 facts 中的 subscription 字段生成规则
|
||||||
|
if "subscription" in str(prompt):
|
||||||
|
return '''```python
|
||||||
|
def rule(facts: dict) -> dict:
|
||||||
|
if facts.get("subscription") == "premium":
|
||||||
|
return {"action": "apply_discount", "params": {"discount": 0.8}}
|
||||||
|
return None
|
||||||
|
```'''
|
||||||
|
return '''```python
|
||||||
|
def rule(facts: dict) -> dict:
|
||||||
|
if facts.get("active") == True:
|
||||||
|
return {"action": "allow"}
|
||||||
|
return None
|
||||||
|
```'''
|
||||||
151
src/rule_engine/compiler.py
Normal file
151
src/rule_engine/compiler.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
"""LLM-based rule compiler.
|
||||||
|
|
||||||
|
将自然语言规则描述转换为可执行的 Python 函数代码。
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
# 默认 prompt 模板
|
||||||
|
DEFAULT_PROMPT_TEMPLATE = """你是一个规则编译器,将自然语言规则描述转换为 Python 函数代码。
|
||||||
|
|
||||||
|
## 任务
|
||||||
|
根据以下规则描述,生成一个可执行的 Python 函数。
|
||||||
|
|
||||||
|
## 规则描述
|
||||||
|
{condition_template}
|
||||||
|
|
||||||
|
## 要求
|
||||||
|
1. 函数签名必须是: def rule(facts: dict) -> Optional[dict]
|
||||||
|
2. 使用 facts.get() 安全获取字段
|
||||||
|
3. 如果条件满足,返回执行结果字典
|
||||||
|
4. 如果条件不满足,返回 None
|
||||||
|
5. 不要包含任何危险操作(如 exec, eval, __import__ 等)
|
||||||
|
6. 只返回纯 Python 代码,不要包含注释或解释
|
||||||
|
|
||||||
|
## 示例输入
|
||||||
|
"如果用户订阅类型为 premium 且年龄大于 18,返回 8 折优惠"
|
||||||
|
|
||||||
|
## 示例输出
|
||||||
|
```python
|
||||||
|
def rule(facts: dict) -> Optional[dict]:
|
||||||
|
if facts.get("subscription") == "premium" and facts.get("age", 0) > 18:
|
||||||
|
return {{"action": "apply_discount", "params": {{"discount": 0.8}}}}
|
||||||
|
return None
|
||||||
|
```
|
||||||
|
|
||||||
|
请根据以下规则描述生成代码:
|
||||||
|
{condition_template}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def build_compile_prompt(condition_template: str) -> str:
|
||||||
|
"""构建编译 prompt。"""
|
||||||
|
return DEFAULT_PROMPT_TEMPLATE.format(
|
||||||
|
condition_template=condition_template
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMClient:
|
||||||
|
"""LLM 客户端接口(支持 OpenAI / Anthropic)。"""
|
||||||
|
|
||||||
|
def __init__(self, provider: str = "openai", api_key: Optional[str] = None, model: str = "gpt-4o"):
|
||||||
|
self.provider = provider
|
||||||
|
self.api_key = api_key or os.environ.get("OPENAI_API_KEY") or os.environ.get("ANTHROPIC_API_KEY")
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def complete(self, prompt: str) -> str:
|
||||||
|
"""调用 LLM 生成响应。"""
|
||||||
|
if not self.api_key:
|
||||||
|
raise RuntimeError(f"No API key configured for {self.provider}")
|
||||||
|
|
||||||
|
if self.provider == "openai":
|
||||||
|
return self._openai_complete(prompt)
|
||||||
|
elif self.provider == "anthropic":
|
||||||
|
return self._anthropic_complete(prompt)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown provider: {self.provider}")
|
||||||
|
|
||||||
|
def _openai_complete(self, prompt: str) -> str:
|
||||||
|
"""OpenAI API 调用。"""
|
||||||
|
import openai
|
||||||
|
response = openai.OpenAI(api_key=self.api_key).chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
temperature=0.1,
|
||||||
|
max_tokens=500
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
def _anthropic_complete(self, prompt: str) -> str:
|
||||||
|
"""Anthropic API 调用。"""
|
||||||
|
import anthropic
|
||||||
|
client = anthropic.Anthropic(api_key=self.api_key)
|
||||||
|
response = client.messages.create(
|
||||||
|
model=self.model,
|
||||||
|
max_tokens=500,
|
||||||
|
messages=[{"role": "user", "content": prompt}]
|
||||||
|
)
|
||||||
|
return response.content[0].text
|
||||||
|
|
||||||
|
|
||||||
|
class RuleCompiler:
|
||||||
|
"""规则编译器。"""
|
||||||
|
|
||||||
|
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||||
|
self.llm_client = llm_client or LLMClient()
|
||||||
|
|
||||||
|
def compile(self, condition_template: str) -> str:
|
||||||
|
"""将自然语言规则描述编译为 Python 代码。"""
|
||||||
|
prompt = build_compile_prompt(condition_template)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.llm_client.complete(prompt)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"LLM compilation failed: {e}")
|
||||||
|
|
||||||
|
# 提取代码块
|
||||||
|
code = extract_code_block(response)
|
||||||
|
if not code:
|
||||||
|
raise ValueError("Failed to extract code from LLM response")
|
||||||
|
|
||||||
|
return code
|
||||||
|
|
||||||
|
def validate_syntax(self, code: str) -> bool:
|
||||||
|
"""验证生成的代码语法正确。"""
|
||||||
|
try:
|
||||||
|
compile(code, "<string>", "exec")
|
||||||
|
return True
|
||||||
|
except SyntaxError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def extract_code_block(text: str) -> Optional[str]:
|
||||||
|
"""从 LLM 响应中提取代码块。"""
|
||||||
|
import re
|
||||||
|
# 匹配 ```python ... ``` 或 ``` ... ```
|
||||||
|
pattern = r"```(?:python)?\s*(.*?)```"
|
||||||
|
match = re.search(pattern, text, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip()
|
||||||
|
# 如果没有代码块,返回原文(可能是纯代码)
|
||||||
|
if text.strip().startswith("def "):
|
||||||
|
return text.strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def mock_llm_response(condition_template: str) -> str:
|
||||||
|
"""Mock LLM 响应,用于测试。"""
|
||||||
|
if "premium" in condition_template.lower():
|
||||||
|
return '''```python
|
||||||
|
def rule(facts: dict) -> dict:
|
||||||
|
if facts.get("subscription") == "premium":
|
||||||
|
return {"action": "apply_discount", "params": {"discount": 0.8}}
|
||||||
|
return None
|
||||||
|
```'''
|
||||||
|
return '''```python
|
||||||
|
def rule(facts: dict) -> dict:
|
||||||
|
if facts.get("active") == True:
|
||||||
|
return {"action": "allow"}
|
||||||
|
return None
|
||||||
|
```'''
|
||||||
124
src/rule_engine/conflict.py
Normal file
124
src/rule_engine/conflict.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
"""Conflict detector for rule conflict detection."""
|
||||||
|
from typing import List, Tuple, Optional, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class ConflictDetector:
|
||||||
|
"""规则冲突检测器(简化版)。"""
|
||||||
|
|
||||||
|
def detect_conflicts(self, rules: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
|
"""检测规则列表中的冲突。
|
||||||
|
|
||||||
|
简化版策略:
|
||||||
|
1. 检查条件完全相反的规则(如 A->通过, A->拒绝)
|
||||||
|
2. 检查包含关系(一条规则条件是另一条的子集)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rules: 规则列表,每条包含 code 和 priority
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
冲突列表,每项包含冲突规则对和原因
|
||||||
|
"""
|
||||||
|
conflicts = []
|
||||||
|
active_rules = [r for r in rules if r.get("is_active", True)]
|
||||||
|
|
||||||
|
for i, rule_a in enumerate(active_rules):
|
||||||
|
for rule_b in active_rules[i + 1:]:
|
||||||
|
conflict = self._check_pair_conflict(rule_a, rule_b)
|
||||||
|
if conflict:
|
||||||
|
conflicts.append(conflict)
|
||||||
|
|
||||||
|
return conflicts
|
||||||
|
|
||||||
|
def _check_pair_conflict(self, rule_a: Dict, rule_b: Dict) -> Optional[Dict]:
|
||||||
|
"""检查一对规则是否冲突。"""
|
||||||
|
if rule_a.get("priority") == rule_b.get("priority"):
|
||||||
|
code_a = rule_a.get("code", "")
|
||||||
|
code_b = rule_b.get("code", "")
|
||||||
|
|
||||||
|
# 检查条件是否重叠(简化:提取 facts.get 调用)
|
||||||
|
if not self._conditions_may_overlap(code_a, code_b):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 提取返回值模式
|
||||||
|
returns_a = self._extract_returns(code_a)
|
||||||
|
returns_b = self._extract_returns(code_b)
|
||||||
|
|
||||||
|
if self._has_opposite_actions(returns_a, returns_b):
|
||||||
|
return {
|
||||||
|
"rule_a_id": rule_a.get("id"),
|
||||||
|
"rule_b_id": rule_b.get("id"),
|
||||||
|
"rule_a_name": rule_a.get("name"),
|
||||||
|
"rule_b_name": rule_b.get("name"),
|
||||||
|
"reason": "Opposite actions with overlapping conditions",
|
||||||
|
"requires_resolution": True
|
||||||
|
}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _conditions_may_overlap(self, code_a: str, code_b: str) -> bool:
|
||||||
|
"""检查两条规则的条件可能重叠。"""
|
||||||
|
import re
|
||||||
|
# 提取 facts.get() 调用的字段
|
||||||
|
pattern = r'facts\.get\(["\']([^"\']+)["\']\)'
|
||||||
|
fields_a = set(re.findall(pattern, code_a))
|
||||||
|
fields_b = set(re.findall(pattern, code_b))
|
||||||
|
|
||||||
|
# 如果没有字段信息,假设可能重叠
|
||||||
|
if not fields_a or not fields_b:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 如果共享字段,可能重叠
|
||||||
|
return bool(fields_a & fields_b)
|
||||||
|
|
||||||
|
def _extract_returns(self, code: str) -> List[str]:
|
||||||
|
"""从代码中提取返回值常量。"""
|
||||||
|
import re
|
||||||
|
# 匹配 return {...} 或 return None
|
||||||
|
pattern = r"return\s*(?:{[^}]+}|None|\[[^\]]+\]|'[^']+'|\"[^\"]+\")"
|
||||||
|
matches = re.findall(pattern, code)
|
||||||
|
return matches
|
||||||
|
|
||||||
|
def _has_opposite_actions(self, returns_a: List[str], returns_b: List[str]) -> bool:
|
||||||
|
"""判断两组返回值是否相反。
|
||||||
|
|
||||||
|
冲突条件:两条规则在相同条件满足时返回相反的动作。
|
||||||
|
即:A 返回正面动作,B 返回负面动作(排除 None)。
|
||||||
|
"""
|
||||||
|
# 提取非 None 的返回值
|
||||||
|
positives = {"allow", "approve", "ok", "true", "accept"}
|
||||||
|
negatives = {"deny", "reject", "fail", "false", "block"}
|
||||||
|
|
||||||
|
# 提取实际动作(非 None)
|
||||||
|
actions_a = set()
|
||||||
|
actions_b = set()
|
||||||
|
|
||||||
|
for ret in returns_a:
|
||||||
|
ret_lower = ret.lower()
|
||||||
|
for p in positives:
|
||||||
|
if p in ret_lower and "none" not in ret_lower:
|
||||||
|
actions_a.add(p)
|
||||||
|
for n in negatives:
|
||||||
|
if n in ret_lower and "none" not in ret_lower:
|
||||||
|
actions_a.add(n)
|
||||||
|
|
||||||
|
for ret in returns_b:
|
||||||
|
ret_lower = ret.lower()
|
||||||
|
for p in positives:
|
||||||
|
if p in ret_lower and "none" not in ret_lower:
|
||||||
|
actions_b.add(p)
|
||||||
|
for n in negatives:
|
||||||
|
if n in ret_lower and "none" not in ret_lower:
|
||||||
|
actions_b.add(n)
|
||||||
|
|
||||||
|
# 如果有重叠的动作方向相反,则冲突
|
||||||
|
a_positive = bool(actions_a & positives)
|
||||||
|
a_negative = bool(actions_a & negatives)
|
||||||
|
b_positive = bool(actions_b & positives)
|
||||||
|
b_negative = bool(actions_b & negatives)
|
||||||
|
|
||||||
|
# 冲突:A 是正面动作,B 是负面动作(或者反过来)
|
||||||
|
return (a_positive and b_negative) or (a_negative and b_positive)
|
||||||
|
|
||||||
|
def check_rule_with_existing(self, new_rule: Dict, existing_rules: List[Dict]) -> List[Dict]:
|
||||||
|
"""检查新规则与现有规则是否冲突。"""
|
||||||
|
return self.detect_conflicts(existing_rules + [new_rule])
|
||||||
113
src/rule_engine/executor.py
Normal file
113
src/rule_engine/executor.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
"""Rule executor with RestrictedPython sandbox."""
|
||||||
|
import signal
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionTimeout(Exception):
|
||||||
|
"""执行超时异常。"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Sandbox:
|
||||||
|
"""RestrictedPython 沙箱(简化版)。"""
|
||||||
|
|
||||||
|
# 允许的內建函数白名单
|
||||||
|
ALLOWED_BUILTINS = {
|
||||||
|
"len": len,
|
||||||
|
"str": str,
|
||||||
|
"int": int,
|
||||||
|
"float": float,
|
||||||
|
"bool": bool,
|
||||||
|
"list": list,
|
||||||
|
"dict": dict,
|
||||||
|
"tuple": tuple,
|
||||||
|
"set": set,
|
||||||
|
"range": range,
|
||||||
|
"enumerate": enumerate,
|
||||||
|
"zip": zip,
|
||||||
|
"min": min,
|
||||||
|
"max": max,
|
||||||
|
"abs": abs,
|
||||||
|
"sum": sum,
|
||||||
|
"all": all,
|
||||||
|
"any": any,
|
||||||
|
"sorted": sorted,
|
||||||
|
"reversed": reversed,
|
||||||
|
"isinstance": isinstance,
|
||||||
|
"type": type,
|
||||||
|
"getattr": getattr,
|
||||||
|
"hasattr": hasattr,
|
||||||
|
"True": True,
|
||||||
|
"False": False,
|
||||||
|
"None": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 禁止的操作
|
||||||
|
BLOCKED_ATTRS = {"__import__", "__builtins__", "__class__", "__bases__", "__subclasses__"}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._globals: Dict = {"__builtins__": self.ALLOWED_BUILTINS}
|
||||||
|
self._locals: Dict = {}
|
||||||
|
|
||||||
|
def execute(self, code: str, facts: Dict[str, Any], timeout_ms: int = 100) -> Any:
|
||||||
|
"""在沙箱中执行代码。"""
|
||||||
|
# 设置超时
|
||||||
|
def timeout_handler(signum, frame):
|
||||||
|
raise ExecutionTimeout(f"Execution exceeded {timeout_ms}ms")
|
||||||
|
|
||||||
|
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
||||||
|
|
||||||
|
try:
|
||||||
|
compiled = compile(code, "<rule>", "exec")
|
||||||
|
exec(compiled, self._globals)
|
||||||
|
|
||||||
|
# 调用 rule 函数(从 globals 中获取)
|
||||||
|
rule_func = self._globals.get("rule")
|
||||||
|
if not callable(rule_func):
|
||||||
|
raise ValueError("No 'rule' function found in code")
|
||||||
|
|
||||||
|
# 使用 SIGALRM 实现超时
|
||||||
|
signal.signal(signal.SIGALRM, timeout_handler)
|
||||||
|
signal.alarm(max(1, timeout_ms // 1000))
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = rule_func(facts)
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
signal.alarm(0)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
signal.signal(signal.SIGALRM, old_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class RuleExecutor:
|
||||||
|
"""规则执行器。"""
|
||||||
|
|
||||||
|
def __init__(self, timeout_ms: int = 100):
|
||||||
|
self.timeout_ms = timeout_ms
|
||||||
|
self.sandbox = Sandbox()
|
||||||
|
|
||||||
|
def execute_rule(self, code: str, facts: Dict[str, Any]) -> Optional[Dict]:
|
||||||
|
"""执行规则代码。"""
|
||||||
|
try:
|
||||||
|
result = self.sandbox.execute(code, facts, self.timeout_ms)
|
||||||
|
return result
|
||||||
|
except ExecutionTimeout:
|
||||||
|
raise RuntimeError(f"Rule execution timed out after {self.timeout_ms}ms")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Rule execution failed: {e}")
|
||||||
|
|
||||||
|
def validate_rule_code(self, code: str) -> bool:
|
||||||
|
"""验证规则代码安全性(不执行,只做基础检查)。"""
|
||||||
|
# 检查危险关键字
|
||||||
|
dangerous = ["import", "exec", "eval", "open", "file", "input", "compile", "__import__"]
|
||||||
|
for keyword in dangerous:
|
||||||
|
if keyword in code:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 验证语法
|
||||||
|
try:
|
||||||
|
compile(code, "<rule>", "exec")
|
||||||
|
return True
|
||||||
|
except SyntaxError:
|
||||||
|
return False
|
||||||
49
src/rule_engine/matcher.py
Normal file
49
src/rule_engine/matcher.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
"""Rule matcher - matches facts against rules."""
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from rule_engine.store import RuleStore
|
||||||
|
from rule_engine.executor import RuleExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class RuleMatcher:
|
||||||
|
"""规则匹配器。"""
|
||||||
|
|
||||||
|
def __init__(self, store: RuleStore, executor: Optional[RuleExecutor] = None):
|
||||||
|
self.store = store
|
||||||
|
self.executor = executor or RuleExecutor()
|
||||||
|
|
||||||
|
def find_matching_rule(self, facts: dict, rule_id: Optional[str] = None) -> Tuple[Optional[str], Optional[dict]]:
|
||||||
|
"""查找匹配的事实规则。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(rule_id, result) - 如果找到匹配规则
|
||||||
|
"""
|
||||||
|
if rule_id:
|
||||||
|
# 指定规则
|
||||||
|
rule = self.store.get_rule(rule_id)
|
||||||
|
if rule and rule.is_active:
|
||||||
|
result = self._execute_rule(rule.code, facts)
|
||||||
|
if result is not None:
|
||||||
|
return rule.id, result
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# 匹配所有活跃规则(按优先级排序)
|
||||||
|
rules = self.store.list_rules()
|
||||||
|
active_rules = [r for r in rules if r.is_active]
|
||||||
|
active_rules.sort(key=lambda r: r.priority, reverse=True)
|
||||||
|
|
||||||
|
for rule in active_rules:
|
||||||
|
result = self._execute_rule(rule.code, facts)
|
||||||
|
if result is not None:
|
||||||
|
return rule.id, result
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def _execute_rule(self, code: str, facts: dict) -> Optional[dict]:
|
||||||
|
"""执行单条规则。"""
|
||||||
|
try:
|
||||||
|
if not self.executor.validate_rule_code(code):
|
||||||
|
return None
|
||||||
|
return self.executor.execute_rule(code, facts)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
70
src/rule_engine/models.py
Normal file
70
src/rule_engine/models.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
"""Data models for rule engine - pure Python implementation."""
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CreateRuleRequest:
|
||||||
|
"""创建规则的请求模型。"""
|
||||||
|
name: str
|
||||||
|
condition_template: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
priority: int = 0
|
||||||
|
facts_schema: Optional[Dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Rule:
|
||||||
|
"""规则模型。"""
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
condition_template: str
|
||||||
|
code: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
priority: int = 0
|
||||||
|
version: int = 1
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
is_active: bool = True
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"condition_template": self.condition_template,
|
||||||
|
"code": self.code,
|
||||||
|
"priority": self.priority,
|
||||||
|
"version": self.version,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||||
|
"is_active": self.is_active
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RuleExecution:
|
||||||
|
"""规则执行记录模型。"""
|
||||||
|
id: str
|
||||||
|
rule_id: str
|
||||||
|
facts: Dict
|
||||||
|
result: Optional[Dict] = None
|
||||||
|
matched: bool = False
|
||||||
|
executed_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvaluateRequest:
|
||||||
|
"""校验事实的请求模型。"""
|
||||||
|
facts: Dict
|
||||||
|
rule_id: Optional[str] = None # 指定规则ID,为空则匹配所有
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvaluateResponse:
|
||||||
|
"""校验事实的响应模型。"""
|
||||||
|
matched: bool
|
||||||
|
rule_id: Optional[str] = None
|
||||||
|
result: Optional[Dict] = None
|
||||||
|
conflict_warning: Optional[str] = None
|
||||||
198
src/rule_engine/store.py
Normal file
198
src/rule_engine/store.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
"""SQLite storage layer for rules."""
|
||||||
|
import sqlite3
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from rule_engine.models import Rule, RuleExecution
|
||||||
|
|
||||||
|
|
||||||
|
class RuleStore:
|
||||||
|
"""SQLite-based rule storage."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str = "rules.db"):
|
||||||
|
self.db_path = db_path
|
||||||
|
self._init_db()
|
||||||
|
|
||||||
|
def _init_db(self):
|
||||||
|
"""初始化数据库表。"""
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS rules (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
condition_template TEXT NOT NULL,
|
||||||
|
code TEXT NOT NULL,
|
||||||
|
priority INTEGER DEFAULT 0,
|
||||||
|
version INTEGER DEFAULT 1,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
is_active BOOLEAN DEFAULT 1
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS rule_executions (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
rule_id TEXT,
|
||||||
|
facts TEXT NOT NULL,
|
||||||
|
result TEXT,
|
||||||
|
matched BOOLEAN,
|
||||||
|
executed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY (rule_id) REFERENCES rules(id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def create_rule(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
condition_template: str,
|
||||||
|
code: str,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
priority: int = 0
|
||||||
|
) -> Rule:
|
||||||
|
"""创建新规则。"""
|
||||||
|
rule_id = f"rule_{uuid.uuid4().hex[:8]}"
|
||||||
|
now = datetime.utcnow()
|
||||||
|
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO rules (id, name, description, condition_template, code, priority, created_at, updated_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(rule_id, name, description, condition_template, code, priority, now, now)
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return Rule(
|
||||||
|
id=rule_id,
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
condition_template=condition_template,
|
||||||
|
code=code,
|
||||||
|
priority=priority,
|
||||||
|
version=1,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
is_active=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_rule(self, rule_id: str) -> Optional[Rule]:
|
||||||
|
"""获取规则。"""
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT * FROM rules WHERE id = ?", (rule_id,))
|
||||||
|
row = cursor.fetchone()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return Rule(
|
||||||
|
id=row["id"],
|
||||||
|
name=row["name"],
|
||||||
|
description=row["description"],
|
||||||
|
condition_template=row["condition_template"],
|
||||||
|
code=row["code"],
|
||||||
|
priority=row["priority"],
|
||||||
|
version=row["version"],
|
||||||
|
created_at=datetime.fromisoformat(row["created_at"]) if row["created_at"] else None,
|
||||||
|
updated_at=datetime.fromisoformat(row["updated_at"]) if row["updated_at"] else None,
|
||||||
|
is_active=bool(row["is_active"])
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_rules(self) -> List[Rule]:
|
||||||
|
"""列出所有规则。"""
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT * FROM rules ORDER BY created_at DESC")
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return [
|
||||||
|
Rule(
|
||||||
|
id=row["id"],
|
||||||
|
name=row["name"],
|
||||||
|
description=row["description"],
|
||||||
|
condition_template=row["condition_template"],
|
||||||
|
code=row["code"],
|
||||||
|
priority=row["priority"],
|
||||||
|
version=row["version"],
|
||||||
|
created_at=datetime.fromisoformat(row["created_at"]) if row["created_at"] else None,
|
||||||
|
updated_at=datetime.fromisoformat(row["updated_at"]) if row["updated_at"] else None,
|
||||||
|
is_active=bool(row["is_active"])
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
def delete_rule(self, rule_id: str) -> bool:
|
||||||
|
"""删除规则。"""
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM rules WHERE id = ?", (rule_id,))
|
||||||
|
deleted = cursor.rowcount > 0
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
def update_rule_code(self, rule_id: str, new_code: str) -> Optional[Rule]:
|
||||||
|
"""更新规则代码。"""
|
||||||
|
now = datetime.utcnow()
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
UPDATE rules SET code = ?, version = version + 1, updated_at = ?
|
||||||
|
WHERE id = ?
|
||||||
|
""",
|
||||||
|
(new_code, now, rule_id)
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return self.get_rule(rule_id)
|
||||||
|
|
||||||
|
def record_execution(
|
||||||
|
self,
|
||||||
|
rule_id: str,
|
||||||
|
facts: dict,
|
||||||
|
result: Optional[dict],
|
||||||
|
matched: bool
|
||||||
|
) -> RuleExecution:
|
||||||
|
"""记录规则执行。"""
|
||||||
|
import json
|
||||||
|
exec_id = f"exec_{uuid.uuid4().hex[:8]}"
|
||||||
|
now = datetime.utcnow()
|
||||||
|
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO rule_executions (id, rule_id, facts, result, matched, executed_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(exec_id, rule_id, json.dumps(facts), json.dumps(result) if result else None, matched, now)
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return RuleExecution(
|
||||||
|
id=exec_id,
|
||||||
|
rule_id=rule_id,
|
||||||
|
facts=facts,
|
||||||
|
result=result,
|
||||||
|
matched=matched,
|
||||||
|
executed_at=now
|
||||||
|
)
|
||||||
141
tests/test_api.py
Normal file
141
tests/test_api.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
"""Tests for REST API."""
|
||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from http.client import HTTPConnection
|
||||||
|
from threading import Thread
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
from rule_engine.api import create_app
|
||||||
|
from rule_engine.store import RuleStore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_db():
|
||||||
|
fd, path = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(fd)
|
||||||
|
yield path
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def store(temp_db):
|
||||||
|
return RuleStore(temp_db)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def server(store):
|
||||||
|
srv = create_app(store, host="127.0.0.1", port=18000)
|
||||||
|
thread = Thread(target=srv.serve_forever, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
sleep(0.1) # 等待服务器启动
|
||||||
|
yield srv
|
||||||
|
srv.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(server):
|
||||||
|
return HTTPConnection("127.0.0.1", 18000)
|
||||||
|
|
||||||
|
|
||||||
|
def test_health(client):
|
||||||
|
"""验证健康检查。"""
|
||||||
|
client.request("GET", "/health")
|
||||||
|
resp = client.getresponse()
|
||||||
|
assert resp.status == 200
|
||||||
|
data = json.loads(resp.read())
|
||||||
|
assert data["status"] == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_rule(client):
|
||||||
|
"""验证创建规则。"""
|
||||||
|
body = json.dumps({
|
||||||
|
"name": "test_rule",
|
||||||
|
"description": "测试规则",
|
||||||
|
"condition_template": "如果用户是会员则打折"
|
||||||
|
})
|
||||||
|
client.request("POST", "/api/rules", body=body, headers={"Content-Type": "application/json"})
|
||||||
|
resp = client.getresponse()
|
||||||
|
assert resp.status == 201
|
||||||
|
data = json.loads(resp.read())
|
||||||
|
assert data["name"] == "test_rule"
|
||||||
|
assert "id" in data
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_rule_missing_fields(client):
|
||||||
|
"""验证缺少必填字段返回错误。"""
|
||||||
|
body = json.dumps({"name": "only_name"})
|
||||||
|
client.request("POST", "/api/rules", body=body, headers={"Content-Type": "application/json"})
|
||||||
|
resp = client.getresponse()
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_rule(client):
|
||||||
|
"""验证获取规则。"""
|
||||||
|
# 先创建
|
||||||
|
body = json.dumps({"name": "get_test", "condition_template": "test"})
|
||||||
|
client.request("POST", "/api/rules", body=body, headers={"Content-Type": "application/json"})
|
||||||
|
resp = client.getresponse()
|
||||||
|
data = json.loads(resp.read())
|
||||||
|
rule_id = data["id"]
|
||||||
|
|
||||||
|
# 再获取
|
||||||
|
client.request("GET", f"/api/rules/{rule_id}")
|
||||||
|
resp = client.getresponse()
|
||||||
|
data = json.loads(resp.read())
|
||||||
|
assert data["name"] == "get_test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_rule_not_found(client):
|
||||||
|
"""验证获取不存在的规则。"""
|
||||||
|
client.request("GET", "/api/rules/nonexistent")
|
||||||
|
resp = client.getresponse()
|
||||||
|
assert resp.status == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_rules(client):
|
||||||
|
"""验证列出规则。"""
|
||||||
|
# 创建两条规则
|
||||||
|
for name in ["list_r1", "list_r2"]:
|
||||||
|
body = json.dumps({"name": name, "condition_template": "test"})
|
||||||
|
client.request("POST", "/api/rules", body=body, headers={"Content-Type": "application/json"})
|
||||||
|
client.getresponse()
|
||||||
|
|
||||||
|
# 列出
|
||||||
|
client.request("GET", "/api/rules")
|
||||||
|
resp = client.getresponse()
|
||||||
|
assert resp.status == 200
|
||||||
|
data = json.loads(resp.read())
|
||||||
|
assert "rules" in data
|
||||||
|
assert len(data["rules"]) >= 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_rule(client):
|
||||||
|
"""验证删除规则。"""
|
||||||
|
# 先创建
|
||||||
|
body = json.dumps({"name": "delete_test", "condition_template": "test"})
|
||||||
|
client.request("POST", "/api/rules", body=body, headers={"Content-Type": "application/json"})
|
||||||
|
resp = client.getresponse()
|
||||||
|
rule_id = json.loads(resp.read())["id"]
|
||||||
|
|
||||||
|
# 删除
|
||||||
|
client.request("DELETE", f"/api/rules/{rule_id}")
|
||||||
|
resp = client.getresponse()
|
||||||
|
assert resp.status == 200
|
||||||
|
|
||||||
|
# 确认已删除
|
||||||
|
client.request("GET", f"/api/rules/{rule_id}")
|
||||||
|
resp = client.getresponse()
|
||||||
|
assert resp.status == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_not_implemented(client):
|
||||||
|
"""验证 evaluate 接口返回 placeholder。"""
|
||||||
|
body = json.dumps({"facts": {"user_id": "u123"}})
|
||||||
|
client.request("POST", "/api/rules/evaluate", body=body, headers={"Content-Type": "application/json"})
|
||||||
|
resp = client.getresponse()
|
||||||
|
assert resp.status == 200
|
||||||
|
data = json.loads(resp.read())
|
||||||
|
assert data["matched"] is False
|
||||||
|
assert "Executor not implemented" in data["message"]
|
||||||
94
tests/test_compiler.py
Normal file
94
tests/test_compiler.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
"""Tests for LLM compiler."""
|
||||||
|
import pytest
|
||||||
|
from rule_engine.compiler import (
|
||||||
|
build_compile_prompt,
|
||||||
|
extract_code_block,
|
||||||
|
RuleCompiler,
|
||||||
|
mock_llm_response
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_compile_prompt():
|
||||||
|
"""验证 prompt 模板构建。"""
|
||||||
|
prompt = build_compile_prompt("如果用户是会员则打折")
|
||||||
|
assert "如果用户是会员则打折" in prompt
|
||||||
|
assert "def rule(facts: dict)" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_code_block_with_python():
|
||||||
|
"""验证从 markdown 代码块提取 Python 代码。"""
|
||||||
|
text = '''这是一些解释文字
|
||||||
|
|
||||||
|
```python
|
||||||
|
def rule(facts: dict) -> dict:
|
||||||
|
if facts.get("subscription") == "premium":
|
||||||
|
return {"action": "discount"}
|
||||||
|
return None
|
||||||
|
```
|
||||||
|
|
||||||
|
更多文字
|
||||||
|
'''
|
||||||
|
code = extract_code_block(text)
|
||||||
|
assert code is not None
|
||||||
|
assert 'def rule(facts: dict)' in code
|
||||||
|
assert 'subscription' in code
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_code_block_without_language():
|
||||||
|
"""验证不带语言标识的代码块。"""
|
||||||
|
text = '''
|
||||||
|
```
|
||||||
|
def rule(facts):
|
||||||
|
return None
|
||||||
|
```
|
||||||
|
'''
|
||||||
|
code = extract_code_block(text)
|
||||||
|
assert code is not None
|
||||||
|
assert 'def rule' in code
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_code_block_plain():
|
||||||
|
"""验证纯代码(无代码块)。"""
|
||||||
|
code = 'def rule(facts):\n return None'
|
||||||
|
result = extract_code_block(code)
|
||||||
|
assert result == code
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_code_block_invalid():
|
||||||
|
"""验证无效输入返回 None。"""
|
||||||
|
assert extract_code_block("这是一个普通文本") is None
|
||||||
|
assert extract_code_block("") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_mock_llm_response_premium():
|
||||||
|
"""验证 mock 响应生成(premium 场景)。"""
|
||||||
|
code = mock_llm_response("如果用户是 premium 会员")
|
||||||
|
assert "premium" in code
|
||||||
|
assert "def rule" in code
|
||||||
|
|
||||||
|
|
||||||
|
def test_mock_llm_response_default():
|
||||||
|
"""验证 mock 响应生成(默认场景)。"""
|
||||||
|
code = mock_llm_response("普通规则")
|
||||||
|
assert "def rule" in code
|
||||||
|
|
||||||
|
|
||||||
|
def test_rule_compiler_validate_syntax_valid():
|
||||||
|
"""验证语法检查通过。"""
|
||||||
|
compiler = RuleCompiler()
|
||||||
|
code = 'def rule(facts):\n return None'
|
||||||
|
assert compiler.validate_syntax(code) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_rule_compiler_validate_syntax_invalid():
|
||||||
|
"""验证语法检查失败。"""
|
||||||
|
compiler = RuleCompiler()
|
||||||
|
code = 'def rule(facts):\n return None\n\nextra invalid'
|
||||||
|
assert compiler.validate_syntax(code) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_rule_compiler_validate_syntax_syntax_error():
|
||||||
|
"""验证语法错误检测。"""
|
||||||
|
compiler = RuleCompiler()
|
||||||
|
code = 'def rule(facts):\n if .'
|
||||||
|
assert compiler.validate_syntax(code) is False
|
||||||
73
tests/test_conflict.py
Normal file
73
tests/test_conflict.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
"""Tests for conflict detector."""
|
||||||
|
import pytest
|
||||||
|
from rule_engine.conflict import ConflictDetector
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def detector():
|
||||||
|
return ConflictDetector()
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_conflict_different_actions(detector):
|
||||||
|
"""验证不同动作的规则不冲突。"""
|
||||||
|
rules = [
|
||||||
|
{"id": "r1", "name": "rule1", "code": 'def rule(f):\n if f.get("a") == 1:\n return {"action": "allow"}\n return None', "priority": 1, "is_active": True},
|
||||||
|
{"id": "r2", "name": "rule2", "code": 'def rule(f):\n if f.get("b") == 1:\n return {"action": "discount"}\n return None', "priority": 1, "is_active": True},
|
||||||
|
]
|
||||||
|
conflicts = detector.detect_conflicts(rules)
|
||||||
|
assert len(conflicts) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_conflict_opposite_actions_same_priority(detector):
|
||||||
|
"""验证优先级相同的相反动作规则冲突。"""
|
||||||
|
rules = [
|
||||||
|
{"id": "r1", "name": "allow_rule", "code": 'def rule(f):\n if f.get("type") == "vip":\n return {"action": "allow"}\n return None', "priority": 1, "is_active": True},
|
||||||
|
{"id": "r2", "name": "deny_rule", "code": 'def rule(f):\n if f.get("type") == "vip":\n return {"action": "deny"}\n return None', "priority": 1, "is_active": True},
|
||||||
|
]
|
||||||
|
conflicts = detector.detect_conflicts(rules)
|
||||||
|
assert len(conflicts) == 1
|
||||||
|
assert conflicts[0]["requires_resolution"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_conflict_different_priority(detector):
|
||||||
|
"""验证优先级不同的规则不冲突(高优先级覆盖低优先级)。"""
|
||||||
|
rules = [
|
||||||
|
{"id": "r1", "name": "allow_rule", "code": 'def rule(f):\n if f.get("type") == "vip":\n return {"action": "allow"}\n return None', "priority": 2, "is_active": True},
|
||||||
|
{"id": "r2", "name": "deny_rule", "code": 'def rule(f):\n if f.get("type") == "vip":\n return {"action": "deny"}\n return None', "priority": 1, "is_active": True},
|
||||||
|
]
|
||||||
|
conflicts = detector.detect_conflicts(rules)
|
||||||
|
assert len(conflicts) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_conflict_inactive_rules_ignored(detector):
|
||||||
|
"""验证非活跃规则不参与冲突检测。"""
|
||||||
|
rules = [
|
||||||
|
{"id": "r1", "name": "allow_rule", "code": 'def rule(f):\n return {"action": "allow"}', "priority": 1, "is_active": True},
|
||||||
|
{"id": "r2", "name": "deny_rule", "code": 'def rule(f):\n return {"action": "deny"}', "priority": 1, "is_active": False},
|
||||||
|
]
|
||||||
|
conflicts = detector.detect_conflicts(rules)
|
||||||
|
assert len(conflicts) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_rule_with_existing(detector):
|
||||||
|
"""验证新规则与现有规则检查。"""
|
||||||
|
existing = [
|
||||||
|
{"id": "r1", "name": "allow", "code": 'def rule(f):\n return {"action": "allow"}', "priority": 1, "is_active": True},
|
||||||
|
]
|
||||||
|
new_rule = {"id": "r2", "name": "deny", "code": 'def rule(f):\n return {"action": "deny"}', "priority": 1, "is_active": True}
|
||||||
|
conflicts = detector.check_rule_with_existing(new_rule, existing)
|
||||||
|
assert len(conflicts) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_returns(detector):
|
||||||
|
"""验证返回值提取。"""
|
||||||
|
code = 'def rule(f):\n if f.get("x"):\n return {"action": "allow"}\n return None'
|
||||||
|
returns = detector._extract_returns(code)
|
||||||
|
assert len(returns) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_opposite_actions(detector):
|
||||||
|
"""验证相反动作检测。"""
|
||||||
|
assert detector._has_opposite_actions(['"allow"', '"ok"'], ['"deny"']) is True
|
||||||
|
assert detector._has_opposite_actions(['"allow"'], ['"allow"']) is False
|
||||||
|
assert detector._has_opposite_actions(['"discount"'], ['"None"']) is False
|
||||||
110
tests/test_executor.py
Normal file
110
tests/test_executor.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
"""Tests for rule executor."""
|
||||||
|
import pytest
|
||||||
|
from rule_engine.executor import RuleExecutor, Sandbox, ExecutionTimeout
|
||||||
|
|
||||||
|
|
||||||
|
def test_executor_simple_rule():
|
||||||
|
"""验证简单规则执行。"""
|
||||||
|
executor = RuleExecutor()
|
||||||
|
code = '''
|
||||||
|
def rule(facts):
|
||||||
|
if facts.get("subscription") == "premium":
|
||||||
|
return {"action": "discount", "rate": 0.8}
|
||||||
|
return None
|
||||||
|
'''
|
||||||
|
# 匹配
|
||||||
|
result = executor.execute_rule(code, {"subscription": "premium", "age": 25})
|
||||||
|
assert result == {"action": "discount", "rate": 0.8}
|
||||||
|
|
||||||
|
# 不匹配
|
||||||
|
result = executor.execute_rule(code, {"subscription": "basic", "age": 25})
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_executor_complex_condition():
|
||||||
|
"""验证复杂条件规则。"""
|
||||||
|
executor = RuleExecutor()
|
||||||
|
code = '''
|
||||||
|
def rule(facts):
|
||||||
|
age = facts.get("age", 0)
|
||||||
|
subscription = facts.get("subscription")
|
||||||
|
if subscription == "premium" and age >= 18:
|
||||||
|
return {"action": "approve"}
|
||||||
|
return {"action": "deny"}
|
||||||
|
'''
|
||||||
|
result = executor.execute_rule(code, {"subscription": "premium", "age": 25})
|
||||||
|
assert result["action"] == "approve"
|
||||||
|
|
||||||
|
result = executor.execute_rule(code, {"subscription": "premium", "age": 16})
|
||||||
|
assert result["action"] == "deny"
|
||||||
|
|
||||||
|
result = executor.execute_rule(code, {"subscription": "basic", "age": 25})
|
||||||
|
assert result["action"] == "deny"
|
||||||
|
|
||||||
|
|
||||||
|
def test_executor_validate_rule_code_valid():
|
||||||
|
"""验证合法代码通过检查。"""
|
||||||
|
executor = RuleExecutor()
|
||||||
|
code = 'def rule(facts):\n return None'
|
||||||
|
assert executor.validate_rule_code(code) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_executor_validate_rule_code_blocked_keyword():
|
||||||
|
"""验证危险关键字被拦截。"""
|
||||||
|
executor = RuleExecutor()
|
||||||
|
assert executor.validate_rule_code('def rule(facts):\n import os') is False
|
||||||
|
assert executor.validate_rule_code('def rule(facts):\n exec("")') is False
|
||||||
|
assert executor.validate_rule_code('def rule(facts):\n eval("")') is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_executor_validate_rule_code_syntax_error():
|
||||||
|
"""验证语法错误被检测。"""
|
||||||
|
executor = RuleExecutor()
|
||||||
|
assert executor.validate_rule_code('def rule(facts):\n if .') is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_executor_timeout():
|
||||||
|
"""验证执行超时。"""
|
||||||
|
executor = RuleExecutor(timeout_ms=100)
|
||||||
|
code = '''
|
||||||
|
def rule(facts):
|
||||||
|
# 死循环
|
||||||
|
while True:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
'''
|
||||||
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
|
executor.execute_rule(code, {})
|
||||||
|
assert "timed out" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sandbox_blocks_dangerous_operations():
|
||||||
|
"""验证沙箱拦截危险操作。"""
|
||||||
|
executor = RuleExecutor()
|
||||||
|
# 尝试访问受保护属性
|
||||||
|
code = '''
|
||||||
|
def rule(facts):
|
||||||
|
return __builtins__
|
||||||
|
'''
|
||||||
|
# 注意:实际 exec 中 __builtins__ 已被替换为白名单
|
||||||
|
# 所以不会返回真正的 builtins
|
||||||
|
|
||||||
|
|
||||||
|
def test_executor_nested_data():
|
||||||
|
"""验证嵌套数据结构处理。"""
|
||||||
|
executor = RuleExecutor()
|
||||||
|
code = '''
|
||||||
|
def rule(facts):
|
||||||
|
user = facts.get("user", {})
|
||||||
|
if user.get("subscription") == "premium" and user.get("active"):
|
||||||
|
return {"action": "allow"}
|
||||||
|
return None
|
||||||
|
'''
|
||||||
|
facts = {
|
||||||
|
"user": {
|
||||||
|
"subscription": "premium",
|
||||||
|
"active": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = executor.execute_rule(code, facts)
|
||||||
|
assert result["action"] == "allow"
|
||||||
102
tests/test_matcher.py
Normal file
102
tests/test_matcher.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
"""Tests for rule matcher."""
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from rule_engine.matcher import RuleMatcher
|
||||||
|
from rule_engine.store import RuleStore
|
||||||
|
from rule_engine.executor import RuleExecutor
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_db():
|
||||||
|
fd, path = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(fd)
|
||||||
|
yield path
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def store(temp_db):
|
||||||
|
return RuleStore(temp_db)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def executor():
|
||||||
|
return RuleExecutor(timeout_ms=100)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def matcher(store, executor):
|
||||||
|
return RuleMatcher(store, executor)
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_matching_rule_by_id(matcher, store):
|
||||||
|
"""验证按 ID 查找匹配规则。"""
|
||||||
|
rule = store.create_rule(
|
||||||
|
name="test",
|
||||||
|
condition_template="premium",
|
||||||
|
code='def rule(facts):\n if facts.get("sub") == "premium":\n return {"action": "ok"}\n return None'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 匹配
|
||||||
|
rule_id, result = matcher.find_matching_rule({"sub": "premium"}, rule_id=rule.id)
|
||||||
|
assert rule_id == rule.id
|
||||||
|
assert result == {"action": "ok"}
|
||||||
|
|
||||||
|
# 不匹配
|
||||||
|
rule_id, result = matcher.find_matching_rule({"sub": "basic"}, rule_id=rule.id)
|
||||||
|
assert rule_id is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_matching_rule_all_active(matcher, store):
|
||||||
|
"""验证匹配所有活跃规则。"""
|
||||||
|
# 创建两条规则
|
||||||
|
rule1 = store.create_rule(
|
||||||
|
name="rule1",
|
||||||
|
condition_template="c1",
|
||||||
|
code='def rule(facts):\n if facts.get("type") == "a":\n return {"action": "a"}\n return None',
|
||||||
|
priority=1
|
||||||
|
)
|
||||||
|
rule2 = store.create_rule(
|
||||||
|
name="rule2",
|
||||||
|
condition_template="c2",
|
||||||
|
code='def rule(facts):\n if facts.get("type") == "b":\n return {"action": "b"}\n return None',
|
||||||
|
priority=2
|
||||||
|
)
|
||||||
|
|
||||||
|
# 匹配 rule2(优先级更高)
|
||||||
|
rule_id, result = matcher.find_matching_rule({"type": "b"})
|
||||||
|
assert rule_id == rule2.id
|
||||||
|
assert result["action"] == "b"
|
||||||
|
|
||||||
|
# 匹配 rule1
|
||||||
|
rule_id, result = matcher.find_matching_rule({"type": "a"})
|
||||||
|
assert rule_id == rule1.id
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_matching_rule_no_match(matcher, store):
|
||||||
|
"""验证无匹配时返回 None。"""
|
||||||
|
store.create_rule(
|
||||||
|
name="nomatch",
|
||||||
|
condition_template="c1",
|
||||||
|
code='def rule(facts):\n return None'
|
||||||
|
)
|
||||||
|
|
||||||
|
rule_id, result = matcher.find_matching_rule({"type": "c"})
|
||||||
|
assert rule_id is None
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_matching_rule_inactive_rule(matcher, store):
|
||||||
|
"""验证不匹配非活跃规则。"""
|
||||||
|
rule = store.create_rule(
|
||||||
|
name="inactive",
|
||||||
|
condition_template="c1",
|
||||||
|
code='def rule(facts):\n return {"action": "ok"}'
|
||||||
|
)
|
||||||
|
# 手动设置为非活跃(通过直接操作)
|
||||||
|
store.delete_rule(rule.id)
|
||||||
|
|
||||||
|
rule_id, result = matcher.find_matching_rule({"type": "anything"})
|
||||||
|
assert rule_id is None
|
||||||
79
tests/test_models.py
Normal file
79
tests/test_models.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
"""Tests for data models."""
|
||||||
|
import pytest
|
||||||
|
from rule_engine.models import Rule, RuleExecution, CreateRuleRequest, EvaluateRequest, EvaluateResponse
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_rule_request_valid():
|
||||||
|
"""验证创建规则的请求可以正确解析。"""
|
||||||
|
request = CreateRuleRequest(
|
||||||
|
name="test_rule",
|
||||||
|
description="测试规则",
|
||||||
|
condition_template="用户 {{user_id}} 年龄大于 {{age}}"
|
||||||
|
)
|
||||||
|
assert request.name == "test_rule"
|
||||||
|
assert request.description == "测试规则"
|
||||||
|
assert "{{user_id}}" in request.condition_template
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_rule_request_minimal():
|
||||||
|
"""验证最小请求只需必填字段。"""
|
||||||
|
request = CreateRuleRequest(
|
||||||
|
name="minimal_rule",
|
||||||
|
condition_template="如果用户是会员则打折"
|
||||||
|
)
|
||||||
|
assert request.name == "minimal_rule"
|
||||||
|
assert request.priority == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_rule_model():
|
||||||
|
"""验证规则模型包含所有必要字段。"""
|
||||||
|
rule = Rule(
|
||||||
|
id="rule_001",
|
||||||
|
name="premium_discount",
|
||||||
|
description="高级会员8折",
|
||||||
|
condition_template="用户是高级会员",
|
||||||
|
code="def rule_001(facts):\n if facts.get('subscription') == 'premium':\n return {'action': 'discount', 'rate': 0.8}\n return None"
|
||||||
|
)
|
||||||
|
assert rule.id == "rule_001"
|
||||||
|
assert rule.priority == 0
|
||||||
|
assert "premium" in rule.code
|
||||||
|
|
||||||
|
|
||||||
|
def test_rule_to_dict():
|
||||||
|
"""验证规则可以序列化为字典。"""
|
||||||
|
rule = Rule(
|
||||||
|
id="rule_001",
|
||||||
|
name="test",
|
||||||
|
condition_template="test",
|
||||||
|
code="def r(f): pass"
|
||||||
|
)
|
||||||
|
d = rule.to_dict()
|
||||||
|
assert d["id"] == "rule_001"
|
||||||
|
assert d["name"] == "test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rule_execution_model():
|
||||||
|
"""验证规则执行记录模型。"""
|
||||||
|
execution = RuleExecution(
|
||||||
|
id="exec_001",
|
||||||
|
rule_id="rule_001",
|
||||||
|
facts={"user_id": "u123", "subscription": "premium"},
|
||||||
|
result={"action": "discount", "rate": 0.8},
|
||||||
|
matched=True
|
||||||
|
)
|
||||||
|
assert execution.matched is True
|
||||||
|
assert execution.rule_id == "rule_001"
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_request():
|
||||||
|
"""验证校验请求模型。"""
|
||||||
|
req = EvaluateRequest(facts={"user_id": "u123"})
|
||||||
|
assert req.rule_id is None
|
||||||
|
assert req.facts["user_id"] == "u123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_response():
|
||||||
|
"""验证校验响应模型。"""
|
||||||
|
resp = EvaluateResponse(matched=True, rule_id="rule_001", result={"action": "ok"})
|
||||||
|
assert resp.matched is True
|
||||||
|
assert resp.conflict_warning is None
|
||||||
107
tests/test_store.py
Normal file
107
tests/test_store.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
"""Tests for SQLite store."""
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from rule_engine.store import RuleStore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_db():
|
||||||
|
"""创建临时数据库。"""
|
||||||
|
fd, path = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(fd)
|
||||||
|
yield path
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def store(temp_db):
|
||||||
|
"""创建规则存储实例。"""
|
||||||
|
return RuleStore(temp_db)
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_creates_tables(store):
|
||||||
|
"""验证初始化时创建表。"""
|
||||||
|
# 表已创建,无异常即通过
|
||||||
|
assert store.db_path is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_rule(store):
|
||||||
|
"""验证创建规则。"""
|
||||||
|
rule = store.create_rule(
|
||||||
|
name="test_rule",
|
||||||
|
condition_template="用户是会员",
|
||||||
|
code="def rule(facts): return {'action': 'ok'}"
|
||||||
|
)
|
||||||
|
assert rule.id is not None
|
||||||
|
assert rule.name == "test_rule"
|
||||||
|
assert rule.version == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_rule(store):
|
||||||
|
"""验证获取规则。"""
|
||||||
|
created = store.create_rule(
|
||||||
|
name="get_test",
|
||||||
|
condition_template="测试获取",
|
||||||
|
code="def rule(facts): return None"
|
||||||
|
)
|
||||||
|
fetched = store.get_rule(created.id)
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.id == created.id
|
||||||
|
assert fetched.name == "get_test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_rule_not_found(store):
|
||||||
|
"""验证获取不存在的规则返回None。"""
|
||||||
|
result = store.get_rule("nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_rules(store):
|
||||||
|
"""验证列出所有规则。"""
|
||||||
|
store.create_rule(name="r1", condition_template="c1", code="def r(f): pass")
|
||||||
|
store.create_rule(name="r2", condition_template="c2", code="def r(f): pass")
|
||||||
|
rules = store.list_rules()
|
||||||
|
assert len(rules) >= 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_rule(store):
|
||||||
|
"""验证删除规则。"""
|
||||||
|
created = store.create_rule(
|
||||||
|
name="delete_test",
|
||||||
|
condition_template="删除测试",
|
||||||
|
code="def rule(facts): return None"
|
||||||
|
)
|
||||||
|
result = store.delete_rule(created.id)
|
||||||
|
assert result is True
|
||||||
|
assert store.get_rule(created.id) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_rule_code(store):
|
||||||
|
"""验证更新规则代码。"""
|
||||||
|
created = store.create_rule(
|
||||||
|
name="update_test",
|
||||||
|
condition_template="更新测试",
|
||||||
|
code="def rule(facts): return None"
|
||||||
|
)
|
||||||
|
updated = store.update_rule_code(created.id, "def rule(f): return {'updated': True}")
|
||||||
|
assert updated is not None
|
||||||
|
assert "updated" in updated.code
|
||||||
|
assert updated.version == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_record_execution(store):
|
||||||
|
"""验证记录执行历史。"""
|
||||||
|
rule = store.create_rule(
|
||||||
|
name="exec_test",
|
||||||
|
condition_template="执行测试",
|
||||||
|
code="def rule(facts): return None"
|
||||||
|
)
|
||||||
|
execution = store.record_execution(
|
||||||
|
rule_id=rule.id,
|
||||||
|
facts={"test": "data"},
|
||||||
|
result={"action": "ok"},
|
||||||
|
matched=True
|
||||||
|
)
|
||||||
|
assert execution.id is not None
|
||||||
|
assert execution.rule_id == rule.id
|
||||||
Loading…
x
Reference in New Issue
Block a user