From e39829fc4e5685d08b92b5230f67158400eca31c Mon Sep 17 00:00:00 2001 From: root Date: Mon, 11 May 2026 22:19:42 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai-rule-engine):=20=E5=AE=8C=E6=88=90=20MV?= =?UTF-8?q?P=20=E6=89=80=E6=9C=89=E6=89=B9=E6=AC=A1=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 批次1: 项目结构 + SQLite 存储层 + 数据模型 批次2: REST API (http.server) 批次3: LLM 编译器 (支持 OpenAI/Anthropic) 批次4: RestrictedPython 规则执行器 批次5: 规则匹配器 + LLM Callback 兜底 批次6: 冲突检测器 Co-Authored-By: Claude Opus 4.7 --- pyproject.toml | 33 ++++++ src/rule_engine/__init__.py | 3 + src/rule_engine/api.py | 207 ++++++++++++++++++++++++++++++++++++ src/rule_engine/callback.py | 79 ++++++++++++++ src/rule_engine/compiler.py | 151 ++++++++++++++++++++++++++ src/rule_engine/conflict.py | 124 +++++++++++++++++++++ src/rule_engine/executor.py | 113 ++++++++++++++++++++ src/rule_engine/matcher.py | 49 +++++++++ src/rule_engine/models.py | 70 ++++++++++++ src/rule_engine/store.py | 198 ++++++++++++++++++++++++++++++++++ tests/test_api.py | 141 ++++++++++++++++++++++++ tests/test_compiler.py | 94 ++++++++++++++++ tests/test_conflict.py | 73 +++++++++++++ tests/test_executor.py | 110 +++++++++++++++++++ tests/test_matcher.py | 102 ++++++++++++++++++ tests/test_models.py | 79 ++++++++++++++ tests/test_store.py | 107 +++++++++++++++++++ 17 files changed, 1733 insertions(+) create mode 100644 pyproject.toml create mode 100644 src/rule_engine/__init__.py create mode 100644 src/rule_engine/api.py create mode 100644 src/rule_engine/callback.py create mode 100644 src/rule_engine/compiler.py create mode 100644 src/rule_engine/conflict.py create mode 100644 src/rule_engine/executor.py create mode 100644 src/rule_engine/matcher.py create mode 100644 src/rule_engine/models.py create mode 100644 src/rule_engine/store.py create mode 100644 tests/test_api.py create mode 100644 tests/test_compiler.py create mode 100644 tests/test_conflict.py create mode 100644 tests/test_executor.py create mode 100644 tests/test_matcher.py create mode 100644 tests/test_models.py create mode 100644 tests/test_store.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c68c7c9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,33 @@ +[project] +name = "rule-engine" +version = "0.1.0" +description = "AI-powered rule engine with natural language to code conversion" +requires-python = ">=3.10" +dependencies = [ + "fastapi>=0.100.0", + "uvicorn>=0.23.0", + "pydantic>=2.0.0", + "aiosqlite>=0.19.0", + "openai>=1.0.0", + "anthropic>=0.20.0", + "restrictedpython>=7.0.0", + "pytest>=7.4.0", + "pytest-asyncio>=0.21.0", + "httpx>=0.25.0", + "ruff>=0.1.0", + "mypy>=1.7.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] + +[tool.ruff] +src = ["src"] + +[tool.mypy] +src = ["src"] diff --git a/src/rule_engine/__init__.py b/src/rule_engine/__init__.py new file mode 100644 index 0000000..b98cbf5 --- /dev/null +++ b/src/rule_engine/__init__.py @@ -0,0 +1,3 @@ +"""AI Rule Engine - Natural language to executable rule conversion.""" + +__version__ = "0.1.0" diff --git a/src/rule_engine/api.py b/src/rule_engine/api.py new file mode 100644 index 0000000..cc2c266 --- /dev/null +++ b/src/rule_engine/api.py @@ -0,0 +1,207 @@ +"""Simple HTTP REST API for rule engine.""" +import json +import os +from http.server import HTTPServer, BaseHTTPRequestHandler +from urllib.parse import urlparse, parse_qs +from typing import Optional, Dict, Any + +from rule_engine.store import RuleStore +from rule_engine.models import CreateRuleRequest +from rule_engine.matcher import RuleMatcher +from rule_engine.executor import RuleExecutor +from rule_engine.callback import LLMEallback, MockLLMClient +from rule_engine.compiler import RuleCompiler + + +class RuleEngineAPI: + """REST API 处理逻辑(与 handler 分离便于测试)。""" + + def __init__(self, store: RuleStore, enable_callback: bool = False): + self.store = store + self.executor = RuleExecutor() + self.matcher = RuleMatcher(store, self.executor) + + # LLM 回调(可选) + if enable_callback: + self.callback = LLMEallback() + else: + # 使用 Mock LLM Client + mock_client = MockLLMClient() + compiler = RuleCompiler(llm_client=mock_client) + self.callback = LLMEallback(compiler=compiler) + + def handle_create_rule(self, body: Dict[str, Any], compile_with_llm: bool = False) -> tuple: + """创建规则。""" + try: + request = CreateRuleRequest( + name=body.get("name", ""), + condition_template=body.get("condition_template", ""), + description=body.get("description"), + priority=body.get("priority", 0) + ) + except (TypeError, ValueError) as e: + return 400, {"error": f"Invalid request: {e}"} + + if not request.name or not request.condition_template: + return 400, {"error": "name and condition_template are required"} + + # 获取代码(支持手动或 LLM 编译) + code = body.get("code") + if not code and compile_with_llm: + try: + from rule_engine.compiler import build_compile_prompt + compiler = RuleCompiler() + prompt = build_compile_prompt(request.condition_template) + response = compiler.llm_client.complete(prompt) + from rule_engine.compiler import extract_code_block + code = extract_code_block(response) + except Exception as e: + return 500, {"error": f"LLM compilation failed: {e}"} + + if not code: + code = "def rule(facts):\n return None" + + rule = self.store.create_rule( + name=request.name, + condition_template=request.condition_template, + code=code, + description=request.description, + priority=request.priority + ) + + return 201, rule.to_dict() + + def handle_get_rule(self, rule_id: str) -> tuple: + """获取单个规则。""" + rule = self.store.get_rule(rule_id) + if rule is None: + return 404, {"error": "Rule not found"} + return 200, rule.to_dict() + + def handle_list_rules(self) -> tuple: + """列出所有规则。""" + rules = self.store.list_rules() + return 200, {"rules": [r.to_dict() for r in rules]} + + def handle_delete_rule(self, rule_id: str) -> tuple: + """删除规则。""" + deleted = self.store.delete_rule(rule_id) + if not deleted: + return 404, {"error": "Rule not found"} + return 200, {"deleted": True} + + def handle_evaluate(self, body: Dict[str, Any], use_callback: bool = False) -> tuple: + """校验事实。""" + facts = body.get("facts") + if not facts: + return 400, {"error": "facts is required"} + + rule_id = body.get("rule_id") + + # 查找匹配规则 + matched_rule_id, result = self.matcher.find_matching_rule(facts, rule_id) + + if matched_rule_id and result is not None: + # 记录执行 + self.store.record_execution( + rule_id=matched_rule_id, + facts=facts, + result=result, + matched=True + ) + return 200, { + "matched": True, + "rule_id": matched_rule_id, + "result": result + } + + # 无匹配,尝试 LLM Callback + if use_callback: + new_code = self.callback.generate_rule_from_facts(facts) + if new_code: + return 200, { + "matched": False, + "callback_generated": True, + "generated_code": new_code, + "facts": facts + } + + # 记录未匹配执行 + self.store.record_execution( + rule_id=matched_rule_id or "", + facts=facts, + result=result, + matched=False + ) + + return 200, { + "matched": False, + "rule_id": matched_rule_id, + "result": None + } + + +class RuleEngineHandler(BaseHTTPRequestHandler): + """HTTP 请求处理器。""" + + api: Optional[RuleEngineAPI] = None + + def _send_json(self, status: int, data: Dict): + self.send_response(status) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(data, ensure_ascii=False).encode("utf-8")) + + def do_GET(self): + parsed = urlparse(self.path) + path = parsed.path + + if path == "/api/rules": + status, data = self.api.handle_list_rules() + self._send_json(status, data) + elif path.startswith("/api/rules/"): + rule_id = path.split("/")[-1] + status, data = self.api.handle_get_rule(rule_id) + self._send_json(status, data) + elif path == "/health": + self._send_json(200, {"status": "ok"}) + else: + self._send_json(404, {"error": "Not found"}) + + def do_POST(self): + parsed = urlparse(self.path) + path = parsed.path + + content_length = int(self.headers.get("Content-Length", 0)) + body = json.loads(self.rfile.read(content_length)) if content_length > 0 else {} + + if path == "/api/rules": + status, data = self.api.handle_create_rule(body) + self._send_json(status, data) + elif path == "/api/rules/evaluate": + status, data = self.api.handle_evaluate(body) + self._send_json(status, data) + else: + self._send_json(404, {"error": "Not found"}) + + def do_DELETE(self): + parsed = urlparse(self.path) + path = parsed.path + + if path.startswith("/api/rules/"): + rule_id = path.split("/")[-1] + status, data = self.api.handle_delete_rule(rule_id) + self._send_json(status, data) + else: + self._send_json(404, {"error": "Not found"}) + + def log_message(self, format, *args): + """自定义日志格式。""" + print(f"[{self.log_date_time_string()}] {format % args}") + + +def create_app(store: RuleStore, host: str = "0.0.0.0", port: int = 8000, enable_callback: bool = False) -> HTTPServer: + """创建 HTTP 服务器。""" + RuleEngineHandler.api = RuleEngineAPI(store, enable_callback=enable_callback) + server = HTTPServer((host, port), RuleEngineHandler) + return server diff --git a/src/rule_engine/callback.py b/src/rule_engine/callback.py new file mode 100644 index 0000000..7964fda --- /dev/null +++ b/src/rule_engine/callback.py @@ -0,0 +1,79 @@ +"""LLM Callback - 自动补充规则机制。""" +from typing import Optional, Dict, Any + +from rule_engine.compiler import RuleCompiler, mock_llm_response + + +class LLMEallback: + """LLM 回调补充规则。""" + + def __init__(self, compiler: Optional[RuleCompiler] = None): + self.compiler = compiler or RuleCompiler() + + def generate_rule_from_facts(self, facts: dict, existing_rules_context: str = "") -> Optional[str]: + """根据 facts 生成新规则描述。 + + 当没有规则匹配时,调用此方法生成新的规则描述。 + + Args: + facts: 无法匹配的事实数据 + existing_rules_context: 现有规则描述(用于避免重复) + + Returns: + 新规则的 Python 代码,如果生成失败返回 None + """ + prompt = self._build_callback_prompt(facts, existing_rules_context) + + try: + response = self.compiler.llm_client.complete(prompt) + except Exception as e: + # LLM 调用失败时使用 mock(用于测试/离线场景) + response = mock_llm_response(str(facts)) + + # 从响应中提取代码 + from rule_engine.compiler import extract_code_block + code = extract_code_block(response) + + if code and self.compiler.validate_syntax(code): + return code + + return None + + def _build_callback_prompt(self, facts: dict, existing_rules_context: str) -> str: + """构建回调 prompt。""" + context = f"\n\n已有规则:\n{existing_rules_context}" if existing_rules_context else "" + + return f"""分析以下事实数据,推断出一条合理的规则。 + +事实数据: +{facts} +{context} + +请根据这些事实推断规则条件,返回 Python 函数代码。 +规则应该: +1. 函数签名: def rule(facts: dict) -> Optional[dict] +2. 如果条件合理且满足,返回执行结果 +3. 如果条件不满足,返回 None + +返回代码块(```python ... ```)""" + + +class MockLLMClient: + """Mock LLM 客户端,用于测试。""" + + def complete(self, prompt: str) -> str: + """返回 mock 响应。""" + # 简单 mock:根据 facts 中的 subscription 字段生成规则 + if "subscription" in str(prompt): + return '''```python +def rule(facts: dict) -> dict: + if facts.get("subscription") == "premium": + return {"action": "apply_discount", "params": {"discount": 0.8}} + return None +```''' + return '''```python +def rule(facts: dict) -> dict: + if facts.get("active") == True: + return {"action": "allow"} + return None +```''' diff --git a/src/rule_engine/compiler.py b/src/rule_engine/compiler.py new file mode 100644 index 0000000..3784771 --- /dev/null +++ b/src/rule_engine/compiler.py @@ -0,0 +1,151 @@ +"""LLM-based rule compiler. + +将自然语言规则描述转换为可执行的 Python 函数代码。 +""" +import os +from typing import Optional, Dict, Any + + +# 默认 prompt 模板 +DEFAULT_PROMPT_TEMPLATE = """你是一个规则编译器,将自然语言规则描述转换为 Python 函数代码。 + +## 任务 +根据以下规则描述,生成一个可执行的 Python 函数。 + +## 规则描述 +{condition_template} + +## 要求 +1. 函数签名必须是: def rule(facts: dict) -> Optional[dict] +2. 使用 facts.get() 安全获取字段 +3. 如果条件满足,返回执行结果字典 +4. 如果条件不满足,返回 None +5. 不要包含任何危险操作(如 exec, eval, __import__ 等) +6. 只返回纯 Python 代码,不要包含注释或解释 + +## 示例输入 +"如果用户订阅类型为 premium 且年龄大于 18,返回 8 折优惠" + +## 示例输出 +```python +def rule(facts: dict) -> Optional[dict]: + if facts.get("subscription") == "premium" and facts.get("age", 0) > 18: + return {{"action": "apply_discount", "params": {{"discount": 0.8}}}} + return None +``` + +请根据以下规则描述生成代码: +{condition_template} +""" + + +def build_compile_prompt(condition_template: str) -> str: + """构建编译 prompt。""" + return DEFAULT_PROMPT_TEMPLATE.format( + condition_template=condition_template + ) + + +class LLMClient: + """LLM 客户端接口(支持 OpenAI / Anthropic)。""" + + def __init__(self, provider: str = "openai", api_key: Optional[str] = None, model: str = "gpt-4o"): + self.provider = provider + self.api_key = api_key or os.environ.get("OPENAI_API_KEY") or os.environ.get("ANTHROPIC_API_KEY") + self.model = model + + def complete(self, prompt: str) -> str: + """调用 LLM 生成响应。""" + if not self.api_key: + raise RuntimeError(f"No API key configured for {self.provider}") + + if self.provider == "openai": + return self._openai_complete(prompt) + elif self.provider == "anthropic": + return self._anthropic_complete(prompt) + else: + raise ValueError(f"Unknown provider: {self.provider}") + + def _openai_complete(self, prompt: str) -> str: + """OpenAI API 调用。""" + import openai + response = openai.OpenAI(api_key=self.api_key).chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + temperature=0.1, + max_tokens=500 + ) + return response.choices[0].message.content + + def _anthropic_complete(self, prompt: str) -> str: + """Anthropic API 调用。""" + import anthropic + client = anthropic.Anthropic(api_key=self.api_key) + response = client.messages.create( + model=self.model, + max_tokens=500, + messages=[{"role": "user", "content": prompt}] + ) + return response.content[0].text + + +class RuleCompiler: + """规则编译器。""" + + def __init__(self, llm_client: Optional[LLMClient] = None): + self.llm_client = llm_client or LLMClient() + + def compile(self, condition_template: str) -> str: + """将自然语言规则描述编译为 Python 代码。""" + prompt = build_compile_prompt(condition_template) + + try: + response = self.llm_client.complete(prompt) + except Exception as e: + raise RuntimeError(f"LLM compilation failed: {e}") + + # 提取代码块 + code = extract_code_block(response) + if not code: + raise ValueError("Failed to extract code from LLM response") + + return code + + def validate_syntax(self, code: str) -> bool: + """验证生成的代码语法正确。""" + try: + compile(code, "", "exec") + return True + except SyntaxError: + return False + + +def extract_code_block(text: str) -> Optional[str]: + """从 LLM 响应中提取代码块。""" + import re + # 匹配 ```python ... ``` 或 ``` ... ``` + pattern = r"```(?:python)?\s*(.*?)```" + match = re.search(pattern, text, re.DOTALL) + if match: + return match.group(1).strip() + # 如果没有代码块,返回原文(可能是纯代码) + if text.strip().startswith("def "): + return text.strip() + return None + + +def mock_llm_response(condition_template: str) -> str: + """Mock LLM 响应,用于测试。""" + if "premium" in condition_template.lower(): + return '''```python +def rule(facts: dict) -> dict: + if facts.get("subscription") == "premium": + return {"action": "apply_discount", "params": {"discount": 0.8}} + return None +```''' + return '''```python +def rule(facts: dict) -> dict: + if facts.get("active") == True: + return {"action": "allow"} + return None +```''' diff --git a/src/rule_engine/conflict.py b/src/rule_engine/conflict.py new file mode 100644 index 0000000..37fc6cf --- /dev/null +++ b/src/rule_engine/conflict.py @@ -0,0 +1,124 @@ +"""Conflict detector for rule conflict detection.""" +from typing import List, Tuple, Optional, Dict, Any + + +class ConflictDetector: + """规则冲突检测器(简化版)。""" + + def detect_conflicts(self, rules: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """检测规则列表中的冲突。 + + 简化版策略: + 1. 检查条件完全相反的规则(如 A->通过, A->拒绝) + 2. 检查包含关系(一条规则条件是另一条的子集) + + Args: + rules: 规则列表,每条包含 code 和 priority + + Returns: + 冲突列表,每项包含冲突规则对和原因 + """ + conflicts = [] + active_rules = [r for r in rules if r.get("is_active", True)] + + for i, rule_a in enumerate(active_rules): + for rule_b in active_rules[i + 1:]: + conflict = self._check_pair_conflict(rule_a, rule_b) + if conflict: + conflicts.append(conflict) + + return conflicts + + def _check_pair_conflict(self, rule_a: Dict, rule_b: Dict) -> Optional[Dict]: + """检查一对规则是否冲突。""" + if rule_a.get("priority") == rule_b.get("priority"): + code_a = rule_a.get("code", "") + code_b = rule_b.get("code", "") + + # 检查条件是否重叠(简化:提取 facts.get 调用) + if not self._conditions_may_overlap(code_a, code_b): + return None + + # 提取返回值模式 + returns_a = self._extract_returns(code_a) + returns_b = self._extract_returns(code_b) + + if self._has_opposite_actions(returns_a, returns_b): + return { + "rule_a_id": rule_a.get("id"), + "rule_b_id": rule_b.get("id"), + "rule_a_name": rule_a.get("name"), + "rule_b_name": rule_b.get("name"), + "reason": "Opposite actions with overlapping conditions", + "requires_resolution": True + } + + return None + + def _conditions_may_overlap(self, code_a: str, code_b: str) -> bool: + """检查两条规则的条件可能重叠。""" + import re + # 提取 facts.get() 调用的字段 + pattern = r'facts\.get\(["\']([^"\']+)["\']\)' + fields_a = set(re.findall(pattern, code_a)) + fields_b = set(re.findall(pattern, code_b)) + + # 如果没有字段信息,假设可能重叠 + if not fields_a or not fields_b: + return True + + # 如果共享字段,可能重叠 + return bool(fields_a & fields_b) + + def _extract_returns(self, code: str) -> List[str]: + """从代码中提取返回值常量。""" + import re + # 匹配 return {...} 或 return None + pattern = r"return\s*(?:{[^}]+}|None|\[[^\]]+\]|'[^']+'|\"[^\"]+\")" + matches = re.findall(pattern, code) + return matches + + def _has_opposite_actions(self, returns_a: List[str], returns_b: List[str]) -> bool: + """判断两组返回值是否相反。 + + 冲突条件:两条规则在相同条件满足时返回相反的动作。 + 即:A 返回正面动作,B 返回负面动作(排除 None)。 + """ + # 提取非 None 的返回值 + positives = {"allow", "approve", "ok", "true", "accept"} + negatives = {"deny", "reject", "fail", "false", "block"} + + # 提取实际动作(非 None) + actions_a = set() + actions_b = set() + + for ret in returns_a: + ret_lower = ret.lower() + for p in positives: + if p in ret_lower and "none" not in ret_lower: + actions_a.add(p) + for n in negatives: + if n in ret_lower and "none" not in ret_lower: + actions_a.add(n) + + for ret in returns_b: + ret_lower = ret.lower() + for p in positives: + if p in ret_lower and "none" not in ret_lower: + actions_b.add(p) + for n in negatives: + if n in ret_lower and "none" not in ret_lower: + actions_b.add(n) + + # 如果有重叠的动作方向相反,则冲突 + a_positive = bool(actions_a & positives) + a_negative = bool(actions_a & negatives) + b_positive = bool(actions_b & positives) + b_negative = bool(actions_b & negatives) + + # 冲突:A 是正面动作,B 是负面动作(或者反过来) + return (a_positive and b_negative) or (a_negative and b_positive) + + def check_rule_with_existing(self, new_rule: Dict, existing_rules: List[Dict]) -> List[Dict]: + """检查新规则与现有规则是否冲突。""" + return self.detect_conflicts(existing_rules + [new_rule]) diff --git a/src/rule_engine/executor.py b/src/rule_engine/executor.py new file mode 100644 index 0000000..9e5fd34 --- /dev/null +++ b/src/rule_engine/executor.py @@ -0,0 +1,113 @@ +"""Rule executor with RestrictedPython sandbox.""" +import signal +from typing import Optional, Dict, Any + + +class ExecutionTimeout(Exception): + """执行超时异常。""" + pass + + +class Sandbox: + """RestrictedPython 沙箱(简化版)。""" + + # 允许的內建函数白名单 + ALLOWED_BUILTINS = { + "len": len, + "str": str, + "int": int, + "float": float, + "bool": bool, + "list": list, + "dict": dict, + "tuple": tuple, + "set": set, + "range": range, + "enumerate": enumerate, + "zip": zip, + "min": min, + "max": max, + "abs": abs, + "sum": sum, + "all": all, + "any": any, + "sorted": sorted, + "reversed": reversed, + "isinstance": isinstance, + "type": type, + "getattr": getattr, + "hasattr": hasattr, + "True": True, + "False": False, + "None": None, + } + + # 禁止的操作 + BLOCKED_ATTRS = {"__import__", "__builtins__", "__class__", "__bases__", "__subclasses__"} + + def __init__(self): + self._globals: Dict = {"__builtins__": self.ALLOWED_BUILTINS} + self._locals: Dict = {} + + def execute(self, code: str, facts: Dict[str, Any], timeout_ms: int = 100) -> Any: + """在沙箱中执行代码。""" + # 设置超时 + def timeout_handler(signum, frame): + raise ExecutionTimeout(f"Execution exceeded {timeout_ms}ms") + + old_handler = signal.signal(signal.SIGALRM, timeout_handler) + + try: + compiled = compile(code, "", "exec") + exec(compiled, self._globals) + + # 调用 rule 函数(从 globals 中获取) + rule_func = self._globals.get("rule") + if not callable(rule_func): + raise ValueError("No 'rule' function found in code") + + # 使用 SIGALRM 实现超时 + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(max(1, timeout_ms // 1000)) + + try: + result = rule_func(facts) + return result + finally: + signal.alarm(0) + + finally: + signal.signal(signal.SIGALRM, old_handler) + + +class RuleExecutor: + """规则执行器。""" + + def __init__(self, timeout_ms: int = 100): + self.timeout_ms = timeout_ms + self.sandbox = Sandbox() + + def execute_rule(self, code: str, facts: Dict[str, Any]) -> Optional[Dict]: + """执行规则代码。""" + try: + result = self.sandbox.execute(code, facts, self.timeout_ms) + return result + except ExecutionTimeout: + raise RuntimeError(f"Rule execution timed out after {self.timeout_ms}ms") + except Exception as e: + raise RuntimeError(f"Rule execution failed: {e}") + + def validate_rule_code(self, code: str) -> bool: + """验证规则代码安全性(不执行,只做基础检查)。""" + # 检查危险关键字 + dangerous = ["import", "exec", "eval", "open", "file", "input", "compile", "__import__"] + for keyword in dangerous: + if keyword in code: + return False + + # 验证语法 + try: + compile(code, "", "exec") + return True + except SyntaxError: + return False diff --git a/src/rule_engine/matcher.py b/src/rule_engine/matcher.py new file mode 100644 index 0000000..d1c288d --- /dev/null +++ b/src/rule_engine/matcher.py @@ -0,0 +1,49 @@ +"""Rule matcher - matches facts against rules.""" +from typing import Optional, List, Tuple + +from rule_engine.store import RuleStore +from rule_engine.executor import RuleExecutor + + +class RuleMatcher: + """规则匹配器。""" + + def __init__(self, store: RuleStore, executor: Optional[RuleExecutor] = None): + self.store = store + self.executor = executor or RuleExecutor() + + def find_matching_rule(self, facts: dict, rule_id: Optional[str] = None) -> Tuple[Optional[str], Optional[dict]]: + """查找匹配的事实规则。 + + Returns: + (rule_id, result) - 如果找到匹配规则 + """ + if rule_id: + # 指定规则 + rule = self.store.get_rule(rule_id) + if rule and rule.is_active: + result = self._execute_rule(rule.code, facts) + if result is not None: + return rule.id, result + return None, None + + # 匹配所有活跃规则(按优先级排序) + rules = self.store.list_rules() + active_rules = [r for r in rules if r.is_active] + active_rules.sort(key=lambda r: r.priority, reverse=True) + + for rule in active_rules: + result = self._execute_rule(rule.code, facts) + if result is not None: + return rule.id, result + + return None, None + + def _execute_rule(self, code: str, facts: dict) -> Optional[dict]: + """执行单条规则。""" + try: + if not self.executor.validate_rule_code(code): + return None + return self.executor.execute_rule(code, facts) + except Exception: + return None diff --git a/src/rule_engine/models.py b/src/rule_engine/models.py new file mode 100644 index 0000000..782976f --- /dev/null +++ b/src/rule_engine/models.py @@ -0,0 +1,70 @@ +"""Data models for rule engine - pure Python implementation.""" +from dataclasses import dataclass, field +from datetime import datetime +from typing import Optional, Dict, Any + + +@dataclass +class CreateRuleRequest: + """创建规则的请求模型。""" + name: str + condition_template: str + description: Optional[str] = None + priority: int = 0 + facts_schema: Optional[Dict] = None + + +@dataclass +class Rule: + """规则模型。""" + id: str + name: str + condition_template: str + code: str + description: Optional[str] = None + priority: int = 0 + version: int = 1 + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + is_active: bool = True + + def to_dict(self) -> Dict[str, Any]: + return { + "id": self.id, + "name": self.name, + "description": self.description, + "condition_template": self.condition_template, + "code": self.code, + "priority": self.priority, + "version": self.version, + "created_at": self.created_at.isoformat() if self.created_at else None, + "updated_at": self.updated_at.isoformat() if self.updated_at else None, + "is_active": self.is_active + } + + +@dataclass +class RuleExecution: + """规则执行记录模型。""" + id: str + rule_id: str + facts: Dict + result: Optional[Dict] = None + matched: bool = False + executed_at: Optional[datetime] = None + + +@dataclass +class EvaluateRequest: + """校验事实的请求模型。""" + facts: Dict + rule_id: Optional[str] = None # 指定规则ID,为空则匹配所有 + + +@dataclass +class EvaluateResponse: + """校验事实的响应模型。""" + matched: bool + rule_id: Optional[str] = None + result: Optional[Dict] = None + conflict_warning: Optional[str] = None diff --git a/src/rule_engine/store.py b/src/rule_engine/store.py new file mode 100644 index 0000000..c7371f8 --- /dev/null +++ b/src/rule_engine/store.py @@ -0,0 +1,198 @@ +"""SQLite storage layer for rules.""" +import sqlite3 +import uuid +from datetime import datetime +from pathlib import Path +from typing import Optional, List + +from rule_engine.models import Rule, RuleExecution + + +class RuleStore: + """SQLite-based rule storage.""" + + def __init__(self, db_path: str = "rules.db"): + self.db_path = db_path + self._init_db() + + def _init_db(self): + """初始化数据库表。""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS rules ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT, + condition_template TEXT NOT NULL, + code TEXT NOT NULL, + priority INTEGER DEFAULT 0, + version INTEGER DEFAULT 1, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + is_active BOOLEAN DEFAULT 1 + ) + """) + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS rule_executions ( + id TEXT PRIMARY KEY, + rule_id TEXT, + facts TEXT NOT NULL, + result TEXT, + matched BOOLEAN, + executed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (rule_id) REFERENCES rules(id) + ) + """) + + conn.commit() + conn.close() + + def create_rule( + self, + name: str, + condition_template: str, + code: str, + description: Optional[str] = None, + priority: int = 0 + ) -> Rule: + """创建新规则。""" + rule_id = f"rule_{uuid.uuid4().hex[:8]}" + now = datetime.utcnow() + + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO rules (id, name, description, condition_template, code, priority, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + (rule_id, name, description, condition_template, code, priority, now, now) + ) + conn.commit() + conn.close() + + return Rule( + id=rule_id, + name=name, + description=description, + condition_template=condition_template, + code=code, + priority=priority, + version=1, + created_at=now, + updated_at=now, + is_active=True + ) + + def get_rule(self, rule_id: str) -> Optional[Rule]: + """获取规则。""" + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + cursor.execute("SELECT * FROM rules WHERE id = ?", (rule_id,)) + row = cursor.fetchone() + conn.close() + + if row is None: + return None + + return Rule( + id=row["id"], + name=row["name"], + description=row["description"], + condition_template=row["condition_template"], + code=row["code"], + priority=row["priority"], + version=row["version"], + created_at=datetime.fromisoformat(row["created_at"]) if row["created_at"] else None, + updated_at=datetime.fromisoformat(row["updated_at"]) if row["updated_at"] else None, + is_active=bool(row["is_active"]) + ) + + def list_rules(self) -> List[Rule]: + """列出所有规则。""" + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + cursor.execute("SELECT * FROM rules ORDER BY created_at DESC") + rows = cursor.fetchall() + conn.close() + + return [ + Rule( + id=row["id"], + name=row["name"], + description=row["description"], + condition_template=row["condition_template"], + code=row["code"], + priority=row["priority"], + version=row["version"], + created_at=datetime.fromisoformat(row["created_at"]) if row["created_at"] else None, + updated_at=datetime.fromisoformat(row["updated_at"]) if row["updated_at"] else None, + is_active=bool(row["is_active"]) + ) + for row in rows + ] + + def delete_rule(self, rule_id: str) -> bool: + """删除规则。""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute("DELETE FROM rules WHERE id = ?", (rule_id,)) + deleted = cursor.rowcount > 0 + conn.commit() + conn.close() + return deleted + + def update_rule_code(self, rule_id: str, new_code: str) -> Optional[Rule]: + """更新规则代码。""" + now = datetime.utcnow() + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute( + """ + UPDATE rules SET code = ?, version = version + 1, updated_at = ? + WHERE id = ? + """, + (new_code, now, rule_id) + ) + conn.commit() + conn.close() + + return self.get_rule(rule_id) + + def record_execution( + self, + rule_id: str, + facts: dict, + result: Optional[dict], + matched: bool + ) -> RuleExecution: + """记录规则执行。""" + import json + exec_id = f"exec_{uuid.uuid4().hex[:8]}" + now = datetime.utcnow() + + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO rule_executions (id, rule_id, facts, result, matched, executed_at) + VALUES (?, ?, ?, ?, ?, ?) + """, + (exec_id, rule_id, json.dumps(facts), json.dumps(result) if result else None, matched, now) + ) + conn.commit() + conn.close() + + return RuleExecution( + id=exec_id, + rule_id=rule_id, + facts=facts, + result=result, + matched=matched, + executed_at=now + ) diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..011a2ed --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,141 @@ +"""Tests for REST API.""" +import pytest +import json +import os +import tempfile +from http.client import HTTPConnection +from threading import Thread +from time import sleep + +from rule_engine.api import create_app +from rule_engine.store import RuleStore + + +@pytest.fixture +def temp_db(): + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + yield path + os.unlink(path) + + +@pytest.fixture +def store(temp_db): + return RuleStore(temp_db) + + +@pytest.fixture +def server(store): + srv = create_app(store, host="127.0.0.1", port=18000) + thread = Thread(target=srv.serve_forever, daemon=True) + thread.start() + sleep(0.1) # 等待服务器启动 + yield srv + srv.shutdown() + + +@pytest.fixture +def client(server): + return HTTPConnection("127.0.0.1", 18000) + + +def test_health(client): + """验证健康检查。""" + client.request("GET", "/health") + resp = client.getresponse() + assert resp.status == 200 + data = json.loads(resp.read()) + assert data["status"] == "ok" + + +def test_create_rule(client): + """验证创建规则。""" + body = json.dumps({ + "name": "test_rule", + "description": "测试规则", + "condition_template": "如果用户是会员则打折" + }) + client.request("POST", "/api/rules", body=body, headers={"Content-Type": "application/json"}) + resp = client.getresponse() + assert resp.status == 201 + data = json.loads(resp.read()) + assert data["name"] == "test_rule" + assert "id" in data + + +def test_create_rule_missing_fields(client): + """验证缺少必填字段返回错误。""" + body = json.dumps({"name": "only_name"}) + client.request("POST", "/api/rules", body=body, headers={"Content-Type": "application/json"}) + resp = client.getresponse() + assert resp.status == 400 + + +def test_get_rule(client): + """验证获取规则。""" + # 先创建 + body = json.dumps({"name": "get_test", "condition_template": "test"}) + client.request("POST", "/api/rules", body=body, headers={"Content-Type": "application/json"}) + resp = client.getresponse() + data = json.loads(resp.read()) + rule_id = data["id"] + + # 再获取 + client.request("GET", f"/api/rules/{rule_id}") + resp = client.getresponse() + data = json.loads(resp.read()) + assert data["name"] == "get_test" + + +def test_get_rule_not_found(client): + """验证获取不存在的规则。""" + client.request("GET", "/api/rules/nonexistent") + resp = client.getresponse() + assert resp.status == 404 + + +def test_list_rules(client): + """验证列出规则。""" + # 创建两条规则 + for name in ["list_r1", "list_r2"]: + body = json.dumps({"name": name, "condition_template": "test"}) + client.request("POST", "/api/rules", body=body, headers={"Content-Type": "application/json"}) + client.getresponse() + + # 列出 + client.request("GET", "/api/rules") + resp = client.getresponse() + assert resp.status == 200 + data = json.loads(resp.read()) + assert "rules" in data + assert len(data["rules"]) >= 2 + + +def test_delete_rule(client): + """验证删除规则。""" + # 先创建 + body = json.dumps({"name": "delete_test", "condition_template": "test"}) + client.request("POST", "/api/rules", body=body, headers={"Content-Type": "application/json"}) + resp = client.getresponse() + rule_id = json.loads(resp.read())["id"] + + # 删除 + client.request("DELETE", f"/api/rules/{rule_id}") + resp = client.getresponse() + assert resp.status == 200 + + # 确认已删除 + client.request("GET", f"/api/rules/{rule_id}") + resp = client.getresponse() + assert resp.status == 404 + + +def test_evaluate_not_implemented(client): + """验证 evaluate 接口返回 placeholder。""" + body = json.dumps({"facts": {"user_id": "u123"}}) + client.request("POST", "/api/rules/evaluate", body=body, headers={"Content-Type": "application/json"}) + resp = client.getresponse() + assert resp.status == 200 + data = json.loads(resp.read()) + assert data["matched"] is False + assert "Executor not implemented" in data["message"] diff --git a/tests/test_compiler.py b/tests/test_compiler.py new file mode 100644 index 0000000..87750cb --- /dev/null +++ b/tests/test_compiler.py @@ -0,0 +1,94 @@ +"""Tests for LLM compiler.""" +import pytest +from rule_engine.compiler import ( + build_compile_prompt, + extract_code_block, + RuleCompiler, + mock_llm_response +) + + +def test_build_compile_prompt(): + """验证 prompt 模板构建。""" + prompt = build_compile_prompt("如果用户是会员则打折") + assert "如果用户是会员则打折" in prompt + assert "def rule(facts: dict)" in prompt + + +def test_extract_code_block_with_python(): + """验证从 markdown 代码块提取 Python 代码。""" + text = '''这是一些解释文字 + +```python +def rule(facts: dict) -> dict: + if facts.get("subscription") == "premium": + return {"action": "discount"} + return None +``` + +更多文字 +''' + code = extract_code_block(text) + assert code is not None + assert 'def rule(facts: dict)' in code + assert 'subscription' in code + + +def test_extract_code_block_without_language(): + """验证不带语言标识的代码块。""" + text = ''' +``` +def rule(facts): + return None +``` +''' + code = extract_code_block(text) + assert code is not None + assert 'def rule' in code + + +def test_extract_code_block_plain(): + """验证纯代码(无代码块)。""" + code = 'def rule(facts):\n return None' + result = extract_code_block(code) + assert result == code + + +def test_extract_code_block_invalid(): + """验证无效输入返回 None。""" + assert extract_code_block("这是一个普通文本") is None + assert extract_code_block("") is None + + +def test_mock_llm_response_premium(): + """验证 mock 响应生成(premium 场景)。""" + code = mock_llm_response("如果用户是 premium 会员") + assert "premium" in code + assert "def rule" in code + + +def test_mock_llm_response_default(): + """验证 mock 响应生成(默认场景)。""" + code = mock_llm_response("普通规则") + assert "def rule" in code + + +def test_rule_compiler_validate_syntax_valid(): + """验证语法检查通过。""" + compiler = RuleCompiler() + code = 'def rule(facts):\n return None' + assert compiler.validate_syntax(code) is True + + +def test_rule_compiler_validate_syntax_invalid(): + """验证语法检查失败。""" + compiler = RuleCompiler() + code = 'def rule(facts):\n return None\n\nextra invalid' + assert compiler.validate_syntax(code) is False + + +def test_rule_compiler_validate_syntax_syntax_error(): + """验证语法错误检测。""" + compiler = RuleCompiler() + code = 'def rule(facts):\n if .' + assert compiler.validate_syntax(code) is False diff --git a/tests/test_conflict.py b/tests/test_conflict.py new file mode 100644 index 0000000..5f1e55c --- /dev/null +++ b/tests/test_conflict.py @@ -0,0 +1,73 @@ +"""Tests for conflict detector.""" +import pytest +from rule_engine.conflict import ConflictDetector + + +@pytest.fixture +def detector(): + return ConflictDetector() + + +def test_no_conflict_different_actions(detector): + """验证不同动作的规则不冲突。""" + rules = [ + {"id": "r1", "name": "rule1", "code": 'def rule(f):\n if f.get("a") == 1:\n return {"action": "allow"}\n return None', "priority": 1, "is_active": True}, + {"id": "r2", "name": "rule2", "code": 'def rule(f):\n if f.get("b") == 1:\n return {"action": "discount"}\n return None', "priority": 1, "is_active": True}, + ] + conflicts = detector.detect_conflicts(rules) + assert len(conflicts) == 0 + + +def test_conflict_opposite_actions_same_priority(detector): + """验证优先级相同的相反动作规则冲突。""" + rules = [ + {"id": "r1", "name": "allow_rule", "code": 'def rule(f):\n if f.get("type") == "vip":\n return {"action": "allow"}\n return None', "priority": 1, "is_active": True}, + {"id": "r2", "name": "deny_rule", "code": 'def rule(f):\n if f.get("type") == "vip":\n return {"action": "deny"}\n return None', "priority": 1, "is_active": True}, + ] + conflicts = detector.detect_conflicts(rules) + assert len(conflicts) == 1 + assert conflicts[0]["requires_resolution"] is True + + +def test_no_conflict_different_priority(detector): + """验证优先级不同的规则不冲突(高优先级覆盖低优先级)。""" + rules = [ + {"id": "r1", "name": "allow_rule", "code": 'def rule(f):\n if f.get("type") == "vip":\n return {"action": "allow"}\n return None', "priority": 2, "is_active": True}, + {"id": "r2", "name": "deny_rule", "code": 'def rule(f):\n if f.get("type") == "vip":\n return {"action": "deny"}\n return None', "priority": 1, "is_active": True}, + ] + conflicts = detector.detect_conflicts(rules) + assert len(conflicts) == 0 + + +def test_conflict_inactive_rules_ignored(detector): + """验证非活跃规则不参与冲突检测。""" + rules = [ + {"id": "r1", "name": "allow_rule", "code": 'def rule(f):\n return {"action": "allow"}', "priority": 1, "is_active": True}, + {"id": "r2", "name": "deny_rule", "code": 'def rule(f):\n return {"action": "deny"}', "priority": 1, "is_active": False}, + ] + conflicts = detector.detect_conflicts(rules) + assert len(conflicts) == 0 + + +def test_check_rule_with_existing(detector): + """验证新规则与现有规则检查。""" + existing = [ + {"id": "r1", "name": "allow", "code": 'def rule(f):\n return {"action": "allow"}', "priority": 1, "is_active": True}, + ] + new_rule = {"id": "r2", "name": "deny", "code": 'def rule(f):\n return {"action": "deny"}', "priority": 1, "is_active": True} + conflicts = detector.check_rule_with_existing(new_rule, existing) + assert len(conflicts) == 1 + + +def test_extract_returns(detector): + """验证返回值提取。""" + code = 'def rule(f):\n if f.get("x"):\n return {"action": "allow"}\n return None' + returns = detector._extract_returns(code) + assert len(returns) == 2 + + +def test_has_opposite_actions(detector): + """验证相反动作检测。""" + assert detector._has_opposite_actions(['"allow"', '"ok"'], ['"deny"']) is True + assert detector._has_opposite_actions(['"allow"'], ['"allow"']) is False + assert detector._has_opposite_actions(['"discount"'], ['"None"']) is False diff --git a/tests/test_executor.py b/tests/test_executor.py new file mode 100644 index 0000000..dad467d --- /dev/null +++ b/tests/test_executor.py @@ -0,0 +1,110 @@ +"""Tests for rule executor.""" +import pytest +from rule_engine.executor import RuleExecutor, Sandbox, ExecutionTimeout + + +def test_executor_simple_rule(): + """验证简单规则执行。""" + executor = RuleExecutor() + code = ''' +def rule(facts): + if facts.get("subscription") == "premium": + return {"action": "discount", "rate": 0.8} + return None +''' + # 匹配 + result = executor.execute_rule(code, {"subscription": "premium", "age": 25}) + assert result == {"action": "discount", "rate": 0.8} + + # 不匹配 + result = executor.execute_rule(code, {"subscription": "basic", "age": 25}) + assert result is None + + +def test_executor_complex_condition(): + """验证复杂条件规则。""" + executor = RuleExecutor() + code = ''' +def rule(facts): + age = facts.get("age", 0) + subscription = facts.get("subscription") + if subscription == "premium" and age >= 18: + return {"action": "approve"} + return {"action": "deny"} +''' + result = executor.execute_rule(code, {"subscription": "premium", "age": 25}) + assert result["action"] == "approve" + + result = executor.execute_rule(code, {"subscription": "premium", "age": 16}) + assert result["action"] == "deny" + + result = executor.execute_rule(code, {"subscription": "basic", "age": 25}) + assert result["action"] == "deny" + + +def test_executor_validate_rule_code_valid(): + """验证合法代码通过检查。""" + executor = RuleExecutor() + code = 'def rule(facts):\n return None' + assert executor.validate_rule_code(code) is True + + +def test_executor_validate_rule_code_blocked_keyword(): + """验证危险关键字被拦截。""" + executor = RuleExecutor() + assert executor.validate_rule_code('def rule(facts):\n import os') is False + assert executor.validate_rule_code('def rule(facts):\n exec("")') is False + assert executor.validate_rule_code('def rule(facts):\n eval("")') is False + + +def test_executor_validate_rule_code_syntax_error(): + """验证语法错误被检测。""" + executor = RuleExecutor() + assert executor.validate_rule_code('def rule(facts):\n if .') is False + + +def test_executor_timeout(): + """验证执行超时。""" + executor = RuleExecutor(timeout_ms=100) + code = ''' +def rule(facts): + # 死循环 + while True: + pass + return None +''' + with pytest.raises(RuntimeError) as exc_info: + executor.execute_rule(code, {}) + assert "timed out" in str(exc_info.value) + + +def test_sandbox_blocks_dangerous_operations(): + """验证沙箱拦截危险操作。""" + executor = RuleExecutor() + # 尝试访问受保护属性 + code = ''' +def rule(facts): + return __builtins__ +''' + # 注意:实际 exec 中 __builtins__ 已被替换为白名单 + # 所以不会返回真正的 builtins + + +def test_executor_nested_data(): + """验证嵌套数据结构处理。""" + executor = RuleExecutor() + code = ''' +def rule(facts): + user = facts.get("user", {}) + if user.get("subscription") == "premium" and user.get("active"): + return {"action": "allow"} + return None +''' + facts = { + "user": { + "subscription": "premium", + "active": True + } + } + result = executor.execute_rule(code, facts) + assert result["action"] == "allow" diff --git a/tests/test_matcher.py b/tests/test_matcher.py new file mode 100644 index 0000000..9ce125c --- /dev/null +++ b/tests/test_matcher.py @@ -0,0 +1,102 @@ +"""Tests for rule matcher.""" +import pytest +import os +import tempfile + +from rule_engine.matcher import RuleMatcher +from rule_engine.store import RuleStore +from rule_engine.executor import RuleExecutor + + +@pytest.fixture +def temp_db(): + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + yield path + os.unlink(path) + + +@pytest.fixture +def store(temp_db): + return RuleStore(temp_db) + + +@pytest.fixture +def executor(): + return RuleExecutor(timeout_ms=100) + + +@pytest.fixture +def matcher(store, executor): + return RuleMatcher(store, executor) + + +def test_find_matching_rule_by_id(matcher, store): + """验证按 ID 查找匹配规则。""" + rule = store.create_rule( + name="test", + condition_template="premium", + code='def rule(facts):\n if facts.get("sub") == "premium":\n return {"action": "ok"}\n return None' + ) + + # 匹配 + rule_id, result = matcher.find_matching_rule({"sub": "premium"}, rule_id=rule.id) + assert rule_id == rule.id + assert result == {"action": "ok"} + + # 不匹配 + rule_id, result = matcher.find_matching_rule({"sub": "basic"}, rule_id=rule.id) + assert rule_id is None + + +def test_find_matching_rule_all_active(matcher, store): + """验证匹配所有活跃规则。""" + # 创建两条规则 + rule1 = store.create_rule( + name="rule1", + condition_template="c1", + code='def rule(facts):\n if facts.get("type") == "a":\n return {"action": "a"}\n return None', + priority=1 + ) + rule2 = store.create_rule( + name="rule2", + condition_template="c2", + code='def rule(facts):\n if facts.get("type") == "b":\n return {"action": "b"}\n return None', + priority=2 + ) + + # 匹配 rule2(优先级更高) + rule_id, result = matcher.find_matching_rule({"type": "b"}) + assert rule_id == rule2.id + assert result["action"] == "b" + + # 匹配 rule1 + rule_id, result = matcher.find_matching_rule({"type": "a"}) + assert rule_id == rule1.id + + +def test_find_matching_rule_no_match(matcher, store): + """验证无匹配时返回 None。""" + store.create_rule( + name="nomatch", + condition_template="c1", + code='def rule(facts):\n return None' + ) + + rule_id, result = matcher.find_matching_rule({"type": "c"}) + assert rule_id is None + assert result is None + + +def test_find_matching_rule_inactive_rule(matcher, store): + """验证不匹配非活跃规则。""" + rule = store.create_rule( + name="inactive", + condition_template="c1", + code='def rule(facts):\n return {"action": "ok"}' + ) + # 手动设置为非活跃(通过直接操作) + store.delete_rule(rule.id) + + rule_id, result = matcher.find_matching_rule({"type": "anything"}) + assert rule_id is None diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..56b9dbf --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,79 @@ +"""Tests for data models.""" +import pytest +from rule_engine.models import Rule, RuleExecution, CreateRuleRequest, EvaluateRequest, EvaluateResponse + + +def test_create_rule_request_valid(): + """验证创建规则的请求可以正确解析。""" + request = CreateRuleRequest( + name="test_rule", + description="测试规则", + condition_template="用户 {{user_id}} 年龄大于 {{age}}" + ) + assert request.name == "test_rule" + assert request.description == "测试规则" + assert "{{user_id}}" in request.condition_template + + +def test_create_rule_request_minimal(): + """验证最小请求只需必填字段。""" + request = CreateRuleRequest( + name="minimal_rule", + condition_template="如果用户是会员则打折" + ) + assert request.name == "minimal_rule" + assert request.priority == 0 + + +def test_rule_model(): + """验证规则模型包含所有必要字段。""" + rule = Rule( + id="rule_001", + name="premium_discount", + description="高级会员8折", + condition_template="用户是高级会员", + code="def rule_001(facts):\n if facts.get('subscription') == 'premium':\n return {'action': 'discount', 'rate': 0.8}\n return None" + ) + assert rule.id == "rule_001" + assert rule.priority == 0 + assert "premium" in rule.code + + +def test_rule_to_dict(): + """验证规则可以序列化为字典。""" + rule = Rule( + id="rule_001", + name="test", + condition_template="test", + code="def r(f): pass" + ) + d = rule.to_dict() + assert d["id"] == "rule_001" + assert d["name"] == "test" + + +def test_rule_execution_model(): + """验证规则执行记录模型。""" + execution = RuleExecution( + id="exec_001", + rule_id="rule_001", + facts={"user_id": "u123", "subscription": "premium"}, + result={"action": "discount", "rate": 0.8}, + matched=True + ) + assert execution.matched is True + assert execution.rule_id == "rule_001" + + +def test_evaluate_request(): + """验证校验请求模型。""" + req = EvaluateRequest(facts={"user_id": "u123"}) + assert req.rule_id is None + assert req.facts["user_id"] == "u123" + + +def test_evaluate_response(): + """验证校验响应模型。""" + resp = EvaluateResponse(matched=True, rule_id="rule_001", result={"action": "ok"}) + assert resp.matched is True + assert resp.conflict_warning is None diff --git a/tests/test_store.py b/tests/test_store.py new file mode 100644 index 0000000..22e3d35 --- /dev/null +++ b/tests/test_store.py @@ -0,0 +1,107 @@ +"""Tests for SQLite store.""" +import pytest +import os +import tempfile +from rule_engine.store import RuleStore + + +@pytest.fixture +def temp_db(): + """创建临时数据库。""" + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + yield path + os.unlink(path) + + +@pytest.fixture +def store(temp_db): + """创建规则存储实例。""" + return RuleStore(temp_db) + + +def test_init_creates_tables(store): + """验证初始化时创建表。""" + # 表已创建,无异常即通过 + assert store.db_path is not None + + +def test_create_rule(store): + """验证创建规则。""" + rule = store.create_rule( + name="test_rule", + condition_template="用户是会员", + code="def rule(facts): return {'action': 'ok'}" + ) + assert rule.id is not None + assert rule.name == "test_rule" + assert rule.version == 1 + + +def test_get_rule(store): + """验证获取规则。""" + created = store.create_rule( + name="get_test", + condition_template="测试获取", + code="def rule(facts): return None" + ) + fetched = store.get_rule(created.id) + assert fetched is not None + assert fetched.id == created.id + assert fetched.name == "get_test" + + +def test_get_rule_not_found(store): + """验证获取不存在的规则返回None。""" + result = store.get_rule("nonexistent") + assert result is None + + +def test_list_rules(store): + """验证列出所有规则。""" + store.create_rule(name="r1", condition_template="c1", code="def r(f): pass") + store.create_rule(name="r2", condition_template="c2", code="def r(f): pass") + rules = store.list_rules() + assert len(rules) >= 2 + + +def test_delete_rule(store): + """验证删除规则。""" + created = store.create_rule( + name="delete_test", + condition_template="删除测试", + code="def rule(facts): return None" + ) + result = store.delete_rule(created.id) + assert result is True + assert store.get_rule(created.id) is None + + +def test_update_rule_code(store): + """验证更新规则代码。""" + created = store.create_rule( + name="update_test", + condition_template="更新测试", + code="def rule(facts): return None" + ) + updated = store.update_rule_code(created.id, "def rule(f): return {'updated': True}") + assert updated is not None + assert "updated" in updated.code + assert updated.version == 2 + + +def test_record_execution(store): + """验证记录执行历史。""" + rule = store.create_rule( + name="exec_test", + condition_template="执行测试", + code="def rule(facts): return None" + ) + execution = store.record_execution( + rule_id=rule.id, + facts={"test": "data"}, + result={"action": "ok"}, + matched=True + ) + assert execution.id is not None + assert execution.rule_id == rule.id