NAC_Blockchain/nvm_v2/nvm-l0/src/merkle.rs

133 lines
4.7 KiB
Rust

// NVM-L0 Merkle树实现
use crate::types::Hash;
use serde::{Deserialize, Serialize};
/// Merkle树节点
#[derive(Debug, Clone, Serialize, Deserialize)]
enum MerkleNode {
Leaf { hash: Hash, data: Vec<u8> },
Internal { hash: Hash, left: Box<MerkleNode>, right: Box<MerkleNode> },
Empty,
}
impl MerkleNode {
fn hash(&self) -> Hash {
match self {
MerkleNode::Leaf { hash, .. } => *hash,
MerkleNode::Internal { hash, .. } => *hash,
MerkleNode::Empty => Hash::zero(),
}
}
}
/// Merkle树
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MerkleTree {
root: MerkleNode,
leaf_count: usize,
}
impl MerkleTree {
pub fn new(data: Vec<Vec<u8>>) -> Self {
if data.is_empty() {
return Self { root: MerkleNode::Empty, leaf_count: 0 };
}
let leaves: Vec<MerkleNode> = data.into_iter()
.map(|d| { let hash = Hash::sha3_384(&d); MerkleNode::Leaf { hash, data: d } })
.collect();
let leaf_count = leaves.len();
let root = Self::build_tree(leaves);
Self { root, leaf_count }
}
fn build_tree(mut nodes: Vec<MerkleNode>) -> MerkleNode {
if nodes.is_empty() { return MerkleNode::Empty; }
if nodes.len() == 1 { return nodes.pop().unwrap(); }
if nodes.len() % 2 == 1 {
let last = nodes.last().unwrap().clone();
nodes.push(last);
}
let mut parent_nodes = Vec::new();
for i in (0..nodes.len()).step_by(2) {
let left = nodes[i].clone();
let right = nodes[i + 1].clone();
let mut combined = Vec::new();
combined.extend_from_slice(left.hash().as_bytes());
combined.extend_from_slice(right.hash().as_bytes());
let hash = Hash::sha3_384(&combined);
parent_nodes.push(MerkleNode::Internal { hash, left: Box::new(left), right: Box::new(right) });
}
Self::build_tree(parent_nodes)
}
pub fn root_hash(&self) -> Hash { self.root.hash() }
pub fn leaf_count(&self) -> usize { self.leaf_count }
pub fn generate_proof(&self, index: usize) -> Option<MerkleProof> {
if index >= self.leaf_count { return None; }
let mut proof_hashes = Vec::new();
let mut proof_directions = Vec::new();
self.generate_proof_recursive(&self.root, index, 0, self.leaf_count, &mut proof_hashes, &mut proof_directions);
Some(MerkleProof { leaf_index: index, proof_hashes, proof_directions })
}
fn generate_proof_recursive(&self, node: &MerkleNode, target_index: usize, start_index: usize, count: usize, proof_hashes: &mut Vec<Hash>, proof_directions: &mut Vec<bool>) {
match node {
MerkleNode::Leaf { .. } => {}
MerkleNode::Internal { left, right, .. } => {
let mid = start_index + count / 2;
if target_index < mid {
proof_hashes.push(right.hash());
proof_directions.push(false);
self.generate_proof_recursive(left, target_index, start_index, count / 2, proof_hashes, proof_directions);
} else {
proof_hashes.push(left.hash());
proof_directions.push(true);
self.generate_proof_recursive(right, target_index, mid, count / 2, proof_hashes, proof_directions);
}
}
MerkleNode::Empty => {}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MerkleProof {
pub leaf_index: usize,
pub proof_hashes: Vec<Hash>,
pub proof_directions: Vec<bool>,
}
impl MerkleProof {
pub fn verify(&self, leaf_hash: Hash, root_hash: Hash) -> bool {
let mut current_hash = leaf_hash;
for (sibling_hash, is_left) in self.proof_hashes.iter().zip(self.proof_directions.iter()) {
let mut combined = Vec::new();
if *is_left {
combined.extend_from_slice(sibling_hash.as_bytes());
combined.extend_from_slice(current_hash.as_bytes());
} else {
combined.extend_from_slice(current_hash.as_bytes());
combined.extend_from_slice(sibling_hash.as_bytes());
}
current_hash = Hash::sha3_384(&combined);
}
current_hash == root_hash
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_merkle_tree() {
let data = vec![b"data1".to_vec(), b"data2".to_vec()];
let tree = MerkleTree::new(data.clone());
assert_eq!(tree.leaf_count(), 2);
let proof = tree.generate_proof(0).unwrap();
let leaf_hash = Hash::sha3_384(&data[0]);
assert!(proof.verify(leaf_hash, tree.root_hash()));
}
}