479 lines
13 KiB
Rust
479 lines
13 KiB
Rust
//! Transformer编码器
|
||
//!
|
||
//! 本模块实现4层8头的Transformer编码器,用于提取50维宏观特征的时序依赖。
|
||
|
||
use crate::constants::*;
|
||
use crate::error::{Error, Result};
|
||
use ndarray::{Array1, Array2};
|
||
use serde::{Deserialize, Serialize};
|
||
|
||
// ============================================================================
|
||
// 多头注意力层
|
||
// ============================================================================
|
||
|
||
/// 多头注意力层
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct MultiHeadAttention {
|
||
/// 注意力头数
|
||
num_heads: usize,
|
||
/// 嵌入维度
|
||
embed_dim: usize,
|
||
/// 每个头的维度
|
||
head_dim: usize,
|
||
/// Query权重矩阵(embed_dim × embed_dim)
|
||
w_q: Array2<f32>,
|
||
/// Key权重矩阵(embed_dim × embed_dim)
|
||
w_k: Array2<f32>,
|
||
/// Value权重矩阵(embed_dim × embed_dim)
|
||
w_v: Array2<f32>,
|
||
/// 输出权重矩阵(embed_dim × embed_dim)
|
||
w_o: Array2<f32>,
|
||
}
|
||
|
||
impl MultiHeadAttention {
|
||
/// 创建新的多头注意力层
|
||
///
|
||
/// # 参数
|
||
///
|
||
/// * `num_heads` - 注意力头数
|
||
/// * `embed_dim` - 嵌入维度
|
||
pub fn new(num_heads: usize, embed_dim: usize) -> Result<Self> {
|
||
if embed_dim % num_heads != 0 {
|
||
return Err(Error::ArchitectureError(format!(
|
||
"嵌入维度 {} 必须能被注意力头数 {} 整除",
|
||
embed_dim, num_heads
|
||
)));
|
||
}
|
||
|
||
let head_dim = embed_dim / num_heads;
|
||
|
||
// 初始化权重矩阵(Xavier初始化)
|
||
let scale = (embed_dim as f32).sqrt();
|
||
let w_q = Array2::zeros((embed_dim, embed_dim)) / scale;
|
||
let w_k = Array2::zeros((embed_dim, embed_dim)) / scale;
|
||
let w_v = Array2::zeros((embed_dim, embed_dim)) / scale;
|
||
let w_o = Array2::zeros((embed_dim, embed_dim)) / scale;
|
||
|
||
Ok(Self {
|
||
num_heads,
|
||
embed_dim,
|
||
head_dim,
|
||
w_q,
|
||
w_k,
|
||
w_v,
|
||
w_o,
|
||
})
|
||
}
|
||
|
||
/// 前向传播
|
||
///
|
||
/// # 参数
|
||
///
|
||
/// * `x` - 输入张量(seq_len × embed_dim)
|
||
///
|
||
/// # 返回
|
||
///
|
||
/// 输出张量(seq_len × embed_dim)
|
||
pub fn forward(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
|
||
let (seq_len, embed_dim) = x.dim();
|
||
|
||
if embed_dim != self.embed_dim {
|
||
return Err(Error::DimensionMismatch {
|
||
expected: self.embed_dim,
|
||
actual: embed_dim,
|
||
});
|
||
}
|
||
|
||
// Q = X @ W_q, K = X @ W_k, V = X @ W_v
|
||
let q = x.dot(&self.w_q);
|
||
let k = x.dot(&self.w_k);
|
||
let v = x.dot(&self.w_v);
|
||
|
||
// 分割为多头:(seq_len, embed_dim) -> (num_heads, seq_len, head_dim)
|
||
// 简化实现:直接在embed_dim维度上计算注意力
|
||
|
||
// 注意力分数:Q @ K^T / sqrt(head_dim)
|
||
let scale = (self.head_dim as f32).sqrt();
|
||
let scores = q.dot(&k.t()) / scale;
|
||
|
||
// Softmax(简化实现:沿最后一维)
|
||
let attention = Self::softmax(&scores);
|
||
|
||
// 加权求和:Attention @ V
|
||
let output = attention.dot(&v);
|
||
|
||
// 输出投影:Output @ W_o
|
||
let result = output.dot(&self.w_o);
|
||
|
||
Ok(result)
|
||
}
|
||
|
||
/// Softmax激活函数(简化实现)
|
||
fn softmax(x: &Array2<f32>) -> Array2<f32> {
|
||
let mut result = x.clone();
|
||
for mut row in result.rows_mut() {
|
||
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||
for val in row.iter_mut() {
|
||
*val = (*val - max).exp();
|
||
}
|
||
let sum: f32 = row.iter().sum();
|
||
for val in row.iter_mut() {
|
||
*val /= sum;
|
||
}
|
||
}
|
||
result
|
||
}
|
||
}
|
||
|
||
// ============================================================================
|
||
// 前馈神经网络
|
||
// ============================================================================
|
||
|
||
/// 前馈神经网络(FFN)
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct FeedForward {
|
||
/// 输入维度
|
||
input_dim: usize,
|
||
/// 隐藏层维度(通常为4倍输入维度)
|
||
hidden_dim: usize,
|
||
/// 第一层权重(input_dim × hidden_dim)
|
||
w1: Array2<f32>,
|
||
/// 第一层偏置
|
||
b1: Array1<f32>,
|
||
/// 第二层权重(hidden_dim × input_dim)
|
||
w2: Array2<f32>,
|
||
/// 第二层偏置
|
||
b2: Array1<f32>,
|
||
}
|
||
|
||
impl FeedForward {
|
||
/// 创建新的前馈神经网络
|
||
///
|
||
/// # 参数
|
||
///
|
||
/// * `input_dim` - 输入维度
|
||
/// * `hidden_dim` - 隐藏层维度
|
||
pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
|
||
let scale = (input_dim as f32).sqrt();
|
||
let w1 = Array2::zeros((input_dim, hidden_dim)) / scale;
|
||
let b1 = Array1::zeros(hidden_dim);
|
||
let w2 = Array2::zeros((hidden_dim, input_dim)) / scale;
|
||
let b2 = Array1::zeros(input_dim);
|
||
|
||
Self {
|
||
input_dim,
|
||
hidden_dim,
|
||
w1,
|
||
b1,
|
||
w2,
|
||
b2,
|
||
}
|
||
}
|
||
|
||
/// 前向传播
|
||
///
|
||
/// # 参数
|
||
///
|
||
/// * `x` - 输入张量(seq_len × input_dim)
|
||
///
|
||
/// # 返回
|
||
///
|
||
/// 输出张量(seq_len × input_dim)
|
||
pub fn forward(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
|
||
let (seq_len, input_dim) = x.dim();
|
||
|
||
if input_dim != self.input_dim {
|
||
return Err(Error::DimensionMismatch {
|
||
expected: self.input_dim,
|
||
actual: input_dim,
|
||
});
|
||
}
|
||
|
||
// 第一层:X @ W1 + b1
|
||
let mut hidden = x.dot(&self.w1);
|
||
for mut row in hidden.rows_mut() {
|
||
for (val, &bias) in row.iter_mut().zip(self.b1.iter()) {
|
||
*val += bias;
|
||
}
|
||
}
|
||
|
||
// GELU激活函数
|
||
Self::gelu_inplace(&mut hidden);
|
||
|
||
// 第二层:Hidden @ W2 + b2
|
||
let mut output = hidden.dot(&self.w2);
|
||
for mut row in output.rows_mut() {
|
||
for (val, &bias) in row.iter_mut().zip(self.b2.iter()) {
|
||
*val += bias;
|
||
}
|
||
}
|
||
|
||
Ok(output)
|
||
}
|
||
|
||
/// GELU激活函数(原地操作)
|
||
fn gelu_inplace(x: &mut Array2<f32>) {
|
||
for val in x.iter_mut() {
|
||
*val = Self::gelu(*val);
|
||
}
|
||
}
|
||
|
||
/// GELU激活函数
|
||
fn gelu(x: f32) -> f32 {
|
||
0.5 * x * (1.0 + (0.7978845608 * (x + 0.044715 * x.powi(3))).tanh())
|
||
}
|
||
}
|
||
|
||
// ============================================================================
|
||
// Layer Normalization
|
||
// ============================================================================
|
||
|
||
/// Layer Normalization
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct LayerNorm {
|
||
/// 归一化维度
|
||
dim: usize,
|
||
/// 缩放参数(gamma)
|
||
gamma: Array1<f32>,
|
||
/// 偏移参数(beta)
|
||
beta: Array1<f32>,
|
||
/// 数值稳定性常数
|
||
eps: f32,
|
||
}
|
||
|
||
impl LayerNorm {
|
||
/// 创建新的Layer Normalization
|
||
///
|
||
/// # 参数
|
||
///
|
||
/// * `dim` - 归一化维度
|
||
pub fn new(dim: usize) -> Self {
|
||
Self {
|
||
dim,
|
||
gamma: Array1::ones(dim),
|
||
beta: Array1::zeros(dim),
|
||
eps: 1e-5,
|
||
}
|
||
}
|
||
|
||
/// 前向传播
|
||
///
|
||
/// # 参数
|
||
///
|
||
/// * `x` - 输入张量(seq_len × dim)
|
||
///
|
||
/// # 返回
|
||
///
|
||
/// 归一化后的张量
|
||
pub fn forward(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
|
||
let (seq_len, dim) = x.dim();
|
||
|
||
if dim != self.dim {
|
||
return Err(Error::DimensionMismatch {
|
||
expected: self.dim,
|
||
actual: dim,
|
||
});
|
||
}
|
||
|
||
let mut result = x.clone();
|
||
|
||
for mut row in result.rows_mut() {
|
||
// 计算均值和方差
|
||
let mean = row.mean().unwrap_or(0.0);
|
||
let variance = row.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / dim as f32;
|
||
let std = (variance + self.eps).sqrt();
|
||
|
||
// 归一化:(x - mean) / std
|
||
for val in row.iter_mut() {
|
||
*val = (*val - mean) / std;
|
||
}
|
||
|
||
// 缩放和偏移:gamma * x + beta
|
||
for (val, (&gamma, &beta)) in row.iter_mut().zip(self.gamma.iter().zip(self.beta.iter())) {
|
||
*val = gamma * *val + beta;
|
||
}
|
||
}
|
||
|
||
Ok(result)
|
||
}
|
||
}
|
||
|
||
// ============================================================================
|
||
// Transformer编码器层
|
||
// ============================================================================
|
||
|
||
/// Transformer编码器层
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct TransformerEncoderLayer {
|
||
/// 多头注意力层
|
||
attention: MultiHeadAttention,
|
||
/// 前馈神经网络
|
||
ffn: FeedForward,
|
||
/// 第一个Layer Norm
|
||
norm1: LayerNorm,
|
||
/// 第二个Layer Norm
|
||
norm2: LayerNorm,
|
||
}
|
||
|
||
impl TransformerEncoderLayer {
|
||
/// 创建新的Transformer编码器层
|
||
///
|
||
/// # 参数
|
||
///
|
||
/// * `num_heads` - 注意力头数
|
||
/// * `embed_dim` - 嵌入维度
|
||
/// * `ffn_dim` - 前馈网络隐藏层维度
|
||
pub fn new(num_heads: usize, embed_dim: usize, ffn_dim: usize) -> Result<Self> {
|
||
Ok(Self {
|
||
attention: MultiHeadAttention::new(num_heads, embed_dim)?,
|
||
ffn: FeedForward::new(embed_dim, ffn_dim),
|
||
norm1: LayerNorm::new(embed_dim),
|
||
norm2: LayerNorm::new(embed_dim),
|
||
})
|
||
}
|
||
|
||
/// 前向传播
|
||
///
|
||
/// # 参数
|
||
///
|
||
/// * `x` - 输入张量(seq_len × embed_dim)
|
||
///
|
||
/// # 返回
|
||
///
|
||
/// 输出张量(seq_len × embed_dim)
|
||
pub fn forward(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
|
||
// 多头注意力 + 残差连接 + Layer Norm
|
||
let attn_output = self.attention.forward(x)?;
|
||
let x = &(x + &attn_output);
|
||
let x = self.norm1.forward(x)?;
|
||
|
||
// 前馈网络 + 残差连接 + Layer Norm
|
||
let ffn_output = self.ffn.forward(&x)?;
|
||
let x = &x + &ffn_output;
|
||
let x = self.norm2.forward(&x)?;
|
||
|
||
Ok(x)
|
||
}
|
||
}
|
||
|
||
// ============================================================================
|
||
// Transformer编码器
|
||
// ============================================================================
|
||
|
||
/// Transformer编码器(4层)
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct TransformerEncoder {
|
||
/// 编码器层列表
|
||
layers: Vec<TransformerEncoderLayer>,
|
||
/// 嵌入维度
|
||
embed_dim: usize,
|
||
}
|
||
|
||
impl TransformerEncoder {
|
||
/// 创建新的Transformer编码器
|
||
///
|
||
/// # 参数
|
||
///
|
||
/// * `num_layers` - 编码器层数
|
||
/// * `num_heads` - 注意力头数
|
||
/// * `embed_dim` - 嵌入维度
|
||
/// * `ffn_dim` - 前馈网络隐藏层维度
|
||
pub fn new(
|
||
num_layers: usize,
|
||
num_heads: usize,
|
||
embed_dim: usize,
|
||
ffn_dim: usize,
|
||
) -> Result<Self> {
|
||
let mut layers = Vec::with_capacity(num_layers);
|
||
for _ in 0..num_layers {
|
||
layers.push(TransformerEncoderLayer::new(num_heads, embed_dim, ffn_dim)?);
|
||
}
|
||
|
||
Ok(Self { layers, embed_dim })
|
||
}
|
||
|
||
/// 前向传播
|
||
///
|
||
/// # 参数
|
||
///
|
||
/// * `x` - 输入张量(seq_len × embed_dim)
|
||
///
|
||
/// # 返回
|
||
///
|
||
/// 输出张量(seq_len × embed_dim)
|
||
pub fn forward(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
|
||
let mut output = x.clone();
|
||
for layer in &self.layers {
|
||
output = layer.forward(&output)?;
|
||
}
|
||
Ok(output)
|
||
}
|
||
|
||
/// 获取嵌入维度
|
||
pub fn embed_dim(&self) -> usize {
|
||
self.embed_dim
|
||
}
|
||
}
|
||
|
||
// ============================================================================
|
||
// 测试
|
||
// ============================================================================
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn test_multi_head_attention() {
|
||
let mha = MultiHeadAttention::new(ATTENTION_HEADS, EMBED_DIM).unwrap();
|
||
let x = Array2::zeros((10, EMBED_DIM)); // seq_len=10
|
||
let output = mha.forward(&x).unwrap();
|
||
assert_eq!(output.dim(), (10, EMBED_DIM));
|
||
}
|
||
|
||
#[test]
|
||
fn test_feed_forward() {
|
||
let ffn = FeedForward::new(EMBED_DIM, EMBED_DIM * 4);
|
||
let x = Array2::zeros((10, EMBED_DIM));
|
||
let output = ffn.forward(&x).unwrap();
|
||
assert_eq!(output.dim(), (10, EMBED_DIM));
|
||
}
|
||
|
||
#[test]
|
||
fn test_layer_norm() {
|
||
let ln = LayerNorm::new(EMBED_DIM);
|
||
let x = Array2::ones((10, EMBED_DIM));
|
||
let output = ln.forward(&x).unwrap();
|
||
assert_eq!(output.dim(), (10, EMBED_DIM));
|
||
}
|
||
|
||
#[test]
|
||
fn test_transformer_encoder_layer() {
|
||
let layer = TransformerEncoderLayer::new(ATTENTION_HEADS, EMBED_DIM, EMBED_DIM * 4).unwrap();
|
||
let x = Array2::zeros((10, EMBED_DIM));
|
||
let output = layer.forward(&x).unwrap();
|
||
assert_eq!(output.dim(), (10, EMBED_DIM));
|
||
}
|
||
|
||
#[test]
|
||
fn test_transformer_encoder() {
|
||
let encoder = TransformerEncoder::new(
|
||
TRANSFORMER_LAYERS,
|
||
ATTENTION_HEADS,
|
||
EMBED_DIM,
|
||
EMBED_DIM * 4,
|
||
)
|
||
.unwrap();
|
||
let x = Array2::zeros((10, EMBED_DIM));
|
||
let output = encoder.forward(&x).unwrap();
|
||
assert_eq!(output.dim(), (10, EMBED_DIM));
|
||
}
|
||
|
||
#[test]
|
||
fn test_dimension_mismatch() {
|
||
let mha = MultiHeadAttention::new(ATTENTION_HEADS, EMBED_DIM).unwrap();
|
||
let x = Array2::zeros((10, 128)); // 错误的维度
|
||
assert!(mha.forward(&x).is_err());
|
||
}
|
||
}
|