266 lines
6.6 KiB
Rust
266 lines
6.6 KiB
Rust
//! 损失函数
|
||
//!
|
||
//! 本模块实现XTZH AI模型的多任务损失函数:
|
||
//! - 汇率预测损失(MSE)
|
||
//! - 动态权重损失(交叉熵)
|
||
//! - 商品偏离损失(MSE)
|
||
//! - 约束损失(黄金层权重、权重和)
|
||
|
||
use crate::constants::*;
|
||
use crate::error::{Error, Result};
|
||
use crate::model::{DynamicWeights, ModelOutput};
|
||
use ndarray::Array1;
|
||
use serde::{Deserialize, Serialize};
|
||
|
||
// ============================================================================
|
||
// 损失函数配置
|
||
// ============================================================================
|
||
|
||
/// 损失函数配置
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct LossConfig {
|
||
/// 汇率损失权重
|
||
pub rate_weight: f32,
|
||
/// 动态权重损失权重
|
||
pub weights_weight: f32,
|
||
/// 商品偏离损失权重
|
||
pub deltas_weight: f32,
|
||
/// 约束损失权重
|
||
pub constraint_weight: f32,
|
||
}
|
||
|
||
impl Default for LossConfig {
|
||
fn default() -> Self {
|
||
Self {
|
||
rate_weight: 1.0,
|
||
weights_weight: 0.5,
|
||
deltas_weight: 0.3,
|
||
constraint_weight: 10.0, // 高权重确保约束满足
|
||
}
|
||
}
|
||
}
|
||
|
||
// ============================================================================
|
||
// 损失计算
|
||
// ============================================================================
|
||
|
||
/// 计算汇率预测损失(MSE)
|
||
///
|
||
/// # 参数
|
||
///
|
||
/// * `pred` - 预测汇率
|
||
/// * `target` - 目标汇率
|
||
///
|
||
/// # 返回
|
||
///
|
||
/// MSE损失
|
||
pub fn rate_loss(pred: f64, target: f64) -> f32 {
|
||
let diff = pred - target;
|
||
(diff * diff) as f32
|
||
}
|
||
|
||
/// 计算动态权重损失(MSE)
|
||
///
|
||
/// # 参数
|
||
///
|
||
/// * `pred` - 预测权重
|
||
/// * `target` - 目标权重
|
||
///
|
||
/// # 返回
|
||
///
|
||
/// MSE损失
|
||
pub fn weights_loss(pred: &DynamicWeights, target: &DynamicWeights) -> f32 {
|
||
let diff_fx = (pred.w_fx as f32 - target.w_fx as f32) / WEIGHT_SUM as f32;
|
||
let diff_au = (pred.w_au as f32 - target.w_au as f32) / WEIGHT_SUM as f32;
|
||
let diff_com = (pred.w_com as f32 - target.w_com as f32) / WEIGHT_SUM as f32;
|
||
|
||
diff_fx * diff_fx + diff_au * diff_au + diff_com * diff_com
|
||
}
|
||
|
||
/// 计算商品偏离损失(MSE)
|
||
///
|
||
/// # 参数
|
||
///
|
||
/// * `pred` - 预测偏离系数
|
||
/// * `target` - 目标偏离系数
|
||
///
|
||
/// # 返回
|
||
///
|
||
/// MSE损失
|
||
pub fn deltas_loss(pred: &[i8], target: &[i8]) -> Result<f32> {
|
||
if pred.len() != target.len() {
|
||
return Err(Error::DimensionMismatch {
|
||
expected: target.len(),
|
||
actual: pred.len(),
|
||
});
|
||
}
|
||
|
||
let mse: f32 = pred
|
||
.iter()
|
||
.zip(target.iter())
|
||
.map(|(&p, &t)| {
|
||
let diff = (p as f32 - t as f32) / 30.0; // 归一化到[-1, 1]
|
||
diff * diff
|
||
})
|
||
.sum::<f32>()
|
||
/ pred.len() as f32;
|
||
|
||
Ok(mse)
|
||
}
|
||
|
||
/// 计算约束损失
|
||
///
|
||
/// # 参数
|
||
///
|
||
/// * `weights` - 预测权重
|
||
///
|
||
/// # 返回
|
||
///
|
||
/// 约束损失(权重和约束 + 黄金层约束)
|
||
pub fn constraint_loss(weights: &DynamicWeights) -> f32 {
|
||
let mut loss = 0.0;
|
||
|
||
// 权重和约束
|
||
let sum = weights.w_fx + weights.w_au + weights.w_com;
|
||
if sum != WEIGHT_SUM {
|
||
let diff = (sum as f32 - WEIGHT_SUM as f32) / WEIGHT_SUM as f32;
|
||
loss += diff * diff;
|
||
}
|
||
|
||
// 黄金层权重约束(5%-20%)
|
||
if weights.w_au < W_AU_MIN {
|
||
let diff = (weights.w_au as f32 - W_AU_MIN as f32) / WEIGHT_SUM as f32;
|
||
loss += diff * diff;
|
||
} else if weights.w_au > W_AU_MAX {
|
||
let diff = (weights.w_au as f32 - W_AU_MAX as f32) / WEIGHT_SUM as f32;
|
||
loss += diff * diff;
|
||
}
|
||
|
||
loss
|
||
}
|
||
|
||
/// 计算总损失
|
||
///
|
||
/// # 参数
|
||
///
|
||
/// * `pred` - 预测输出
|
||
/// * `target` - 目标输出
|
||
/// * `config` - 损失函数配置
|
||
///
|
||
/// # 返回
|
||
///
|
||
/// (总损失, 各项损失)
|
||
pub fn total_loss(
|
||
pred: &ModelOutput,
|
||
target: &ModelOutput,
|
||
config: &LossConfig,
|
||
) -> Result<(f32, LossComponents)> {
|
||
let rate_l = rate_loss(pred.rate, target.rate);
|
||
let weights_l = weights_loss(&pred.weights, &target.weights);
|
||
let deltas_l = deltas_loss(&pred.commodity_deltas, &target.commodity_deltas)?;
|
||
let constraint_l = constraint_loss(&pred.weights);
|
||
|
||
let total = config.rate_weight * rate_l
|
||
+ config.weights_weight * weights_l
|
||
+ config.deltas_weight * deltas_l
|
||
+ config.constraint_weight * constraint_l;
|
||
|
||
let components = LossComponents {
|
||
rate: rate_l,
|
||
weights: weights_l,
|
||
deltas: deltas_l,
|
||
constraint: constraint_l,
|
||
total,
|
||
};
|
||
|
||
Ok((total, components))
|
||
}
|
||
|
||
/// 损失组件
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct LossComponents {
|
||
/// 汇率损失
|
||
pub rate: f32,
|
||
/// 权重损失
|
||
pub weights: f32,
|
||
/// 偏离损失
|
||
pub deltas: f32,
|
||
/// 约束损失
|
||
pub constraint: f32,
|
||
/// 总损失
|
||
pub total: f32,
|
||
}
|
||
|
||
// ============================================================================
|
||
// 测试
|
||
// ============================================================================
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use chrono::Utc;
|
||
|
||
#[test]
|
||
fn test_rate_loss() {
|
||
let loss = rate_loss(1.0, 1.1);
|
||
assert!((loss - 0.01).abs() < 1e-6);
|
||
}
|
||
|
||
#[test]
|
||
fn test_weights_loss() {
|
||
let pred = DynamicWeights {
|
||
w_fx: 4000,
|
||
w_au: 1000,
|
||
w_com: 5000,
|
||
};
|
||
let target = DynamicWeights {
|
||
w_fx: 4100,
|
||
w_au: 1000,
|
||
w_com: 4900,
|
||
};
|
||
let loss = weights_loss(&pred, &target);
|
||
assert!(loss > 0.0);
|
||
}
|
||
|
||
#[test]
|
||
fn test_deltas_loss() {
|
||
let pred = vec![5, -3, 2, 0, -1, 4, -2, 1, -4, 3, 0, -5, 2, 1, -3, 0, 4, -1];
|
||
let target = vec![5, -3, 2, 0, -1, 4, -2, 1, -4, 3, 0, -5, 2, 1, -3, 0, 4, -1];
|
||
let loss = deltas_loss(&pred, &target).unwrap();
|
||
assert!((loss - 0.0).abs() < 1e-6);
|
||
}
|
||
|
||
#[test]
|
||
fn test_constraint_loss_valid() {
|
||
let weights = DynamicWeights {
|
||
w_fx: 4000,
|
||
w_au: 1000,
|
||
w_com: 5000,
|
||
};
|
||
let loss = constraint_loss(&weights);
|
||
assert!((loss - 0.0).abs() < 1e-6);
|
||
}
|
||
|
||
#[test]
|
||
fn test_constraint_loss_invalid_sum() {
|
||
let weights = DynamicWeights {
|
||
w_fx: 4000,
|
||
w_au: 1000,
|
||
w_com: 4900, // 和不等于10000
|
||
};
|
||
let loss = constraint_loss(&weights);
|
||
assert!(loss > 0.0);
|
||
}
|
||
|
||
#[test]
|
||
fn test_constraint_loss_invalid_gold() {
|
||
let weights = DynamicWeights {
|
||
w_fx: 4000,
|
||
w_au: 300, // 低于5%
|
||
w_com: 5700,
|
||
};
|
||
let loss = constraint_loss(&weights);
|
||
assert!(loss > 0.0);
|
||
}
|
||
}
|