646 lines
19 KiB
Rust
646 lines
19 KiB
Rust
//! 估值验证系统
|
||
//!
|
||
//! 提供估值验证机制、精度评估、差异分析和模型优化
|
||
|
||
use rust_decimal::Decimal;
|
||
use serde::{Deserialize, Serialize};
|
||
use chrono::{DateTime, Utc};
|
||
|
||
use crate::{FinalValuationResult, AIProvider, AIValuationResult};
|
||
|
||
/// 验证规则
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct ValidationRule {
|
||
/// 规则名称
|
||
pub name: String,
|
||
/// 规则描述
|
||
pub description: String,
|
||
/// 最小估值(XTZH)
|
||
pub min_valuation: Option<Decimal>,
|
||
/// 最大估值(XTZH)
|
||
pub max_valuation: Option<Decimal>,
|
||
/// 最小置信度
|
||
pub min_confidence: Option<f64>,
|
||
/// 最大模型差异率(%)
|
||
pub max_model_divergence: Option<f64>,
|
||
/// 是否启用
|
||
pub enabled: bool,
|
||
}
|
||
|
||
impl ValidationRule {
|
||
/// 创建默认验证规则
|
||
pub fn default_rules() -> Vec<Self> {
|
||
vec![
|
||
ValidationRule {
|
||
name: "估值范围检查".to_string(),
|
||
description: "确保估值在合理范围内".to_string(),
|
||
min_valuation: Some(Decimal::ZERO),
|
||
max_valuation: Some(Decimal::new(1_000_000_000_000, 0)), // 1万亿XTZH
|
||
min_confidence: None,
|
||
max_model_divergence: None,
|
||
enabled: true,
|
||
},
|
||
ValidationRule {
|
||
name: "置信度检查".to_string(),
|
||
description: "确保置信度达到最低要求".to_string(),
|
||
min_valuation: None,
|
||
max_valuation: None,
|
||
min_confidence: Some(0.5),
|
||
max_model_divergence: None,
|
||
enabled: true,
|
||
},
|
||
ValidationRule {
|
||
name: "模型差异检查".to_string(),
|
||
description: "确保AI模型估值差异不过大".to_string(),
|
||
min_valuation: None,
|
||
max_valuation: None,
|
||
min_confidence: None,
|
||
max_model_divergence: Some(30.0), // 30%
|
||
enabled: true,
|
||
},
|
||
]
|
||
}
|
||
|
||
/// 验证估值结果
|
||
pub fn validate(&self, result: &FinalValuationResult) -> ValidationResult {
|
||
if !self.enabled {
|
||
return ValidationResult::passed(self.name.clone());
|
||
}
|
||
|
||
let mut issues = Vec::new();
|
||
|
||
// 检查估值范围
|
||
if let Some(min_val) = self.min_valuation {
|
||
if result.valuation_xtzh < min_val {
|
||
issues.push(format!(
|
||
"估值 {} XTZH 低于最小值 {} XTZH",
|
||
result.valuation_xtzh, min_val
|
||
));
|
||
}
|
||
}
|
||
|
||
if let Some(max_val) = self.max_valuation {
|
||
if result.valuation_xtzh > max_val {
|
||
issues.push(format!(
|
||
"估值 {} XTZH 超过最大值 {} XTZH",
|
||
result.valuation_xtzh, max_val
|
||
));
|
||
}
|
||
}
|
||
|
||
// 检查置信度
|
||
if let Some(min_conf) = self.min_confidence {
|
||
if result.confidence < min_conf {
|
||
issues.push(format!(
|
||
"置信度 {:.1}% 低于最小值 {:.1}%",
|
||
result.confidence * 100.0,
|
||
min_conf * 100.0
|
||
));
|
||
}
|
||
}
|
||
|
||
// 检查模型差异
|
||
if let Some(max_div) = self.max_model_divergence {
|
||
if !result.model_results.is_empty() {
|
||
let valuations: Vec<Decimal> = result.model_results
|
||
.iter()
|
||
.map(|r| r.valuation_xtzh)
|
||
.collect();
|
||
|
||
if let Some(divergence) = Self::calculate_divergence(&valuations) {
|
||
if divergence > max_div {
|
||
issues.push(format!(
|
||
"模型差异 {:.1}% 超过最大值 {:.1}%",
|
||
divergence, max_div
|
||
));
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if issues.is_empty() {
|
||
ValidationResult::passed(self.name.clone())
|
||
} else {
|
||
ValidationResult::failed(self.name.clone(), issues)
|
||
}
|
||
}
|
||
|
||
/// 计算模型差异率
|
||
fn calculate_divergence(valuations: &[Decimal]) -> Option<f64> {
|
||
if valuations.len() < 2 {
|
||
return None;
|
||
}
|
||
|
||
let min = valuations.iter().min()?;
|
||
let max = valuations.iter().max()?;
|
||
|
||
if *min == Decimal::ZERO {
|
||
return None;
|
||
}
|
||
|
||
let divergence = ((*max - *min) / *min * Decimal::new(100, 0))
|
||
.to_string()
|
||
.parse::<f64>()
|
||
.ok()?;
|
||
|
||
Some(divergence)
|
||
}
|
||
}
|
||
|
||
/// 验证结果
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct ValidationResult {
|
||
/// 规则名称
|
||
pub rule_name: String,
|
||
/// 是否通过
|
||
pub passed: bool,
|
||
/// 问题列表
|
||
pub issues: Vec<String>,
|
||
/// 验证时间
|
||
pub timestamp: DateTime<Utc>,
|
||
}
|
||
|
||
impl ValidationResult {
|
||
/// 创建通过的验证结果
|
||
pub fn passed(rule_name: String) -> Self {
|
||
Self {
|
||
rule_name,
|
||
passed: true,
|
||
issues: Vec::new(),
|
||
timestamp: Utc::now(),
|
||
}
|
||
}
|
||
|
||
/// 创建失败的验证结果
|
||
pub fn failed(rule_name: String, issues: Vec<String>) -> Self {
|
||
Self {
|
||
rule_name,
|
||
passed: false,
|
||
issues,
|
||
timestamp: Utc::now(),
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 估值验证器
|
||
pub struct ValuationValidator {
|
||
/// 验证规则列表
|
||
rules: Vec<ValidationRule>,
|
||
}
|
||
|
||
impl ValuationValidator {
|
||
/// 创建新的验证器
|
||
pub fn new(rules: Vec<ValidationRule>) -> Self {
|
||
Self { rules }
|
||
}
|
||
|
||
/// 使用默认规则创建验证器
|
||
pub fn with_default_rules() -> Self {
|
||
Self::new(ValidationRule::default_rules())
|
||
}
|
||
|
||
/// 验证估值结果
|
||
pub fn validate(&self, result: &FinalValuationResult) -> Vec<ValidationResult> {
|
||
self.rules
|
||
.iter()
|
||
.map(|rule| rule.validate(result))
|
||
.collect()
|
||
}
|
||
|
||
/// 检查是否所有验证都通过
|
||
pub fn validate_all(&self, result: &FinalValuationResult) -> bool {
|
||
self.validate(result).iter().all(|r| r.passed)
|
||
}
|
||
|
||
/// 添加验证规则
|
||
pub fn add_rule(&mut self, rule: ValidationRule) {
|
||
self.rules.push(rule);
|
||
}
|
||
|
||
/// 移除验证规则
|
||
pub fn remove_rule(&mut self, rule_name: &str) {
|
||
self.rules.retain(|r| r.name != rule_name);
|
||
}
|
||
|
||
/// 启用/禁用规则
|
||
pub fn set_rule_enabled(&mut self, rule_name: &str, enabled: bool) {
|
||
if let Some(rule) = self.rules.iter_mut().find(|r| r.name == rule_name) {
|
||
rule.enabled = enabled;
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 精度评估器
|
||
pub struct AccuracyEvaluator;
|
||
|
||
impl AccuracyEvaluator {
|
||
/// 评估估值精度(与实际价值比较)
|
||
pub fn evaluate(
|
||
estimated: Decimal,
|
||
actual: Decimal,
|
||
) -> AccuracyMetrics {
|
||
let absolute_error = if estimated > actual {
|
||
estimated - actual
|
||
} else {
|
||
actual - estimated
|
||
};
|
||
|
||
let relative_error = if actual != Decimal::ZERO {
|
||
(absolute_error / actual * Decimal::new(100, 0))
|
||
.to_string()
|
||
.parse::<f64>()
|
||
.unwrap_or(0.0)
|
||
} else {
|
||
0.0
|
||
};
|
||
|
||
let accuracy = 100.0 - relative_error.min(100.0);
|
||
|
||
AccuracyMetrics {
|
||
estimated,
|
||
actual,
|
||
absolute_error,
|
||
relative_error,
|
||
accuracy,
|
||
}
|
||
}
|
||
|
||
/// 批量评估精度
|
||
pub fn evaluate_batch(
|
||
pairs: Vec<(Decimal, Decimal)>,
|
||
) -> BatchAccuracyMetrics {
|
||
let metrics: Vec<AccuracyMetrics> = pairs
|
||
.iter()
|
||
.map(|(est, act)| Self::evaluate(*est, *act))
|
||
.collect();
|
||
|
||
let avg_accuracy = metrics.iter().map(|m| m.accuracy).sum::<f64>() / metrics.len() as f64;
|
||
let avg_relative_error = metrics.iter().map(|m| m.relative_error).sum::<f64>() / metrics.len() as f64;
|
||
|
||
BatchAccuracyMetrics {
|
||
total_samples: metrics.len(),
|
||
avg_accuracy,
|
||
avg_relative_error,
|
||
individual_metrics: metrics,
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 精度指标
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct AccuracyMetrics {
|
||
/// 估值
|
||
pub estimated: Decimal,
|
||
/// 实际值
|
||
pub actual: Decimal,
|
||
/// 绝对误差
|
||
pub absolute_error: Decimal,
|
||
/// 相对误差(%)
|
||
pub relative_error: f64,
|
||
/// 精度(%)
|
||
pub accuracy: f64,
|
||
}
|
||
|
||
/// 批量精度指标
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct BatchAccuracyMetrics {
|
||
/// 样本总数
|
||
pub total_samples: usize,
|
||
/// 平均精度(%)
|
||
pub avg_accuracy: f64,
|
||
/// 平均相对误差(%)
|
||
pub avg_relative_error: f64,
|
||
/// 各样本指标
|
||
pub individual_metrics: Vec<AccuracyMetrics>,
|
||
}
|
||
|
||
/// 差异分析器
|
||
pub struct DivergenceAnalyzer;
|
||
|
||
impl DivergenceAnalyzer {
|
||
/// 分析AI模型估值差异
|
||
pub fn analyze(model_results: &[AIValuationResult]) -> DivergenceAnalysis {
|
||
if model_results.is_empty() {
|
||
return DivergenceAnalysis::empty();
|
||
}
|
||
|
||
let valuations: Vec<Decimal> = model_results
|
||
.iter()
|
||
.map(|r| r.valuation_xtzh)
|
||
.collect();
|
||
|
||
let min_valuation = valuations.iter().min().unwrap().clone();
|
||
let max_valuation = valuations.iter().max().unwrap().clone();
|
||
let avg_valuation = valuations.iter().sum::<Decimal>() / Decimal::new(valuations.len() as i64, 0);
|
||
|
||
let divergence_rate = if min_valuation != Decimal::ZERO {
|
||
((max_valuation - min_valuation) / min_valuation * Decimal::new(100, 0))
|
||
.to_string()
|
||
.parse::<f64>()
|
||
.unwrap_or(0.0)
|
||
} else {
|
||
0.0
|
||
};
|
||
|
||
// 识别异常值
|
||
let outliers = Self::identify_outliers(model_results, avg_valuation);
|
||
|
||
// 模型一致性评分
|
||
let consistency_score = Self::calculate_consistency_score(divergence_rate);
|
||
|
||
DivergenceAnalysis {
|
||
model_count: model_results.len(),
|
||
min_valuation,
|
||
max_valuation,
|
||
avg_valuation,
|
||
divergence_rate,
|
||
outliers,
|
||
consistency_score,
|
||
}
|
||
}
|
||
|
||
/// 识别异常值(偏离平均值超过30%)
|
||
fn identify_outliers(
|
||
model_results: &[AIValuationResult],
|
||
avg_valuation: Decimal,
|
||
) -> Vec<AIProvider> {
|
||
let threshold = Decimal::new(30, 2); // 0.30 = 30%
|
||
|
||
model_results
|
||
.iter()
|
||
.filter(|r| {
|
||
let deviation = if r.valuation_xtzh > avg_valuation {
|
||
(r.valuation_xtzh - avg_valuation) / avg_valuation
|
||
} else {
|
||
(avg_valuation - r.valuation_xtzh) / avg_valuation
|
||
};
|
||
deviation > threshold
|
||
})
|
||
.map(|r| r.provider)
|
||
.collect()
|
||
}
|
||
|
||
/// 计算一致性评分(0-100)
|
||
fn calculate_consistency_score(divergence_rate: f64) -> f64 {
|
||
if divergence_rate <= 10.0 {
|
||
100.0
|
||
} else if divergence_rate <= 20.0 {
|
||
90.0 - (divergence_rate - 10.0)
|
||
} else if divergence_rate <= 30.0 {
|
||
80.0 - (divergence_rate - 20.0) * 2.0
|
||
} else {
|
||
(60.0 - (divergence_rate - 30.0)).max(0.0)
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 差异分析结果
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct DivergenceAnalysis {
|
||
/// 模型数量
|
||
pub model_count: usize,
|
||
/// 最小估值
|
||
pub min_valuation: Decimal,
|
||
/// 最大估值
|
||
pub max_valuation: Decimal,
|
||
/// 平均估值
|
||
pub avg_valuation: Decimal,
|
||
/// 差异率(%)
|
||
pub divergence_rate: f64,
|
||
/// 异常值模型
|
||
pub outliers: Vec<AIProvider>,
|
||
/// 一致性评分(0-100)
|
||
pub consistency_score: f64,
|
||
}
|
||
|
||
impl DivergenceAnalysis {
|
||
/// 创建空的分析结果
|
||
fn empty() -> Self {
|
||
Self {
|
||
model_count: 0,
|
||
min_valuation: Decimal::ZERO,
|
||
max_valuation: Decimal::ZERO,
|
||
avg_valuation: Decimal::ZERO,
|
||
divergence_rate: 0.0,
|
||
outliers: Vec::new(),
|
||
consistency_score: 0.0,
|
||
}
|
||
}
|
||
|
||
/// 生成分析报告
|
||
pub fn generate_report(&self) -> String {
|
||
let mut report = String::new();
|
||
|
||
report.push_str("# 模型差异分析报告\n\n");
|
||
report.push_str(&format!("- **模型数量**: {}\n", self.model_count));
|
||
report.push_str(&format!("- **估值范围**: {} - {} XTZH\n", self.min_valuation, self.max_valuation));
|
||
report.push_str(&format!("- **平均估值**: {} XTZH\n", self.avg_valuation));
|
||
report.push_str(&format!("- **差异率**: {:.2}%\n", self.divergence_rate));
|
||
report.push_str(&format!("- **一致性评分**: {:.1}/100\n\n", self.consistency_score));
|
||
|
||
if !self.outliers.is_empty() {
|
||
report.push_str("## ⚠️ 异常值检测\n\n");
|
||
report.push_str("以下模型的估值偏离平均值超过30%:\n\n");
|
||
for provider in &self.outliers {
|
||
report.push_str(&format!("- {:?}\n", provider));
|
||
}
|
||
report.push_str("\n");
|
||
}
|
||
|
||
if self.divergence_rate > 30.0 {
|
||
report.push_str("## 🔴 高差异警告\n\n");
|
||
report.push_str("模型估值差异率超过30%,建议:\n");
|
||
report.push_str("1. 检查输入数据的准确性\n");
|
||
report.push_str("2. 审查异常模型的估值逻辑\n");
|
||
report.push_str("3. 考虑人工审核\n");
|
||
} else if self.divergence_rate > 20.0 {
|
||
report.push_str("## 🟡 中等差异提示\n\n");
|
||
report.push_str("模型估值差异率在20-30%之间,建议关注。\n");
|
||
} else {
|
||
report.push_str("## 🟢 差异正常\n\n");
|
||
report.push_str("模型估值差异在可接受范围内。\n");
|
||
}
|
||
|
||
report
|
||
}
|
||
}
|
||
|
||
/// 模型优化建议
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct OptimizationSuggestion {
|
||
/// 建议类型
|
||
pub suggestion_type: SuggestionType,
|
||
/// 建议描述
|
||
pub description: String,
|
||
/// 优先级
|
||
pub priority: Priority,
|
||
/// 目标模型
|
||
pub target_model: Option<AIProvider>,
|
||
}
|
||
|
||
/// 建议类型
|
||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||
pub enum SuggestionType {
|
||
/// 调整权重
|
||
AdjustWeight,
|
||
/// 更新模型
|
||
UpdateModel,
|
||
/// 增加训练数据
|
||
AddTrainingData,
|
||
/// 调整参数
|
||
TuneParameters,
|
||
/// 人工审核
|
||
HumanReview,
|
||
}
|
||
|
||
/// 优先级
|
||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||
pub enum Priority {
|
||
/// 高
|
||
High,
|
||
/// 中
|
||
Medium,
|
||
/// 低
|
||
Low,
|
||
}
|
||
|
||
/// 模型优化器
|
||
pub struct ModelOptimizer;
|
||
|
||
impl ModelOptimizer {
|
||
/// 生成优化建议
|
||
pub fn generate_suggestions(
|
||
result: &FinalValuationResult,
|
||
divergence: &DivergenceAnalysis,
|
||
) -> Vec<OptimizationSuggestion> {
|
||
let mut suggestions = Vec::new();
|
||
|
||
// 如果差异率过高
|
||
if divergence.divergence_rate > 30.0 {
|
||
suggestions.push(OptimizationSuggestion {
|
||
suggestion_type: SuggestionType::HumanReview,
|
||
description: "模型差异过大,建议人工审核".to_string(),
|
||
priority: Priority::High,
|
||
target_model: None,
|
||
});
|
||
}
|
||
|
||
// 如果有异常值
|
||
for provider in &divergence.outliers {
|
||
suggestions.push(OptimizationSuggestion {
|
||
suggestion_type: SuggestionType::AdjustWeight,
|
||
description: format!("模型 {:?} 估值异常,建议降低权重", provider),
|
||
priority: Priority::Medium,
|
||
target_model: Some(*provider),
|
||
});
|
||
}
|
||
|
||
// 如果置信度过低
|
||
if result.confidence < 0.7 {
|
||
suggestions.push(OptimizationSuggestion {
|
||
suggestion_type: SuggestionType::AddTrainingData,
|
||
description: "整体置信度偏低,建议增加训练数据".to_string(),
|
||
priority: Priority::Medium,
|
||
target_model: None,
|
||
});
|
||
}
|
||
|
||
// 如果一致性评分低
|
||
if divergence.consistency_score < 70.0 {
|
||
suggestions.push(OptimizationSuggestion {
|
||
suggestion_type: SuggestionType::TuneParameters,
|
||
description: "模型一致性不足,建议调整参数".to_string(),
|
||
priority: Priority::Low,
|
||
target_model: None,
|
||
});
|
||
}
|
||
|
||
suggestions
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use chrono::Utc;
|
||
use std::collections::HashMap;
|
||
|
||
fn create_test_result(valuation: i64, confidence: f64) -> FinalValuationResult {
|
||
FinalValuationResult {
|
||
valuation_xtzh: Decimal::new(valuation, 0),
|
||
confidence,
|
||
model_results: vec![],
|
||
weights: HashMap::new(),
|
||
is_anomaly: false,
|
||
anomaly_report: None,
|
||
divergence_report: "Test".to_string(),
|
||
requires_human_review: false,
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn test_validation_rule() {
|
||
let rule = ValidationRule {
|
||
name: "Test Rule".to_string(),
|
||
description: "Test".to_string(),
|
||
min_valuation: Some(Decimal::new(1000, 0)),
|
||
max_valuation: Some(Decimal::new(10000, 0)),
|
||
min_confidence: Some(0.7),
|
||
max_model_divergence: None,
|
||
enabled: true,
|
||
};
|
||
|
||
let result = create_test_result(5000, 0.8);
|
||
let validation = rule.validate(&result);
|
||
assert!(validation.passed);
|
||
|
||
let result2 = create_test_result(500, 0.5);
|
||
let validation2 = rule.validate(&result2);
|
||
assert!(!validation2.passed);
|
||
}
|
||
|
||
#[test]
|
||
fn test_accuracy_evaluator() {
|
||
let metrics = AccuracyEvaluator::evaluate(
|
||
Decimal::new(100, 0),
|
||
Decimal::new(110, 0),
|
||
);
|
||
|
||
assert_eq!(metrics.absolute_error, Decimal::new(10, 0));
|
||
assert!(metrics.relative_error > 9.0 && metrics.relative_error < 10.0);
|
||
assert!(metrics.accuracy > 90.0);
|
||
}
|
||
|
||
#[test]
|
||
fn test_divergence_analyzer() {
|
||
let model_results = vec![
|
||
AIValuationResult {
|
||
provider: AIProvider::ChatGPT,
|
||
valuation_xtzh: Decimal::new(1000, 0),
|
||
confidence: 0.85,
|
||
reasoning: "Test".to_string(),
|
||
timestamp: Utc::now(),
|
||
},
|
||
AIValuationResult {
|
||
provider: AIProvider::DeepSeek,
|
||
valuation_xtzh: Decimal::new(1100, 0),
|
||
confidence: 0.88,
|
||
reasoning: "Test".to_string(),
|
||
timestamp: Utc::now(),
|
||
},
|
||
AIValuationResult {
|
||
provider: AIProvider::DouBao,
|
||
valuation_xtzh: Decimal::new(1050, 0),
|
||
confidence: 0.82,
|
||
reasoning: "Test".to_string(),
|
||
timestamp: Utc::now(),
|
||
},
|
||
];
|
||
|
||
let analysis = DivergenceAnalyzer::analyze(&model_results);
|
||
assert_eq!(analysis.model_count, 3);
|
||
assert!(analysis.divergence_rate < 15.0);
|
||
assert!(analysis.consistency_score > 85.0);
|
||
}
|
||
}
|