/// 测试辅助函数模块 /// /// 提供常用的测试辅助功能 use std::time::Duration; use tokio::time::{sleep, timeout}; /// 等待条件满足 /// /// # Arguments /// * `condition` - 条件检查函数 /// * `timeout_secs` - 超时时间(秒) /// * `check_interval_ms` - 检查间隔(毫秒) /// /// # Returns /// * `Ok(())` - 条件满足 /// * `Err(String)` - 超时 pub async fn wait_for_condition( mut condition: F, timeout_secs: u64, check_interval_ms: u64, ) -> Result<(), String> where F: FnMut() -> bool, { let timeout_duration = Duration::from_secs(timeout_secs); let check_interval = Duration::from_millis(check_interval_ms); let result = timeout(timeout_duration, async { while !condition() { sleep(check_interval).await; } }) .await; match result { Ok(_) => Ok(()), Err(_) => Err(format!("Timeout after {} seconds", timeout_secs)), } } /// 重试执行函数直到成功 /// /// # Arguments /// * `f` - 要执行的函数 /// * `max_retries` - 最大重试次数 /// * `retry_interval_ms` - 重试间隔(毫秒) /// /// # Returns /// * `Ok(T)` - 执行成功的结果 /// * `Err(String)` - 达到最大重试次数 pub async fn retry_until_success( mut f: F, max_retries: usize, retry_interval_ms: u64, ) -> Result where F: FnMut() -> Result, E: std::fmt::Display, { let retry_interval = Duration::from_millis(retry_interval_ms); for attempt in 0..max_retries { match f() { Ok(result) => return Ok(result), Err(e) => { if attempt == max_retries - 1 { return Err(format!( "Failed after {} attempts. Last error: {}", max_retries, e )); } log::debug!("Attempt {} failed: {}. Retrying...", attempt + 1, e); sleep(retry_interval).await; } } } unreachable!() } /// 并发执行多个任务 /// /// # Arguments /// * `tasks` - 任务列表 /// /// # Returns /// * 所有任务的结果 pub async fn run_concurrent(tasks: Vec) -> Vec where F: std::future::Future + Send + 'static, T: Send + 'static, { let handles: Vec<_> = tasks .into_iter() .map(|task| tokio::spawn(task)) .collect(); let mut results = Vec::new(); for handle in handles { if let Ok(result) = handle.await { results.push(result); } } results } /// 生成随机测试数据 pub mod random { use rand::Rng; /// 生成随机字节数组 pub fn random_bytes() -> [u8; N] { let mut rng = rand::thread_rng(); let mut bytes = [0u8; N]; for byte in &mut bytes { *byte = rng.gen(); } bytes } /// 生成随机u64 pub fn random_u64() -> u64 { rand::thread_rng().gen() } /// 生成指定范围内的随机u64 pub fn random_u64_range(min: u64, max: u64) -> u64 { rand::thread_rng().gen_range(min..=max) } /// 生成随机字符串 pub fn random_string(len: usize) -> String { use rand::distributions::Alphanumeric; rand::thread_rng() .sample_iter(&Alphanumeric) .take(len) .map(char::from) .collect() } } /// 性能测量工具 pub mod perf { use std::time::Instant; /// 测量函数执行时间 pub fn measure_time(f: F) -> (T, std::time::Duration) where F: FnOnce() -> T, { let start = Instant::now(); let result = f(); let duration = start.elapsed(); (result, duration) } /// 测量异步函数执行时间 pub async fn measure_time_async(f: F) -> (T, std::time::Duration) where F: std::future::Future, { let start = Instant::now(); let result = f.await; let duration = start.elapsed(); (result, duration) } /// 计算TPS pub fn calculate_tps(tx_count: usize, duration: std::time::Duration) -> f64 { tx_count as f64 / duration.as_secs_f64() } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_wait_for_condition_success() { let mut counter = 0; let result = wait_for_condition( || { counter += 1; counter >= 5 }, 5, 10, ) .await; assert!(result.is_ok()); assert!(counter >= 5); } #[tokio::test] async fn test_wait_for_condition_timeout() { let result = wait_for_condition(|| false, 1, 10).await; assert!(result.is_err()); } #[tokio::test] async fn test_retry_until_success() { let mut counter = 0; let result = retry_until_success( || { counter += 1; if counter >= 3 { Ok(counter) } else { Err("Not ready") } }, 5, 10, ) .await; assert!(result.is_ok()); assert_eq!(result.unwrap(), 3); } #[tokio::test] async fn test_retry_until_failure() { let result = retry_until_success(|| Err::<(), _>("Always fail"), 3, 10).await; assert!(result.is_err()); } // 注意:run_concurrent测试被禁用,因为Rust的impl Trait限制 // 实际使用中可以通过其他方式处理并发任务 // #[tokio::test] // async fn test_run_concurrent() { // // 测试代码 // } #[test] fn test_random_bytes() { let bytes1 = random::random_bytes::<32>(); let bytes2 = random::random_bytes::<32>(); // 随机生成的字节应该不同 assert_ne!(bytes1, bytes2); } #[test] fn test_random_u64() { let num1 = random::random_u64(); let num2 = random::random_u64(); // 随机生成的数字应该不同(概率极高) assert_ne!(num1, num2); } #[test] fn test_random_u64_range() { for _ in 0..100 { let num = random::random_u64_range(10, 20); assert!(num >= 10 && num <= 20); } } #[test] fn test_random_string() { let s1 = random::random_string(10); let s2 = random::random_string(10); assert_eq!(s1.len(), 10); assert_eq!(s2.len(), 10); assert_ne!(s1, s2); } #[test] fn test_measure_time() { let (result, duration) = perf::measure_time(|| { std::thread::sleep(Duration::from_millis(100)); 42 }); assert_eq!(result, 42); assert!(duration.as_millis() >= 100); } #[tokio::test] async fn test_measure_time_async() { let (result, duration) = perf::measure_time_async(async { tokio::time::sleep(Duration::from_millis(100)).await; 42 }) .await; assert_eq!(result, 42); assert!(duration.as_millis() >= 100); } #[test] fn test_calculate_tps() { let tps = perf::calculate_tps(1000, Duration::from_secs(1)); assert_eq!(tps, 1000.0); let tps = perf::calculate_tps(5000, Duration::from_millis(500)); assert_eq!(tps, 10000.0); } }