simple_merkle_tree/
lib.rs1use std::fmt::Debug;
42
43pub use tiny_keccak::Hasher;
44
45pub type Buffer = Vec<u8>;
46type Hash = [u8; 32];
47
48pub struct MerkleTree {
50 hashed_elements: Vec<Hash>,
51}
52
53impl MerkleTree {
55 pub fn new(elements: Vec<Buffer>) -> Self {
56 let elements = {
57 let mut elements: Vec<Buffer> = elements
58 .into_iter()
59 .filter(|e| !e.iter().all(|e| *e == 0))
61 .collect();
62
63 elements.sort();
65
66 let el_len = elements.len();
68 let elements = elements
69 .into_iter()
70 .fold(Vec::with_capacity(el_len), |mut acc, i| {
71 if !acc.contains(&i) {
72 acc.push(i);
73 }
74 acc
75 });
76 elements
77 };
78
79 let el_len = elements.len();
81 let (capacity, levels) = MerkleTree::calculate_levels(&el_len);
82
83 let vector_size = 2 * el_len - 1;
84 let mut result = vec![[0; 32]; vector_size];
85 log::debug!("Creating a vector with size {:}", vector_size);
86
87 let mut prior_elements = 0;
88 for level in 1..=levels {
89 let elem_count_in_level = el_len / level as usize;
90 let start_index = capacity - prior_elements - elem_count_in_level;
91
92 let end_index = start_index + elem_count_in_level; prior_elements += elem_count_in_level;
94 log::trace!(
95 "start_index: {}| end_index {}| elem_count_in_level {}",
96 start_index,
97 end_index,
98 elem_count_in_level
99 );
100
101 if level == 1 {
102 for (idx, elem) in elements.iter().enumerate() {
103 let hashed = MerkleTree::hash(&elem);
104 log::trace!(
105 "Setting idx {:} to {:}",
106 start_index + idx,
107 hex::encode(hashed)
108 );
109 result[start_index + idx] = hashed;
110 }
111 } else {
112 for idx in start_index..end_index {
113 let left = (2_usize * idx) + 1;
114 let right = (2_usize * idx) + 2;
115
116 log::trace!("Getting child of {}| L: {}| R: {}", idx, left, right);
117 let left = result[left];
118 let right = result[right];
119 let parent = MerkleTree::combined_hash(&left, &right);
120 result[idx] = parent;
122 }
123 }
124 }
125
126 let res = Self {
127 hashed_elements: result,
128 };
129 log::debug!("Constructed merkle tree {:#?}", &res);
130 res
131 }
132
133
134 pub fn get_root(&self) -> &[u8; 32] {
135 &self.hashed_elements[0]
136 }
137
138 pub fn get_proof(&self, el: &Buffer) -> Option<Vec<&[u8; 32]>> {
139 let hashed = MerkleTree::hash(&el);
140 log::debug!("Finding proof for {:}", hex::encode(hashed));
141
142 let index = self.hashed_elements.iter().position(|e| e == &hashed);
143
144 match index {
145 Some(mut index) => {
146 let mut res = vec![];
147
148 while index > 0 {
149 let sibling = self.get_pair_element(index);
151
152 if let Some(sibling) = sibling {
153 log::trace!(
154 "getting pair elem for index {:}; res {:}",
155 index,
156 hex::encode(sibling)
157 );
158 res.push(sibling);
159 }
160
161 index = MerkleTree::calculate_parent_idx(index);
162 log::trace!("Parent {:}", index);
163 }
164 Some(res)
165 }
166 None => None,
167 }
168 }
169
170}
171
172impl MerkleTree {
174
175 pub fn combined_hash(first: &[u8], second: &[u8]) -> [u8; 32] {
189 let mut keccak = tiny_keccak::Keccak::v256();
190 keccak.update(first);
191 keccak.update(second);
192 let mut result: [u8; 32] = Default::default();
193 keccak.finalize(&mut result);
194 result
195 }
196
197 pub fn hash(data: &[u8]) -> [u8; 32] {
209 let mut keccak = tiny_keccak::Keccak::v256();
210 keccak.update(&data);
211 let mut result: [u8; 32] = Default::default();
212 keccak.finalize(&mut result);
213 result
214 }
215
216 fn get_pair_element(&self, idx: usize) -> Option<&[u8; 32]> {
217 let pair_idx = MerkleTree::calculate_sibling_idx(idx);
218
219 if pair_idx < self.hashed_elements.len() {
220 return Some(&self.hashed_elements[pair_idx]);
221 }
222 return None;
223 }
224
225 fn calculate_sibling_idx(idx: usize) -> usize {
226 if idx % 2 == 0 {
227 idx - 1
228 } else {
229 idx + 1
230 }
231 }
232
233 fn calculate_parent_idx(child_idx: usize) -> usize {
234 let child_offset = {
235 if child_idx % 2 == 0 {
236 2
238 } else {
239 1
241 }
242 };
243
244 (child_idx - child_offset) / 2
245 }
246
247
248 fn calculate_levels(el_len: &usize) -> (usize, u32) {
249 let capacity = 2 * el_len - 1;
250 let levels: u32 = ((capacity as f32).log2() + 1.) as u32;
251 (capacity, levels)
252 }
253}
254
255impl Debug for MerkleTree {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 let hashed_elements: Vec<String> = self
258 .hashed_elements
259 .iter()
260 .map(|e| hex::encode(e))
261 .collect();
262
263 f.debug_struct("MerkleTree")
264 .field("hashed_elements", &hashed_elements)
265 .finish()
266 }
267}
268
269#[cfg(test)]
270mod tests {
271
272 use super::*;
273
274 fn generate_sample_vec(items: u32) -> Vec<Vec<u8>> {
275 let elements = (0..items)
276 .map(|el| format!("item-string-{:}", el).into_bytes())
277 .collect::<Vec<Vec<u8>>>();
278 elements
279 }
280
281 #[test]
282 fn construct_tree() {
283 simple_logger::init_with_level(log::Level::Trace).unwrap();
284 let elements = generate_sample_vec(4);
285 let tree = MerkleTree::new(elements.clone());
286
287 let a = &elements[0];
288 let b = &elements[1];
289 let c = &elements[2];
290 let d = &elements[3];
291
292 let h_a = MerkleTree::hash(a);
293 let h_b = MerkleTree::hash(b);
294 let h_c = MerkleTree::hash(c);
295 let h_d = MerkleTree::hash(d);
296
297 let h_ab = MerkleTree::combined_hash(&h_a, &h_b);
298 let h_cd = MerkleTree::combined_hash(&h_c, &h_d);
299
300 let h_abcd = MerkleTree::combined_hash(&h_ab, &h_cd);
301
302 log::debug!("h_abcd = {:}", hex::encode(h_abcd));
303
304 log::debug!("h_ab = {:}", hex::encode(h_ab));
305 log::debug!("h_cd = {:}", hex::encode(h_cd));
306
307 log::debug!("h_a = {:}", hex::encode(h_a));
308 log::debug!("h_b = {:}", hex::encode(h_b));
309 log::debug!("h_c = {:}", hex::encode(h_c));
310 log::debug!("h_d = {:}", hex::encode(h_d));
311
312 {
313 let proof = tree.get_proof(d).unwrap();
314 assert_eq!(proof.len(), 2);
315
316 assert_eq!(
317 vec![hex::encode(h_c), hex::encode(h_ab),],
318 proof
319 .iter()
320 .map(|e| hex::encode(e))
321 .collect::<Vec<String>>()
322 );
323 }
324
325 {
326 let proof = tree.get_proof(a).unwrap();
327 assert_eq!(proof.len(), 2);
328
329 assert_eq!(
330 vec![hex::encode(h_b), hex::encode(h_cd),],
331 proof
332 .iter()
333 .map(|e| hex::encode(e))
334 .collect::<Vec<String>>()
335 );
336 }
337 {
338 let root = tree.get_root();
339 assert_eq!(hex::encode(h_abcd), hex::encode(root));
340 }
341 }
342
343 #[test]
344 fn levels_get_calculated() {
345 let elements = generate_sample_vec(4);
346 let levels = MerkleTree::calculate_levels(&elements.len());
347 assert_eq!(levels, (7, 3));
348 }
349 #[test]
350 fn levels_get_calculated_v2() {
351 let elements = generate_sample_vec(3);
352 let levels = MerkleTree::calculate_levels(&elements.len());
353 assert_eq!(levels, (5, 3));
354 }
355 #[test]
356 fn levels_get_calculated_v3() {
357 let elements = generate_sample_vec(2);
358 let levels = MerkleTree::calculate_levels(&elements.len());
359 assert_eq!(levels, (3, 2));
360 }
361}