diff --git a/src/rule_engine/api.py b/src/rule_engine/api.py index cc2c266..1235da3 100644 --- a/src/rule_engine/api.py +++ b/src/rule_engine/api.py @@ -11,6 +11,7 @@ from rule_engine.matcher import RuleMatcher from rule_engine.executor import RuleExecutor from rule_engine.callback import LLMEallback, MockLLMClient from rule_engine.compiler import RuleCompiler +from rule_engine.conflict import ConflictDetector class RuleEngineAPI: @@ -20,6 +21,7 @@ class RuleEngineAPI: self.store = store self.executor = RuleExecutor() self.matcher = RuleMatcher(store, self.executor) + self.conflict_detector = ConflictDetector() # LLM 回调(可选) if enable_callback: @@ -61,6 +63,17 @@ class RuleEngineAPI: if not code: code = "def rule(facts):\n return None" + # 检查与现有规则的冲突 + existing_rules = [r.to_dict() for r in self.store.list_rules()] + new_rule_dict = { + "id": "new_rule", + "name": request.name, + "code": code, + "priority": request.priority, + "is_active": True + } + conflicts = self.conflict_detector.check_rule_with_existing(new_rule_dict, existing_rules) + rule = self.store.create_rule( name=request.name, condition_template=request.condition_template, @@ -69,7 +82,11 @@ class RuleEngineAPI: priority=request.priority ) - return 201, rule.to_dict() + result = rule.to_dict() + if conflicts: + result["conflict_warning"] = conflicts[0] if len(conflicts) == 1 else conflicts + + return 201, result def handle_get_rule(self, rule_id: str) -> tuple: """获取单个规则。"""