NAC_Blockchain/xtzh-ai/src/error.rs

212 lines
5.8 KiB
Rust

//! 错误类型定义
//!
//! 本模块定义XTZH AI系统中所有可能的错误类型。
use thiserror::Error;
/// XTZH AI系统的错误类型
#[derive(Error, Debug)]
pub enum Error {
// ========================================================================
// 数据工程层错误
// ========================================================================
/// 数据源错误
#[error("数据源错误: {0}")]
DataSourceError(String),
/// 特征提取错误
#[error("特征提取错误: {0}")]
FeatureExtractionError(String),
/// 数据预处理错误
#[error("数据预处理错误: {0}")]
PreprocessingError(String),
/// 数据验证错误
#[error("数据验证错误: {0}")]
DataValidationError(String),
// ========================================================================
// AI模型层错误
// ========================================================================
/// 模型加载错误
#[error("模型加载错误: {0}")]
ModelLoadError(String),
/// 模型推理错误
#[error("模型推理错误: {0}")]
InferenceError(String),
/// 模型架构错误
#[error("模型架构错误: {0}")]
ArchitectureError(String),
/// 输入维度错误
#[error("输入维度错误: 期望 {expected}, 实际 {actual}")]
DimensionMismatch {
/// 期望的维度
expected: usize,
/// 实际的维度
actual: usize
},
// ========================================================================
// 训练系统错误
// ========================================================================
/// 训练错误
#[error("训练错误: {0}")]
TrainingError(String),
/// 验证错误
#[error("验证错误: {0}")]
ValidationError(String),
/// 优化器错误
#[error("优化器错误: {0}")]
OptimizerError(String),
// ========================================================================
// ONNX导出错误
// ========================================================================
/// ONNX导出错误
#[error("ONNX导出错误: {0}")]
OnnxExportError(String),
/// 量化错误
#[error("量化错误: {0}")]
QuantizationError(String),
/// 模型优化错误
#[error("模型优化错误: {0}")]
OptimizationError(String),
// ========================================================================
// 预言机层错误
// ========================================================================
/// 预言机节点错误
#[error("预言机节点错误: {0}")]
OracleNodeError(String),
/// 聚合器错误
#[error("聚合器错误: {0}")]
AggregatorError(String),
/// 证明生成错误
#[error("证明生成错误: {0}")]
ProofGenerationError(String),
/// 证明验证错误
#[error("证明验证错误: {0}")]
ProofVerificationError(String),
/// API错误
#[error("API错误: {0}")]
ApiError(String),
// ========================================================================
// 通用错误
// ========================================================================
/// IO错误
#[error("IO错误: {0}")]
IoError(#[from] std::io::Error),
/// 序列化错误
#[error("序列化错误: {0}")]
SerializationError(String),
/// 反序列化错误
#[error("反序列化错误: {0}")]
DeserializationError(String),
/// 配置错误
#[error("配置错误: {0}")]
ConfigError(String),
/// 网络错误
#[error("网络错误: {0}")]
NetworkError(String),
/// 超时错误
#[error("超时错误: {0}")]
TimeoutError(String),
/// 未实现错误
#[error("未实现: {0}")]
NotImplemented(String),
/// 内部错误
#[error("内部错误: {0}")]
InternalError(String),
}
/// XTZH AI系统的Result类型
pub type Result<T> = std::result::Result<T, Error>;
// ============================================================================
// 错误转换实现
// ============================================================================
impl From<serde_json::Error> for Error {
fn from(err: serde_json::Error) -> Self {
Error::SerializationError(err.to_string())
}
}
impl From<bincode::Error> for Error {
fn from(err: bincode::Error) -> Self {
Error::SerializationError(err.to_string())
}
}
impl From<reqwest::Error> for Error {
fn from(err: reqwest::Error) -> Self {
Error::NetworkError(err.to_string())
}
}
impl From<tokio::time::error::Elapsed> for Error {
fn from(err: tokio::time::error::Elapsed) -> Self {
Error::TimeoutError(err.to_string())
}
}
// ============================================================================
// 测试
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_display() {
let err = Error::DataSourceError("测试错误".to_string());
assert_eq!(err.to_string(), "数据源错误: 测试错误");
}
#[test]
fn test_dimension_mismatch() {
let err = Error::DimensionMismatch {
expected: 50,
actual: 48,
};
assert_eq!(err.to_string(), "输入维度错误: 期望 50, 实际 48");
}
#[test]
fn test_io_error_conversion() {
let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "文件未找到");
let err: Error = io_err.into();
assert!(matches!(err, Error::IoError(_)));
}
#[test]
fn test_result_type() {
let result: Result<i32> = Ok(42);
assert_eq!(result.unwrap(), 42);
let result: Result<i32> = Err(Error::InternalError("测试".to_string()));
assert!(result.is_err());
}
}