NAC_Blockchain/nac-ai-valuation/src/validation.rs

646 lines
19 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 估值验证系统
//!
//! 提供估值验证机制、精度评估、差异分析和模型优化
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use chrono::{DateTime, Utc};
use crate::{FinalValuationResult, AIProvider, AIValuationResult};
/// 验证规则
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationRule {
/// 规则名称
pub name: String,
/// 规则描述
pub description: String,
/// 最小估值XTZH
pub min_valuation: Option<Decimal>,
/// 最大估值XTZH
pub max_valuation: Option<Decimal>,
/// 最小置信度
pub min_confidence: Option<f64>,
/// 最大模型差异率(%
pub max_model_divergence: Option<f64>,
/// 是否启用
pub enabled: bool,
}
impl ValidationRule {
/// 创建默认验证规则
pub fn default_rules() -> Vec<Self> {
vec![
ValidationRule {
name: "估值范围检查".to_string(),
description: "确保估值在合理范围内".to_string(),
min_valuation: Some(Decimal::ZERO),
max_valuation: Some(Decimal::new(1_000_000_000_000, 0)), // 1万亿XTZH
min_confidence: None,
max_model_divergence: None,
enabled: true,
},
ValidationRule {
name: "置信度检查".to_string(),
description: "确保置信度达到最低要求".to_string(),
min_valuation: None,
max_valuation: None,
min_confidence: Some(0.5),
max_model_divergence: None,
enabled: true,
},
ValidationRule {
name: "模型差异检查".to_string(),
description: "确保AI模型估值差异不过大".to_string(),
min_valuation: None,
max_valuation: None,
min_confidence: None,
max_model_divergence: Some(30.0), // 30%
enabled: true,
},
]
}
/// 验证估值结果
pub fn validate(&self, result: &FinalValuationResult) -> ValidationResult {
if !self.enabled {
return ValidationResult::passed(self.name.clone());
}
let mut issues = Vec::new();
// 检查估值范围
if let Some(min_val) = self.min_valuation {
if result.valuation_xtzh < min_val {
issues.push(format!(
"估值 {} XTZH 低于最小值 {} XTZH",
result.valuation_xtzh, min_val
));
}
}
if let Some(max_val) = self.max_valuation {
if result.valuation_xtzh > max_val {
issues.push(format!(
"估值 {} XTZH 超过最大值 {} XTZH",
result.valuation_xtzh, max_val
));
}
}
// 检查置信度
if let Some(min_conf) = self.min_confidence {
if result.confidence < min_conf {
issues.push(format!(
"置信度 {:.1}% 低于最小值 {:.1}%",
result.confidence * 100.0,
min_conf * 100.0
));
}
}
// 检查模型差异
if let Some(max_div) = self.max_model_divergence {
if !result.model_results.is_empty() {
let valuations: Vec<Decimal> = result.model_results
.iter()
.map(|r| r.valuation_xtzh)
.collect();
if let Some(divergence) = Self::calculate_divergence(&valuations) {
if divergence > max_div {
issues.push(format!(
"模型差异 {:.1}% 超过最大值 {:.1}%",
divergence, max_div
));
}
}
}
}
if issues.is_empty() {
ValidationResult::passed(self.name.clone())
} else {
ValidationResult::failed(self.name.clone(), issues)
}
}
/// 计算模型差异率
fn calculate_divergence(valuations: &[Decimal]) -> Option<f64> {
if valuations.len() < 2 {
return None;
}
let min = valuations.iter().min()?;
let max = valuations.iter().max()?;
if *min == Decimal::ZERO {
return None;
}
let divergence = ((*max - *min) / *min * Decimal::new(100, 0))
.to_string()
.parse::<f64>()
.ok()?;
Some(divergence)
}
}
/// 验证结果
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationResult {
/// 规则名称
pub rule_name: String,
/// 是否通过
pub passed: bool,
/// 问题列表
pub issues: Vec<String>,
/// 验证时间
pub timestamp: DateTime<Utc>,
}
impl ValidationResult {
/// 创建通过的验证结果
pub fn passed(rule_name: String) -> Self {
Self {
rule_name,
passed: true,
issues: Vec::new(),
timestamp: Utc::now(),
}
}
/// 创建失败的验证结果
pub fn failed(rule_name: String, issues: Vec<String>) -> Self {
Self {
rule_name,
passed: false,
issues,
timestamp: Utc::now(),
}
}
}
/// 估值验证器
pub struct ValuationValidator {
/// 验证规则列表
rules: Vec<ValidationRule>,
}
impl ValuationValidator {
/// 创建新的验证器
pub fn new(rules: Vec<ValidationRule>) -> Self {
Self { rules }
}
/// 使用默认规则创建验证器
pub fn with_default_rules() -> Self {
Self::new(ValidationRule::default_rules())
}
/// 验证估值结果
pub fn validate(&self, result: &FinalValuationResult) -> Vec<ValidationResult> {
self.rules
.iter()
.map(|rule| rule.validate(result))
.collect()
}
/// 检查是否所有验证都通过
pub fn validate_all(&self, result: &FinalValuationResult) -> bool {
self.validate(result).iter().all(|r| r.passed)
}
/// 添加验证规则
pub fn add_rule(&mut self, rule: ValidationRule) {
self.rules.push(rule);
}
/// 移除验证规则
pub fn remove_rule(&mut self, rule_name: &str) {
self.rules.retain(|r| r.name != rule_name);
}
/// 启用/禁用规则
pub fn set_rule_enabled(&mut self, rule_name: &str, enabled: bool) {
if let Some(rule) = self.rules.iter_mut().find(|r| r.name == rule_name) {
rule.enabled = enabled;
}
}
}
/// 精度评估器
pub struct AccuracyEvaluator;
impl AccuracyEvaluator {
/// 评估估值精度(与实际价值比较)
pub fn evaluate(
estimated: Decimal,
actual: Decimal,
) -> AccuracyMetrics {
let absolute_error = if estimated > actual {
estimated - actual
} else {
actual - estimated
};
let relative_error = if actual != Decimal::ZERO {
(absolute_error / actual * Decimal::new(100, 0))
.to_string()
.parse::<f64>()
.unwrap_or(0.0)
} else {
0.0
};
let accuracy = 100.0 - relative_error.min(100.0);
AccuracyMetrics {
estimated,
actual,
absolute_error,
relative_error,
accuracy,
}
}
/// 批量评估精度
pub fn evaluate_batch(
pairs: Vec<(Decimal, Decimal)>,
) -> BatchAccuracyMetrics {
let metrics: Vec<AccuracyMetrics> = pairs
.iter()
.map(|(est, act)| Self::evaluate(*est, *act))
.collect();
let avg_accuracy = metrics.iter().map(|m| m.accuracy).sum::<f64>() / metrics.len() as f64;
let avg_relative_error = metrics.iter().map(|m| m.relative_error).sum::<f64>() / metrics.len() as f64;
BatchAccuracyMetrics {
total_samples: metrics.len(),
avg_accuracy,
avg_relative_error,
individual_metrics: metrics,
}
}
}
/// 精度指标
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccuracyMetrics {
/// 估值
pub estimated: Decimal,
/// 实际值
pub actual: Decimal,
/// 绝对误差
pub absolute_error: Decimal,
/// 相对误差(%
pub relative_error: f64,
/// 精度(%
pub accuracy: f64,
}
/// 批量精度指标
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchAccuracyMetrics {
/// 样本总数
pub total_samples: usize,
/// 平均精度(%
pub avg_accuracy: f64,
/// 平均相对误差(%
pub avg_relative_error: f64,
/// 各样本指标
pub individual_metrics: Vec<AccuracyMetrics>,
}
/// 差异分析器
pub struct DivergenceAnalyzer;
impl DivergenceAnalyzer {
/// 分析AI模型估值差异
pub fn analyze(model_results: &[AIValuationResult]) -> DivergenceAnalysis {
if model_results.is_empty() {
return DivergenceAnalysis::empty();
}
let valuations: Vec<Decimal> = model_results
.iter()
.map(|r| r.valuation_xtzh)
.collect();
let min_valuation = valuations.iter().min().unwrap().clone();
let max_valuation = valuations.iter().max().unwrap().clone();
let avg_valuation = valuations.iter().sum::<Decimal>() / Decimal::new(valuations.len() as i64, 0);
let divergence_rate = if min_valuation != Decimal::ZERO {
((max_valuation - min_valuation) / min_valuation * Decimal::new(100, 0))
.to_string()
.parse::<f64>()
.unwrap_or(0.0)
} else {
0.0
};
// 识别异常值
let outliers = Self::identify_outliers(model_results, avg_valuation);
// 模型一致性评分
let consistency_score = Self::calculate_consistency_score(divergence_rate);
DivergenceAnalysis {
model_count: model_results.len(),
min_valuation,
max_valuation,
avg_valuation,
divergence_rate,
outliers,
consistency_score,
}
}
/// 识别异常值偏离平均值超过30%
fn identify_outliers(
model_results: &[AIValuationResult],
avg_valuation: Decimal,
) -> Vec<AIProvider> {
let threshold = Decimal::new(30, 2); // 0.30 = 30%
model_results
.iter()
.filter(|r| {
let deviation = if r.valuation_xtzh > avg_valuation {
(r.valuation_xtzh - avg_valuation) / avg_valuation
} else {
(avg_valuation - r.valuation_xtzh) / avg_valuation
};
deviation > threshold
})
.map(|r| r.provider)
.collect()
}
/// 计算一致性评分0-100
fn calculate_consistency_score(divergence_rate: f64) -> f64 {
if divergence_rate <= 10.0 {
100.0
} else if divergence_rate <= 20.0 {
90.0 - (divergence_rate - 10.0)
} else if divergence_rate <= 30.0 {
80.0 - (divergence_rate - 20.0) * 2.0
} else {
(60.0 - (divergence_rate - 30.0)).max(0.0)
}
}
}
/// 差异分析结果
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DivergenceAnalysis {
/// 模型数量
pub model_count: usize,
/// 最小估值
pub min_valuation: Decimal,
/// 最大估值
pub max_valuation: Decimal,
/// 平均估值
pub avg_valuation: Decimal,
/// 差异率(%
pub divergence_rate: f64,
/// 异常值模型
pub outliers: Vec<AIProvider>,
/// 一致性评分0-100
pub consistency_score: f64,
}
impl DivergenceAnalysis {
/// 创建空的分析结果
fn empty() -> Self {
Self {
model_count: 0,
min_valuation: Decimal::ZERO,
max_valuation: Decimal::ZERO,
avg_valuation: Decimal::ZERO,
divergence_rate: 0.0,
outliers: Vec::new(),
consistency_score: 0.0,
}
}
/// 生成分析报告
pub fn generate_report(&self) -> String {
let mut report = String::new();
report.push_str("# 模型差异分析报告\n\n");
report.push_str(&format!("- **模型数量**: {}\n", self.model_count));
report.push_str(&format!("- **估值范围**: {} - {} XTZH\n", self.min_valuation, self.max_valuation));
report.push_str(&format!("- **平均估值**: {} XTZH\n", self.avg_valuation));
report.push_str(&format!("- **差异率**: {:.2}%\n", self.divergence_rate));
report.push_str(&format!("- **一致性评分**: {:.1}/100\n\n", self.consistency_score));
if !self.outliers.is_empty() {
report.push_str("## ⚠️ 异常值检测\n\n");
report.push_str("以下模型的估值偏离平均值超过30%\n\n");
for provider in &self.outliers {
report.push_str(&format!("- {:?}\n", provider));
}
report.push_str("\n");
}
if self.divergence_rate > 30.0 {
report.push_str("## 🔴 高差异警告\n\n");
report.push_str("模型估值差异率超过30%,建议:\n");
report.push_str("1. 检查输入数据的准确性\n");
report.push_str("2. 审查异常模型的估值逻辑\n");
report.push_str("3. 考虑人工审核\n");
} else if self.divergence_rate > 20.0 {
report.push_str("## 🟡 中等差异提示\n\n");
report.push_str("模型估值差异率在20-30%之间,建议关注。\n");
} else {
report.push_str("## 🟢 差异正常\n\n");
report.push_str("模型估值差异在可接受范围内。\n");
}
report
}
}
/// 模型优化建议
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizationSuggestion {
/// 建议类型
pub suggestion_type: SuggestionType,
/// 建议描述
pub description: String,
/// 优先级
pub priority: Priority,
/// 目标模型
pub target_model: Option<AIProvider>,
}
/// 建议类型
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SuggestionType {
/// 调整权重
AdjustWeight,
/// 更新模型
UpdateModel,
/// 增加训练数据
AddTrainingData,
/// 调整参数
TuneParameters,
/// 人工审核
HumanReview,
}
/// 优先级
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Priority {
/// 高
High,
/// 中
Medium,
/// 低
Low,
}
/// 模型优化器
pub struct ModelOptimizer;
impl ModelOptimizer {
/// 生成优化建议
pub fn generate_suggestions(
result: &FinalValuationResult,
divergence: &DivergenceAnalysis,
) -> Vec<OptimizationSuggestion> {
let mut suggestions = Vec::new();
// 如果差异率过高
if divergence.divergence_rate > 30.0 {
suggestions.push(OptimizationSuggestion {
suggestion_type: SuggestionType::HumanReview,
description: "模型差异过大,建议人工审核".to_string(),
priority: Priority::High,
target_model: None,
});
}
// 如果有异常值
for provider in &divergence.outliers {
suggestions.push(OptimizationSuggestion {
suggestion_type: SuggestionType::AdjustWeight,
description: format!("模型 {:?} 估值异常,建议降低权重", provider),
priority: Priority::Medium,
target_model: Some(*provider),
});
}
// 如果置信度过低
if result.confidence < 0.7 {
suggestions.push(OptimizationSuggestion {
suggestion_type: SuggestionType::AddTrainingData,
description: "整体置信度偏低,建议增加训练数据".to_string(),
priority: Priority::Medium,
target_model: None,
});
}
// 如果一致性评分低
if divergence.consistency_score < 70.0 {
suggestions.push(OptimizationSuggestion {
suggestion_type: SuggestionType::TuneParameters,
description: "模型一致性不足,建议调整参数".to_string(),
priority: Priority::Low,
target_model: None,
});
}
suggestions
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
use std::collections::HashMap;
fn create_test_result(valuation: i64, confidence: f64) -> FinalValuationResult {
FinalValuationResult {
valuation_xtzh: Decimal::new(valuation, 0),
confidence,
model_results: vec![],
weights: HashMap::new(),
is_anomaly: false,
anomaly_report: None,
divergence_report: "Test".to_string(),
requires_human_review: false,
}
}
#[test]
fn test_validation_rule() {
let rule = ValidationRule {
name: "Test Rule".to_string(),
description: "Test".to_string(),
min_valuation: Some(Decimal::new(1000, 0)),
max_valuation: Some(Decimal::new(10000, 0)),
min_confidence: Some(0.7),
max_model_divergence: None,
enabled: true,
};
let result = create_test_result(5000, 0.8);
let validation = rule.validate(&result);
assert!(validation.passed);
let result2 = create_test_result(500, 0.5);
let validation2 = rule.validate(&result2);
assert!(!validation2.passed);
}
#[test]
fn test_accuracy_evaluator() {
let metrics = AccuracyEvaluator::evaluate(
Decimal::new(100, 0),
Decimal::new(110, 0),
);
assert_eq!(metrics.absolute_error, Decimal::new(10, 0));
assert!(metrics.relative_error > 9.0 && metrics.relative_error < 10.0);
assert!(metrics.accuracy > 90.0);
}
#[test]
fn test_divergence_analyzer() {
let model_results = vec![
AIValuationResult {
provider: AIProvider::ChatGPT,
valuation_xtzh: Decimal::new(1000, 0),
confidence: 0.85,
reasoning: "Test".to_string(),
timestamp: Utc::now(),
},
AIValuationResult {
provider: AIProvider::DeepSeek,
valuation_xtzh: Decimal::new(1100, 0),
confidence: 0.88,
reasoning: "Test".to_string(),
timestamp: Utc::now(),
},
AIValuationResult {
provider: AIProvider::DouBao,
valuation_xtzh: Decimal::new(1050, 0),
confidence: 0.82,
reasoning: "Test".to_string(),
timestamp: Utc::now(),
},
];
let analysis = DivergenceAnalyzer::analyze(&model_results);
assert_eq!(analysis.model_count, 3);
assert!(analysis.divergence_rate < 15.0);
assert!(analysis.consistency_score > 85.0);
}
}