62 lines
1.8 KiB
Python
62 lines
1.8 KiB
Python
#!/usr/bin/env python3
|
|
import os
|
|
import re
|
|
|
|
def add_initialize_function(file_path):
|
|
"""为协议文件添加initialize函数"""
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
content = f.read()
|
|
|
|
# 检查是否已经有initialize函数
|
|
if 'fn initialize()' in content:
|
|
return False
|
|
|
|
# 找到impl块中hash函数之后的位置
|
|
# 在闭合}之前插入新函数
|
|
pattern = r'( pub fn hash\(data: &\[u8\]\) -> Hash \{[^}]+\})\n(\})\n\n(impl Default)'
|
|
|
|
initialize_func = '''
|
|
|
|
/// 初始化协议实例(带错误处理)
|
|
///
|
|
/// # Returns
|
|
///
|
|
/// 返回Result包装的协议实例
|
|
pub fn initialize() -> Result<Self> {
|
|
Ok(Self::new())
|
|
}'''
|
|
|
|
replacement = r'\1' + initialize_func + r'\n\2\n\n\3'
|
|
|
|
new_content = re.sub(pattern, replacement, content)
|
|
|
|
if new_content != content:
|
|
with open(file_path, 'w', encoding='utf-8') as f:
|
|
f.write(new_content)
|
|
return True
|
|
|
|
return False
|
|
|
|
def process_all_protocols(base_dir):
|
|
"""处理所有协议文件"""
|
|
modified_count = 0
|
|
|
|
for layer in range(10): # Layer 0-9
|
|
layer_dir = os.path.join(base_dir, f'layer{layer}')
|
|
if not os.path.exists(layer_dir):
|
|
continue
|
|
|
|
for file_name in os.listdir(layer_dir):
|
|
if file_name.endswith('.rs') and file_name != 'mod.rs':
|
|
file_path = os.path.join(layer_dir, file_name)
|
|
if add_initialize_function(file_path):
|
|
modified_count += 1
|
|
print(f'✅ 修复: {file_path}')
|
|
|
|
return modified_count
|
|
|
|
if __name__ == '__main__':
|
|
base_dir = '/home/ubuntu/NAC_Clean_Dev/nac-protocols/src'
|
|
count = process_all_protocols(base_dir)
|
|
print(f'\n总共修复了 {count} 个协议文件')
|