//! 损失函数 //! //! 本模块实现XTZH AI模型的多任务损失函数: //! - 汇率预测损失(MSE) //! - 动态权重损失(交叉熵) //! - 商品偏离损失(MSE) //! - 约束损失(黄金层权重、权重和) use crate::constants::*; use crate::error::{Error, Result}; use crate::model::{DynamicWeights, ModelOutput}; use ndarray::Array1; use serde::{Deserialize, Serialize}; // ============================================================================ // 损失函数配置 // ============================================================================ /// 损失函数配置 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LossConfig { /// 汇率损失权重 pub rate_weight: f32, /// 动态权重损失权重 pub weights_weight: f32, /// 商品偏离损失权重 pub deltas_weight: f32, /// 约束损失权重 pub constraint_weight: f32, } impl Default for LossConfig { fn default() -> Self { Self { rate_weight: 1.0, weights_weight: 0.5, deltas_weight: 0.3, constraint_weight: 10.0, // 高权重确保约束满足 } } } // ============================================================================ // 损失计算 // ============================================================================ /// 计算汇率预测损失(MSE) /// /// # 参数 /// /// * `pred` - 预测汇率 /// * `target` - 目标汇率 /// /// # 返回 /// /// MSE损失 pub fn rate_loss(pred: f64, target: f64) -> f32 { let diff = pred - target; (diff * diff) as f32 } /// 计算动态权重损失(MSE) /// /// # 参数 /// /// * `pred` - 预测权重 /// * `target` - 目标权重 /// /// # 返回 /// /// MSE损失 pub fn weights_loss(pred: &DynamicWeights, target: &DynamicWeights) -> f32 { let diff_fx = (pred.w_fx as f32 - target.w_fx as f32) / WEIGHT_SUM as f32; let diff_au = (pred.w_au as f32 - target.w_au as f32) / WEIGHT_SUM as f32; let diff_com = (pred.w_com as f32 - target.w_com as f32) / WEIGHT_SUM as f32; diff_fx * diff_fx + diff_au * diff_au + diff_com * diff_com } /// 计算商品偏离损失(MSE) /// /// # 参数 /// /// * `pred` - 预测偏离系数 /// * `target` - 目标偏离系数 /// /// # 返回 /// /// MSE损失 pub fn deltas_loss(pred: &[i8], target: &[i8]) -> Result { if pred.len() != target.len() { return Err(Error::DimensionMismatch { expected: target.len(), actual: pred.len(), }); } let mse: f32 = pred .iter() .zip(target.iter()) .map(|(&p, &t)| { let diff = (p as f32 - t as f32) / 30.0; // 归一化到[-1, 1] diff * diff }) .sum::() / pred.len() as f32; Ok(mse) } /// 计算约束损失 /// /// # 参数 /// /// * `weights` - 预测权重 /// /// # 返回 /// /// 约束损失(权重和约束 + 黄金层约束) pub fn constraint_loss(weights: &DynamicWeights) -> f32 { let mut loss = 0.0; // 权重和约束 let sum = weights.w_fx + weights.w_au + weights.w_com; if sum != WEIGHT_SUM { let diff = (sum as f32 - WEIGHT_SUM as f32) / WEIGHT_SUM as f32; loss += diff * diff; } // 黄金层权重约束(5%-20%) if weights.w_au < W_AU_MIN { let diff = (weights.w_au as f32 - W_AU_MIN as f32) / WEIGHT_SUM as f32; loss += diff * diff; } else if weights.w_au > W_AU_MAX { let diff = (weights.w_au as f32 - W_AU_MAX as f32) / WEIGHT_SUM as f32; loss += diff * diff; } loss } /// 计算总损失 /// /// # 参数 /// /// * `pred` - 预测输出 /// * `target` - 目标输出 /// * `config` - 损失函数配置 /// /// # 返回 /// /// (总损失, 各项损失) pub fn total_loss( pred: &ModelOutput, target: &ModelOutput, config: &LossConfig, ) -> Result<(f32, LossComponents)> { let rate_l = rate_loss(pred.rate, target.rate); let weights_l = weights_loss(&pred.weights, &target.weights); let deltas_l = deltas_loss(&pred.commodity_deltas, &target.commodity_deltas)?; let constraint_l = constraint_loss(&pred.weights); let total = config.rate_weight * rate_l + config.weights_weight * weights_l + config.deltas_weight * deltas_l + config.constraint_weight * constraint_l; let components = LossComponents { rate: rate_l, weights: weights_l, deltas: deltas_l, constraint: constraint_l, total, }; Ok((total, components)) } /// 损失组件 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LossComponents { /// 汇率损失 pub rate: f32, /// 权重损失 pub weights: f32, /// 偏离损失 pub deltas: f32, /// 约束损失 pub constraint: f32, /// 总损失 pub total: f32, } // ============================================================================ // 测试 // ============================================================================ #[cfg(test)] mod tests { use super::*; use chrono::Utc; #[test] fn test_rate_loss() { let loss = rate_loss(1.0, 1.1); assert!((loss - 0.01).abs() < 1e-6); } #[test] fn test_weights_loss() { let pred = DynamicWeights { w_fx: 4000, w_au: 1000, w_com: 5000, }; let target = DynamicWeights { w_fx: 4100, w_au: 1000, w_com: 4900, }; let loss = weights_loss(&pred, &target); assert!(loss > 0.0); } #[test] fn test_deltas_loss() { let pred = vec![5, -3, 2, 0, -1, 4, -2, 1, -4, 3, 0, -5, 2, 1, -3, 0, 4, -1]; let target = vec![5, -3, 2, 0, -1, 4, -2, 1, -4, 3, 0, -5, 2, 1, -3, 0, 4, -1]; let loss = deltas_loss(&pred, &target).unwrap(); assert!((loss - 0.0).abs() < 1e-6); } #[test] fn test_constraint_loss_valid() { let weights = DynamicWeights { w_fx: 4000, w_au: 1000, w_com: 5000, }; let loss = constraint_loss(&weights); assert!((loss - 0.0).abs() < 1e-6); } #[test] fn test_constraint_loss_invalid_sum() { let weights = DynamicWeights { w_fx: 4000, w_au: 1000, w_com: 4900, // 和不等于10000 }; let loss = constraint_loss(&weights); assert!(loss > 0.0); } #[test] fn test_constraint_loss_invalid_gold() { let weights = DynamicWeights { w_fx: 4000, w_au: 300, // 低于5% w_com: 5700, }; let loss = constraint_loss(&weights); assert!(loss > 0.0); } }