NAC_Blockchain/xtzh-ai/src/model/transformer.rs

479 lines
13 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! Transformer编码器
//!
//! 本模块实现4层8头的Transformer编码器用于提取50维宏观特征的时序依赖。
use crate::constants::*;
use crate::error::{Error, Result};
use ndarray::{Array1, Array2};
use serde::{Deserialize, Serialize};
// ============================================================================
// 多头注意力层
// ============================================================================
/// 多头注意力层
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiHeadAttention {
/// 注意力头数
num_heads: usize,
/// 嵌入维度
embed_dim: usize,
/// 每个头的维度
head_dim: usize,
/// Query权重矩阵embed_dim × embed_dim
w_q: Array2<f32>,
/// Key权重矩阵embed_dim × embed_dim
w_k: Array2<f32>,
/// Value权重矩阵embed_dim × embed_dim
w_v: Array2<f32>,
/// 输出权重矩阵embed_dim × embed_dim
w_o: Array2<f32>,
}
impl MultiHeadAttention {
/// 创建新的多头注意力层
///
/// # 参数
///
/// * `num_heads` - 注意力头数
/// * `embed_dim` - 嵌入维度
pub fn new(num_heads: usize, embed_dim: usize) -> Result<Self> {
if embed_dim % num_heads != 0 {
return Err(Error::ArchitectureError(format!(
"嵌入维度 {} 必须能被注意力头数 {} 整除",
embed_dim, num_heads
)));
}
let head_dim = embed_dim / num_heads;
// 初始化权重矩阵Xavier初始化
let scale = (embed_dim as f32).sqrt();
let w_q = Array2::zeros((embed_dim, embed_dim)) / scale;
let w_k = Array2::zeros((embed_dim, embed_dim)) / scale;
let w_v = Array2::zeros((embed_dim, embed_dim)) / scale;
let w_o = Array2::zeros((embed_dim, embed_dim)) / scale;
Ok(Self {
num_heads,
embed_dim,
head_dim,
w_q,
w_k,
w_v,
w_o,
})
}
/// 前向传播
///
/// # 参数
///
/// * `x` - 输入张量seq_len × embed_dim
///
/// # 返回
///
/// 输出张量seq_len × embed_dim
pub fn forward(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
let (seq_len, embed_dim) = x.dim();
if embed_dim != self.embed_dim {
return Err(Error::DimensionMismatch {
expected: self.embed_dim,
actual: embed_dim,
});
}
// Q = X @ W_q, K = X @ W_k, V = X @ W_v
let q = x.dot(&self.w_q);
let k = x.dot(&self.w_k);
let v = x.dot(&self.w_v);
// 分割为多头:(seq_len, embed_dim) -> (num_heads, seq_len, head_dim)
// 简化实现直接在embed_dim维度上计算注意力
// 注意力分数Q @ K^T / sqrt(head_dim)
let scale = (self.head_dim as f32).sqrt();
let scores = q.dot(&k.t()) / scale;
// Softmax简化实现沿最后一维
let attention = Self::softmax(&scores);
// 加权求和Attention @ V
let output = attention.dot(&v);
// 输出投影Output @ W_o
let result = output.dot(&self.w_o);
Ok(result)
}
/// Softmax激活函数简化实现
fn softmax(x: &Array2<f32>) -> Array2<f32> {
let mut result = x.clone();
for mut row in result.rows_mut() {
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
for val in row.iter_mut() {
*val = (*val - max).exp();
}
let sum: f32 = row.iter().sum();
for val in row.iter_mut() {
*val /= sum;
}
}
result
}
}
// ============================================================================
// 前馈神经网络
// ============================================================================
/// 前馈神经网络FFN
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeedForward {
/// 输入维度
input_dim: usize,
/// 隐藏层维度通常为4倍输入维度
hidden_dim: usize,
/// 第一层权重input_dim × hidden_dim
w1: Array2<f32>,
/// 第一层偏置
b1: Array1<f32>,
/// 第二层权重hidden_dim × input_dim
w2: Array2<f32>,
/// 第二层偏置
b2: Array1<f32>,
}
impl FeedForward {
/// 创建新的前馈神经网络
///
/// # 参数
///
/// * `input_dim` - 输入维度
/// * `hidden_dim` - 隐藏层维度
pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
let scale = (input_dim as f32).sqrt();
let w1 = Array2::zeros((input_dim, hidden_dim)) / scale;
let b1 = Array1::zeros(hidden_dim);
let w2 = Array2::zeros((hidden_dim, input_dim)) / scale;
let b2 = Array1::zeros(input_dim);
Self {
input_dim,
hidden_dim,
w1,
b1,
w2,
b2,
}
}
/// 前向传播
///
/// # 参数
///
/// * `x` - 输入张量seq_len × input_dim
///
/// # 返回
///
/// 输出张量seq_len × input_dim
pub fn forward(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
let (seq_len, input_dim) = x.dim();
if input_dim != self.input_dim {
return Err(Error::DimensionMismatch {
expected: self.input_dim,
actual: input_dim,
});
}
// 第一层X @ W1 + b1
let mut hidden = x.dot(&self.w1);
for mut row in hidden.rows_mut() {
for (val, &bias) in row.iter_mut().zip(self.b1.iter()) {
*val += bias;
}
}
// GELU激活函数
Self::gelu_inplace(&mut hidden);
// 第二层Hidden @ W2 + b2
let mut output = hidden.dot(&self.w2);
for mut row in output.rows_mut() {
for (val, &bias) in row.iter_mut().zip(self.b2.iter()) {
*val += bias;
}
}
Ok(output)
}
/// GELU激活函数原地操作
fn gelu_inplace(x: &mut Array2<f32>) {
for val in x.iter_mut() {
*val = Self::gelu(*val);
}
}
/// GELU激活函数
fn gelu(x: f32) -> f32 {
0.5 * x * (1.0 + (0.7978845608 * (x + 0.044715 * x.powi(3))).tanh())
}
}
// ============================================================================
// Layer Normalization
// ============================================================================
/// Layer Normalization
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerNorm {
/// 归一化维度
dim: usize,
/// 缩放参数gamma
gamma: Array1<f32>,
/// 偏移参数beta
beta: Array1<f32>,
/// 数值稳定性常数
eps: f32,
}
impl LayerNorm {
/// 创建新的Layer Normalization
///
/// # 参数
///
/// * `dim` - 归一化维度
pub fn new(dim: usize) -> Self {
Self {
dim,
gamma: Array1::ones(dim),
beta: Array1::zeros(dim),
eps: 1e-5,
}
}
/// 前向传播
///
/// # 参数
///
/// * `x` - 输入张量seq_len × dim
///
/// # 返回
///
/// 归一化后的张量
pub fn forward(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
let (seq_len, dim) = x.dim();
if dim != self.dim {
return Err(Error::DimensionMismatch {
expected: self.dim,
actual: dim,
});
}
let mut result = x.clone();
for mut row in result.rows_mut() {
// 计算均值和方差
let mean = row.mean().unwrap_or(0.0);
let variance = row.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / dim as f32;
let std = (variance + self.eps).sqrt();
// 归一化:(x - mean) / std
for val in row.iter_mut() {
*val = (*val - mean) / std;
}
// 缩放和偏移gamma * x + beta
for (val, (&gamma, &beta)) in row.iter_mut().zip(self.gamma.iter().zip(self.beta.iter())) {
*val = gamma * *val + beta;
}
}
Ok(result)
}
}
// ============================================================================
// Transformer编码器层
// ============================================================================
/// Transformer编码器层
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransformerEncoderLayer {
/// 多头注意力层
attention: MultiHeadAttention,
/// 前馈神经网络
ffn: FeedForward,
/// 第一个Layer Norm
norm1: LayerNorm,
/// 第二个Layer Norm
norm2: LayerNorm,
}
impl TransformerEncoderLayer {
/// 创建新的Transformer编码器层
///
/// # 参数
///
/// * `num_heads` - 注意力头数
/// * `embed_dim` - 嵌入维度
/// * `ffn_dim` - 前馈网络隐藏层维度
pub fn new(num_heads: usize, embed_dim: usize, ffn_dim: usize) -> Result<Self> {
Ok(Self {
attention: MultiHeadAttention::new(num_heads, embed_dim)?,
ffn: FeedForward::new(embed_dim, ffn_dim),
norm1: LayerNorm::new(embed_dim),
norm2: LayerNorm::new(embed_dim),
})
}
/// 前向传播
///
/// # 参数
///
/// * `x` - 输入张量seq_len × embed_dim
///
/// # 返回
///
/// 输出张量seq_len × embed_dim
pub fn forward(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
// 多头注意力 + 残差连接 + Layer Norm
let attn_output = self.attention.forward(x)?;
let x = &(x + &attn_output);
let x = self.norm1.forward(x)?;
// 前馈网络 + 残差连接 + Layer Norm
let ffn_output = self.ffn.forward(&x)?;
let x = &x + &ffn_output;
let x = self.norm2.forward(&x)?;
Ok(x)
}
}
// ============================================================================
// Transformer编码器
// ============================================================================
/// Transformer编码器4层
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransformerEncoder {
/// 编码器层列表
layers: Vec<TransformerEncoderLayer>,
/// 嵌入维度
embed_dim: usize,
}
impl TransformerEncoder {
/// 创建新的Transformer编码器
///
/// # 参数
///
/// * `num_layers` - 编码器层数
/// * `num_heads` - 注意力头数
/// * `embed_dim` - 嵌入维度
/// * `ffn_dim` - 前馈网络隐藏层维度
pub fn new(
num_layers: usize,
num_heads: usize,
embed_dim: usize,
ffn_dim: usize,
) -> Result<Self> {
let mut layers = Vec::with_capacity(num_layers);
for _ in 0..num_layers {
layers.push(TransformerEncoderLayer::new(num_heads, embed_dim, ffn_dim)?);
}
Ok(Self { layers, embed_dim })
}
/// 前向传播
///
/// # 参数
///
/// * `x` - 输入张量seq_len × embed_dim
///
/// # 返回
///
/// 输出张量seq_len × embed_dim
pub fn forward(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
let mut output = x.clone();
for layer in &self.layers {
output = layer.forward(&output)?;
}
Ok(output)
}
/// 获取嵌入维度
pub fn embed_dim(&self) -> usize {
self.embed_dim
}
}
// ============================================================================
// 测试
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multi_head_attention() {
let mha = MultiHeadAttention::new(ATTENTION_HEADS, EMBED_DIM).unwrap();
let x = Array2::zeros((10, EMBED_DIM)); // seq_len=10
let output = mha.forward(&x).unwrap();
assert_eq!(output.dim(), (10, EMBED_DIM));
}
#[test]
fn test_feed_forward() {
let ffn = FeedForward::new(EMBED_DIM, EMBED_DIM * 4);
let x = Array2::zeros((10, EMBED_DIM));
let output = ffn.forward(&x).unwrap();
assert_eq!(output.dim(), (10, EMBED_DIM));
}
#[test]
fn test_layer_norm() {
let ln = LayerNorm::new(EMBED_DIM);
let x = Array2::ones((10, EMBED_DIM));
let output = ln.forward(&x).unwrap();
assert_eq!(output.dim(), (10, EMBED_DIM));
}
#[test]
fn test_transformer_encoder_layer() {
let layer = TransformerEncoderLayer::new(ATTENTION_HEADS, EMBED_DIM, EMBED_DIM * 4).unwrap();
let x = Array2::zeros((10, EMBED_DIM));
let output = layer.forward(&x).unwrap();
assert_eq!(output.dim(), (10, EMBED_DIM));
}
#[test]
fn test_transformer_encoder() {
let encoder = TransformerEncoder::new(
TRANSFORMER_LAYERS,
ATTENTION_HEADS,
EMBED_DIM,
EMBED_DIM * 4,
)
.unwrap();
let x = Array2::zeros((10, EMBED_DIM));
let output = encoder.forward(&x).unwrap();
assert_eq!(output.dim(), (10, EMBED_DIM));
}
#[test]
fn test_dimension_mismatch() {
let mha = MultiHeadAttention::new(ATTENTION_HEADS, EMBED_DIM).unwrap();
let x = Array2::zeros((10, 128)); // 错误的维度
assert!(mha.forward(&x).is_err());
}
}