NAC_Blockchain/nac-ai-compliance/src/rule_engine.rs

448 lines
12 KiB
Rust

//! 规则引擎模块
//!
//! 实现规则定义DSL、执行引擎、更新机制和冲突检测
use crate::compliance_layer::*;
use crate::ai_validator::ComplianceData;
use crate::error::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// 规则引擎
pub struct RuleEngine {
/// 规则集
rules: HashMap<String, Rule>,
/// 规则执行器
executor: RuleExecutor,
}
impl RuleEngine {
/// 创建新的规则引擎
pub fn new() -> Self {
Self {
rules: HashMap::new(),
executor: RuleExecutor::new(),
}
}
/// 添加规则
pub fn add_rule(&mut self, rule: Rule) -> Result<()> {
// 检查规则冲突
self.check_conflicts(&rule)?;
self.rules.insert(rule.id.clone(), rule);
Ok(())
}
/// 移除规则
pub fn remove_rule(&mut self, rule_id: &str) -> Option<Rule> {
self.rules.remove(rule_id)
}
/// 更新规则
pub fn update_rule(&mut self, rule: Rule) -> Result<()> {
if !self.rules.contains_key(&rule.id) {
return Err(Error::RuleError(format!("规则不存在: {}", rule.id)));
}
// 检查规则冲突
self.check_conflicts(&rule)?;
self.rules.insert(rule.id.clone(), rule);
Ok(())
}
/// 获取规则
pub fn get_rule(&self, rule_id: &str) -> Option<&Rule> {
self.rules.get(rule_id)
}
/// 获取所有规则
pub fn get_all_rules(&self) -> Vec<&Rule> {
self.rules.values().collect()
}
/// 应用规则
pub fn apply(&self, result: &mut ComplianceResult, data: &ComplianceData) -> Result<()> {
// 获取适用于当前层级的规则
let applicable_rules: Vec<&Rule> = self.rules.values()
.filter(|r| r.layer == result.layer && r.enabled)
.collect();
// 执行规则
for rule in applicable_rules {
self.executor.execute(rule, result, data)?;
}
Ok(())
}
/// 检查规则冲突
fn check_conflicts(&self, new_rule: &Rule) -> Result<()> {
for existing_rule in self.rules.values() {
if existing_rule.id == new_rule.id {
continue;
}
// 检查是否有冲突
if existing_rule.layer == new_rule.layer &&
existing_rule.priority == new_rule.priority &&
existing_rule.condition.conflicts_with(&new_rule.condition) {
return Err(Error::RuleError(format!(
"规则冲突: {}{}",
existing_rule.id,
new_rule.id
)));
}
}
Ok(())
}
}
impl Default for RuleEngine {
fn default() -> Self {
Self::new()
}
}
/// 规则
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Rule {
/// 规则ID
pub id: String,
/// 规则名称
pub name: String,
/// 规则描述
pub description: String,
/// 适用层级
pub layer: ComplianceLayer,
/// 条件
pub condition: RuleCondition,
/// 动作
pub action: RuleAction,
/// 优先级(数字越大优先级越高)
pub priority: i32,
/// 是否启用
pub enabled: bool,
/// 版本
pub version: u32,
}
impl Rule {
/// 创建新规则
pub fn new(id: String, name: String, layer: ComplianceLayer) -> Self {
Self {
id,
name,
description: String::new(),
layer,
condition: RuleCondition::Always,
action: RuleAction::Pass,
priority: 0,
enabled: true,
version: 1,
}
}
/// 设置描述
pub fn with_description(mut self, description: String) -> Self {
self.description = description;
self
}
/// 设置条件
pub fn with_condition(mut self, condition: RuleCondition) -> Self {
self.condition = condition;
self
}
/// 设置动作
pub fn with_action(mut self, action: RuleAction) -> Self {
self.action = action;
self
}
/// 设置优先级
pub fn with_priority(mut self, priority: i32) -> Self {
self.priority = priority;
self
}
}
/// 规则条件
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RuleCondition {
/// 总是满足
Always,
/// 从不满足
Never,
/// 置信度条件
Confidence {
operator: ComparisonOperator,
value: f64,
},
/// 风险等级条件
RiskLevel {
operator: ComparisonOperator,
value: RiskLevel,
},
/// 状态条件
Status {
value: ComplianceStatus,
},
/// 字段条件
Field {
field: String,
operator: ComparisonOperator,
value: serde_json::Value,
},
/// AND条件
And(Vec<RuleCondition>),
/// OR条件
Or(Vec<RuleCondition>),
/// NOT条件
Not(Box<RuleCondition>),
}
impl RuleCondition {
/// 检查是否与另一个条件冲突
pub fn conflicts_with(&self, _other: &RuleCondition) -> bool {
// 简化实现:假设不冲突
// 实际实现需要复杂的逻辑分析
false
}
}
/// 比较运算符
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ComparisonOperator {
/// 等于
Equal,
/// 不等于
NotEqual,
/// 大于
GreaterThan,
/// 大于等于
GreaterThanOrEqual,
/// 小于
LessThan,
/// 小于等于
LessThanOrEqual,
}
/// 规则动作
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RuleAction {
/// 通过
Pass,
/// 拒绝
Reject {
reason: String,
},
/// 修改状态
SetStatus {
status: ComplianceStatus,
},
/// 修改风险等级
SetRiskLevel {
level: RiskLevel,
},
/// 添加问题
AddIssue {
issue: ComplianceIssue,
},
/// 添加建议
AddRecommendation {
recommendation: String,
},
/// 修改置信度
AdjustConfidence {
adjustment: f64,
},
}
/// 规则执行器
pub struct RuleExecutor;
impl RuleExecutor {
/// 创建新的规则执行器
pub fn new() -> Self {
Self
}
/// 执行规则
pub fn execute(&self, rule: &Rule, result: &mut ComplianceResult, data: &ComplianceData) -> Result<()> {
// 检查条件是否满足
if !self.evaluate_condition(&rule.condition, result, data)? {
return Ok(());
}
// 执行动作
self.execute_action(&rule.action, result)?;
Ok(())
}
/// 评估条件
fn evaluate_condition(&self, condition: &RuleCondition, result: &ComplianceResult, data: &ComplianceData) -> Result<bool> {
match condition {
RuleCondition::Always => Ok(true),
RuleCondition::Never => Ok(false),
RuleCondition::Confidence { operator, value } => {
Ok(self.compare_f64(result.confidence, *value, *operator))
}
RuleCondition::RiskLevel { operator, value } => {
Ok(self.compare_risk_level(result.risk_level, *value, *operator))
}
RuleCondition::Status { value } => {
Ok(result.status == *value)
}
RuleCondition::Field { field, operator, value } => {
let field_value = data.fields.get(field);
match field_value {
Some(v) => Ok(self.compare_json(v, value, *operator)),
None => Ok(false),
}
}
RuleCondition::And(conditions) => {
for cond in conditions {
if !self.evaluate_condition(cond, result, data)? {
return Ok(false);
}
}
Ok(true)
}
RuleCondition::Or(conditions) => {
for cond in conditions {
if self.evaluate_condition(cond, result, data)? {
return Ok(true);
}
}
Ok(false)
}
RuleCondition::Not(condition) => {
Ok(!self.evaluate_condition(condition, result, data)?)
}
}
}
/// 执行动作
fn execute_action(&self, action: &RuleAction, result: &mut ComplianceResult) -> Result<()> {
match action {
RuleAction::Pass => {
// 不做任何修改
}
RuleAction::Reject { reason } => {
result.status = ComplianceStatus::Failed;
result.details = reason.clone();
}
RuleAction::SetStatus { status } => {
result.status = *status;
}
RuleAction::SetRiskLevel { level } => {
result.risk_level = *level;
}
RuleAction::AddIssue { issue } => {
result.issues.push(issue.clone());
}
RuleAction::AddRecommendation { recommendation } => {
result.recommendations.push(recommendation.clone());
}
RuleAction::AdjustConfidence { adjustment } => {
result.confidence = (result.confidence + adjustment).clamp(0.0, 1.0);
}
}
Ok(())
}
/// 比较浮点数
fn compare_f64(&self, left: f64, right: f64, operator: ComparisonOperator) -> bool {
match operator {
ComparisonOperator::Equal => (left - right).abs() < f64::EPSILON,
ComparisonOperator::NotEqual => (left - right).abs() >= f64::EPSILON,
ComparisonOperator::GreaterThan => left > right,
ComparisonOperator::GreaterThanOrEqual => left >= right,
ComparisonOperator::LessThan => left < right,
ComparisonOperator::LessThanOrEqual => left <= right,
}
}
/// 比较风险等级
fn compare_risk_level(&self, left: RiskLevel, right: RiskLevel, operator: ComparisonOperator) -> bool {
match operator {
ComparisonOperator::Equal => left == right,
ComparisonOperator::NotEqual => left != right,
ComparisonOperator::GreaterThan => left > right,
ComparisonOperator::GreaterThanOrEqual => left >= right,
ComparisonOperator::LessThan => left < right,
ComparisonOperator::LessThanOrEqual => left <= right,
}
}
/// 比较JSON值
fn compare_json(&self, _left: &serde_json::Value, _right: &serde_json::Value, _operator: ComparisonOperator) -> bool {
// 简化实现
true
}
}
impl Default for RuleExecutor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rule_creation() {
let rule = Rule::new(
"rule1".to_string(),
"Test Rule".to_string(),
ComplianceLayer::IdentityVerification,
);
assert_eq!(rule.id, "rule1");
assert_eq!(rule.version, 1);
}
#[test]
fn test_rule_engine() {
let mut engine = RuleEngine::new();
let rule = Rule::new(
"rule1".to_string(),
"Test Rule".to_string(),
ComplianceLayer::IdentityVerification,
);
assert!(engine.add_rule(rule).is_ok());
assert!(engine.get_rule("rule1").is_some());
}
#[test]
fn test_rule_condition() {
let condition = RuleCondition::Confidence {
operator: ComparisonOperator::GreaterThan,
value: 0.8,
};
let executor = RuleExecutor::new();
let result = ComplianceResult {
layer: ComplianceLayer::IdentityVerification,
status: ComplianceStatus::Passed,
confidence: 0.9,
risk_level: RiskLevel::Low,
details: "Test".to_string(),
issues: vec![],
recommendations: vec![],
timestamp: chrono::Utc::now(),
};
let data = ComplianceData::new("user123".to_string());
assert!(executor.evaluate_condition(&condition, &result, &data).unwrap());
}
}