212 lines
5.8 KiB
Rust
212 lines
5.8 KiB
Rust
//! 错误类型定义
|
|
//!
|
|
//! 本模块定义XTZH AI系统中所有可能的错误类型。
|
|
|
|
use thiserror::Error;
|
|
|
|
/// XTZH AI系统的错误类型
|
|
#[derive(Error, Debug)]
|
|
pub enum Error {
|
|
// ========================================================================
|
|
// 数据工程层错误
|
|
// ========================================================================
|
|
/// 数据源错误
|
|
#[error("数据源错误: {0}")]
|
|
DataSourceError(String),
|
|
|
|
/// 特征提取错误
|
|
#[error("特征提取错误: {0}")]
|
|
FeatureExtractionError(String),
|
|
|
|
/// 数据预处理错误
|
|
#[error("数据预处理错误: {0}")]
|
|
PreprocessingError(String),
|
|
|
|
/// 数据验证错误
|
|
#[error("数据验证错误: {0}")]
|
|
DataValidationError(String),
|
|
|
|
// ========================================================================
|
|
// AI模型层错误
|
|
// ========================================================================
|
|
/// 模型加载错误
|
|
#[error("模型加载错误: {0}")]
|
|
ModelLoadError(String),
|
|
|
|
/// 模型推理错误
|
|
#[error("模型推理错误: {0}")]
|
|
InferenceError(String),
|
|
|
|
/// 模型架构错误
|
|
#[error("模型架构错误: {0}")]
|
|
ArchitectureError(String),
|
|
|
|
/// 输入维度错误
|
|
#[error("输入维度错误: 期望 {expected}, 实际 {actual}")]
|
|
DimensionMismatch {
|
|
/// 期望的维度
|
|
expected: usize,
|
|
/// 实际的维度
|
|
actual: usize
|
|
},
|
|
|
|
// ========================================================================
|
|
// 训练系统错误
|
|
// ========================================================================
|
|
/// 训练错误
|
|
#[error("训练错误: {0}")]
|
|
TrainingError(String),
|
|
|
|
/// 验证错误
|
|
#[error("验证错误: {0}")]
|
|
ValidationError(String),
|
|
|
|
/// 优化器错误
|
|
#[error("优化器错误: {0}")]
|
|
OptimizerError(String),
|
|
|
|
// ========================================================================
|
|
// ONNX导出错误
|
|
// ========================================================================
|
|
/// ONNX导出错误
|
|
#[error("ONNX导出错误: {0}")]
|
|
OnnxExportError(String),
|
|
|
|
/// 量化错误
|
|
#[error("量化错误: {0}")]
|
|
QuantizationError(String),
|
|
|
|
/// 模型优化错误
|
|
#[error("模型优化错误: {0}")]
|
|
OptimizationError(String),
|
|
|
|
// ========================================================================
|
|
// 预言机层错误
|
|
// ========================================================================
|
|
/// 预言机节点错误
|
|
#[error("预言机节点错误: {0}")]
|
|
OracleNodeError(String),
|
|
|
|
/// 聚合器错误
|
|
#[error("聚合器错误: {0}")]
|
|
AggregatorError(String),
|
|
|
|
/// 证明生成错误
|
|
#[error("证明生成错误: {0}")]
|
|
ProofGenerationError(String),
|
|
|
|
/// 证明验证错误
|
|
#[error("证明验证错误: {0}")]
|
|
ProofVerificationError(String),
|
|
|
|
/// API错误
|
|
#[error("API错误: {0}")]
|
|
ApiError(String),
|
|
|
|
// ========================================================================
|
|
// 通用错误
|
|
// ========================================================================
|
|
/// IO错误
|
|
#[error("IO错误: {0}")]
|
|
IoError(#[from] std::io::Error),
|
|
|
|
/// 序列化错误
|
|
#[error("序列化错误: {0}")]
|
|
SerializationError(String),
|
|
|
|
/// 反序列化错误
|
|
#[error("反序列化错误: {0}")]
|
|
DeserializationError(String),
|
|
|
|
/// 配置错误
|
|
#[error("配置错误: {0}")]
|
|
ConfigError(String),
|
|
|
|
/// 网络错误
|
|
#[error("网络错误: {0}")]
|
|
NetworkError(String),
|
|
|
|
/// 超时错误
|
|
#[error("超时错误: {0}")]
|
|
TimeoutError(String),
|
|
|
|
/// 未实现错误
|
|
#[error("未实现: {0}")]
|
|
NotImplemented(String),
|
|
|
|
/// 内部错误
|
|
#[error("内部错误: {0}")]
|
|
InternalError(String),
|
|
}
|
|
|
|
/// XTZH AI系统的Result类型
|
|
pub type Result<T> = std::result::Result<T, Error>;
|
|
|
|
// ============================================================================
|
|
// 错误转换实现
|
|
// ============================================================================
|
|
|
|
impl From<serde_json::Error> for Error {
|
|
fn from(err: serde_json::Error) -> Self {
|
|
Error::SerializationError(err.to_string())
|
|
}
|
|
}
|
|
|
|
impl From<bincode::Error> for Error {
|
|
fn from(err: bincode::Error) -> Self {
|
|
Error::SerializationError(err.to_string())
|
|
}
|
|
}
|
|
|
|
impl From<reqwest::Error> for Error {
|
|
fn from(err: reqwest::Error) -> Self {
|
|
Error::NetworkError(err.to_string())
|
|
}
|
|
}
|
|
|
|
impl From<tokio::time::error::Elapsed> for Error {
|
|
fn from(err: tokio::time::error::Elapsed) -> Self {
|
|
Error::TimeoutError(err.to_string())
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// 测试
|
|
// ============================================================================
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_error_display() {
|
|
let err = Error::DataSourceError("测试错误".to_string());
|
|
assert_eq!(err.to_string(), "数据源错误: 测试错误");
|
|
}
|
|
|
|
#[test]
|
|
fn test_dimension_mismatch() {
|
|
let err = Error::DimensionMismatch {
|
|
expected: 50,
|
|
actual: 48,
|
|
};
|
|
assert_eq!(err.to_string(), "输入维度错误: 期望 50, 实际 48");
|
|
}
|
|
|
|
#[test]
|
|
fn test_io_error_conversion() {
|
|
let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "文件未找到");
|
|
let err: Error = io_err.into();
|
|
assert!(matches!(err, Error::IoError(_)));
|
|
}
|
|
|
|
#[test]
|
|
fn test_result_type() {
|
|
let result: Result<i32> = Ok(42);
|
|
assert_eq!(result.unwrap(), 42);
|
|
|
|
let result: Result<i32> = Err(Error::InternalError("测试".to_string()));
|
|
assert!(result.is_err());
|
|
}
|
|
}
|