NAC_Blockchain/xtzh-ai/src/training/loss.rs

266 lines
6.6 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 损失函数
//!
//! 本模块实现XTZH AI模型的多任务损失函数
//! - 汇率预测损失MSE
//! - 动态权重损失(交叉熵)
//! - 商品偏离损失MSE
//! - 约束损失(黄金层权重、权重和)
use crate::constants::*;
use crate::error::{Error, Result};
use crate::model::{DynamicWeights, ModelOutput};
use ndarray::Array1;
use serde::{Deserialize, Serialize};
// ============================================================================
// 损失函数配置
// ============================================================================
/// 损失函数配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LossConfig {
/// 汇率损失权重
pub rate_weight: f32,
/// 动态权重损失权重
pub weights_weight: f32,
/// 商品偏离损失权重
pub deltas_weight: f32,
/// 约束损失权重
pub constraint_weight: f32,
}
impl Default for LossConfig {
fn default() -> Self {
Self {
rate_weight: 1.0,
weights_weight: 0.5,
deltas_weight: 0.3,
constraint_weight: 10.0, // 高权重确保约束满足
}
}
}
// ============================================================================
// 损失计算
// ============================================================================
/// 计算汇率预测损失MSE
///
/// # 参数
///
/// * `pred` - 预测汇率
/// * `target` - 目标汇率
///
/// # 返回
///
/// MSE损失
pub fn rate_loss(pred: f64, target: f64) -> f32 {
let diff = pred - target;
(diff * diff) as f32
}
/// 计算动态权重损失MSE
///
/// # 参数
///
/// * `pred` - 预测权重
/// * `target` - 目标权重
///
/// # 返回
///
/// MSE损失
pub fn weights_loss(pred: &DynamicWeights, target: &DynamicWeights) -> f32 {
let diff_fx = (pred.w_fx as f32 - target.w_fx as f32) / WEIGHT_SUM as f32;
let diff_au = (pred.w_au as f32 - target.w_au as f32) / WEIGHT_SUM as f32;
let diff_com = (pred.w_com as f32 - target.w_com as f32) / WEIGHT_SUM as f32;
diff_fx * diff_fx + diff_au * diff_au + diff_com * diff_com
}
/// 计算商品偏离损失MSE
///
/// # 参数
///
/// * `pred` - 预测偏离系数
/// * `target` - 目标偏离系数
///
/// # 返回
///
/// MSE损失
pub fn deltas_loss(pred: &[i8], target: &[i8]) -> Result<f32> {
if pred.len() != target.len() {
return Err(Error::DimensionMismatch {
expected: target.len(),
actual: pred.len(),
});
}
let mse: f32 = pred
.iter()
.zip(target.iter())
.map(|(&p, &t)| {
let diff = (p as f32 - t as f32) / 30.0; // 归一化到[-1, 1]
diff * diff
})
.sum::<f32>()
/ pred.len() as f32;
Ok(mse)
}
/// 计算约束损失
///
/// # 参数
///
/// * `weights` - 预测权重
///
/// # 返回
///
/// 约束损失(权重和约束 + 黄金层约束)
pub fn constraint_loss(weights: &DynamicWeights) -> f32 {
let mut loss = 0.0;
// 权重和约束
let sum = weights.w_fx + weights.w_au + weights.w_com;
if sum != WEIGHT_SUM {
let diff = (sum as f32 - WEIGHT_SUM as f32) / WEIGHT_SUM as f32;
loss += diff * diff;
}
// 黄金层权重约束5%-20%
if weights.w_au < W_AU_MIN {
let diff = (weights.w_au as f32 - W_AU_MIN as f32) / WEIGHT_SUM as f32;
loss += diff * diff;
} else if weights.w_au > W_AU_MAX {
let diff = (weights.w_au as f32 - W_AU_MAX as f32) / WEIGHT_SUM as f32;
loss += diff * diff;
}
loss
}
/// 计算总损失
///
/// # 参数
///
/// * `pred` - 预测输出
/// * `target` - 目标输出
/// * `config` - 损失函数配置
///
/// # 返回
///
/// (总损失, 各项损失)
pub fn total_loss(
pred: &ModelOutput,
target: &ModelOutput,
config: &LossConfig,
) -> Result<(f32, LossComponents)> {
let rate_l = rate_loss(pred.rate, target.rate);
let weights_l = weights_loss(&pred.weights, &target.weights);
let deltas_l = deltas_loss(&pred.commodity_deltas, &target.commodity_deltas)?;
let constraint_l = constraint_loss(&pred.weights);
let total = config.rate_weight * rate_l
+ config.weights_weight * weights_l
+ config.deltas_weight * deltas_l
+ config.constraint_weight * constraint_l;
let components = LossComponents {
rate: rate_l,
weights: weights_l,
deltas: deltas_l,
constraint: constraint_l,
total,
};
Ok((total, components))
}
/// 损失组件
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LossComponents {
/// 汇率损失
pub rate: f32,
/// 权重损失
pub weights: f32,
/// 偏离损失
pub deltas: f32,
/// 约束损失
pub constraint: f32,
/// 总损失
pub total: f32,
}
// ============================================================================
// 测试
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
#[test]
fn test_rate_loss() {
let loss = rate_loss(1.0, 1.1);
assert!((loss - 0.01).abs() < 1e-6);
}
#[test]
fn test_weights_loss() {
let pred = DynamicWeights {
w_fx: 4000,
w_au: 1000,
w_com: 5000,
};
let target = DynamicWeights {
w_fx: 4100,
w_au: 1000,
w_com: 4900,
};
let loss = weights_loss(&pred, &target);
assert!(loss > 0.0);
}
#[test]
fn test_deltas_loss() {
let pred = vec![5, -3, 2, 0, -1, 4, -2, 1, -4, 3, 0, -5, 2, 1, -3, 0, 4, -1];
let target = vec![5, -3, 2, 0, -1, 4, -2, 1, -4, 3, 0, -5, 2, 1, -3, 0, 4, -1];
let loss = deltas_loss(&pred, &target).unwrap();
assert!((loss - 0.0).abs() < 1e-6);
}
#[test]
fn test_constraint_loss_valid() {
let weights = DynamicWeights {
w_fx: 4000,
w_au: 1000,
w_com: 5000,
};
let loss = constraint_loss(&weights);
assert!((loss - 0.0).abs() < 1e-6);
}
#[test]
fn test_constraint_loss_invalid_sum() {
let weights = DynamicWeights {
w_fx: 4000,
w_au: 1000,
w_com: 4900, // 和不等于10000
};
let loss = constraint_loss(&weights);
assert!(loss > 0.0);
}
#[test]
fn test_constraint_loss_invalid_gold() {
let weights = DynamicWeights {
w_fx: 4000,
w_au: 300, // 低于5%
w_com: 5700,
};
let loss = constraint_loss(&weights);
assert!(loss > 0.0);
}
}