tiny_sparse_merkle_tree/
lib.rs

1#![no_std]
2
3extern crate alloc;
4
5#[cfg(any(test, feature = "keccak"))]
6pub mod hash;
7#[cfg(test)]
8mod tests;
9
10// --- core ---
11use core::fmt::Debug;
12// --- alloc ---
13use alloc::vec::Vec;
14
15pub trait Merge {
16	type Item;
17
18	fn merge(l: &Self::Item, r: &Self::Item) -> Self::Item;
19}
20
21/// > Assume the hash algorithm is `a + b`.
22///
23/// ## Tree
24/// ```text
25/// [10]
26/// [0,10]
27/// [0,0,3,7]
28/// [0,0,0,0,1,2,3,4]
29/// ```
30///
31/// ## Merge steps
32/// ```text
33/// [0,0,0,0,1,2,3,4]
34/// [0,0,0,3+4,1,2,3,4]
35/// [0,0,1+2,3+4,1,2,3,4]
36/// [0,1+2+3+4,1+2,3+4,1,2,3,4]
37/// ```
38#[cfg_attr(all(feature = "debug", not(test)), derive(Debug))]
39pub struct SparseMerkleTree<H> {
40	pub nodes: Vec<H>,
41	pub non_empty_leaves_count: u32,
42}
43impl<H> SparseMerkleTree<H>
44where
45	H: Clone + Debug + Default + PartialEq,
46{
47	pub fn new<L, M>(leaves: L) -> Self
48	where
49		L: Iterator<Item = H>,
50		M: Merge<Item = H>,
51	{
52		let non_empty_leaves_count = leaves.size_hint().0 as u32;
53		let half_leaves_count = non_empty_to_half_leaves_count(non_empty_leaves_count);
54		let leaves_count = half_leaves_count * 2;
55		let mut nodes = Vec::with_capacity(leaves_count as _);
56
57		#[cfg(feature = "debug")]
58		{
59			log::debug!("new::non_empty_leaves_count: {}", non_empty_leaves_count);
60			log::debug!("new::half_leaves_count: {}", half_leaves_count);
61		}
62
63		// Fill the empty leaves.
64		(0..half_leaves_count).for_each(|_| nodes.push(Default::default()));
65		// Fill the leaves.
66		leaves.for_each(|leaf| nodes.push(leaf));
67		// Fill the empty leaves.
68		// `x.next_power_of_two()` must grater/equal than/to `x`; qed
69		(0..half_leaves_count - non_empty_leaves_count)
70			.for_each(|_| nodes.push(Default::default()));
71		// Build the SMT.
72		(1..half_leaves_count).rev().for_each(|i| {
73			let i = i as usize;
74			let l = &nodes[i * 2];
75			let r = &nodes[i * 2 + 1];
76
77			nodes[i] = M::merge(l, r);
78		});
79
80		Self {
81			nodes,
82			non_empty_leaves_count,
83		}
84	}
85
86	pub fn leaves_count(&self) -> u32 {
87		self.nodes.len() as _
88	}
89
90	#[cfg(test)]
91	pub fn half_leaves_count(&self) -> u32 {
92		self.leaves_count() / 2
93	}
94
95	pub fn non_empty_leaves_count(&self) -> u32 {
96		self.non_empty_leaves_count
97	}
98
99	pub fn root(&self) -> H {
100		if self.leaves_count() == 0 {
101			Default::default()
102		} else {
103			self.nodes[1].clone()
104		}
105	}
106
107	/// ## Indices
108	/// ```text
109	// leaves  0 0 0 0 0 0 0 0 1 2 3 4 5 0 0 0
110	// indices                 0 1 2 3 4 5 6 7
111	/// ```
112	pub fn proof_of<I>(&self, indices: I) -> Proof<H>
113	where
114		I: AsRef<[u32]>,
115	{
116		let indices = indices.as_ref();
117		let leaves_count = self.leaves_count();
118		let half_leaves_count = leaves_count / 2;
119
120		if indices.iter().any(|i| *i >= self.non_empty_leaves_count()) {
121			log::warn!("proof_of::Index out of bounds.");
122
123			return Default::default();
124		}
125
126		let mut known = Vec::with_capacity(leaves_count as _);
127
128		(0..leaves_count).for_each(|_| known.push(false));
129		indices
130			.iter()
131			.for_each(|i| known[(half_leaves_count + *i) as usize] = true);
132
133		let mut proof = Vec::new();
134
135		(1..half_leaves_count).rev().for_each(|i| {
136			let i = i as usize;
137			let j = i * 2;
138			let k = j + 1;
139			let l = known[j];
140			let r = known[k];
141
142			if l && !r {
143				proof.push(self.nodes[k].clone());
144			}
145			if !l && r {
146				proof.push(self.nodes[j].clone());
147			}
148
149			known[i] = l || r;
150		});
151
152		Proof {
153			root: self.root(),
154			leaves_with_index: indices
155				.iter()
156				.map(|i| {
157					let i = half_leaves_count + *i;
158
159					(i, self.nodes[i as usize].clone())
160				})
161				.collect(),
162			proof,
163		}
164	}
165
166	pub fn verify<M>(proof: Proof<H>) -> bool
167	where
168		M: Merge<Item = H>,
169	{
170		let Proof {
171			root,
172			leaves_with_index: mut nodes_with_indices,
173			proof,
174		} = proof;
175
176		if nodes_with_indices.is_empty() {
177			return false;
178		}
179
180		#[cfg(feature = "debug")]
181		{
182			log::debug!("verify::root: {:?}", root);
183			log::debug!("verify::nodes_with_indices: {:?}", nodes_with_indices);
184			log::debug!("verify::proof: {:?}", proof);
185		}
186
187		// Use ptr to avoid extra vector allocation(`remove`).
188		let mut p_i = 0;
189		let mut n_i = 0;
190
191		while n_i < nodes_with_indices.len() {
192			let i = nodes_with_indices[n_i].0;
193			// Cache the current `n_i`.
194			let n_j = n_i;
195
196			n_i += 1;
197
198			if i == 1 {
199				return &root == &nodes_with_indices[n_j].1;
200			}
201			// Index starts from `0`, left nodes' index is an even number.
202			else if i % 2 == 0 {
203				if p_i == proof.len() {
204					return false;
205				}
206
207				nodes_with_indices.push((i / 2, M::merge(&nodes_with_indices[n_j].1, &proof[p_i])));
208				p_i += 1;
209			}
210			// Check the next node if exists.
211			// Notice that the `n_i` was already `+1`.
212			else if n_i != nodes_with_indices.len() && nodes_with_indices[n_i].0 == i - 1 {
213				nodes_with_indices.push((
214					i / 2,
215					M::merge(&nodes_with_indices[n_i].1, &nodes_with_indices[n_j].1),
216				));
217				n_i += 1;
218			} else {
219				if p_i == proof.len() {
220					return false;
221				}
222
223				nodes_with_indices.push((i / 2, M::merge(&proof[p_i], &nodes_with_indices[n_j].1)));
224				p_i += 1;
225			}
226
227			#[cfg(feature = "debug")]
228			log::debug!("verify::nodes_with_indices: {:?}", nodes_with_indices);
229		}
230
231		false
232	}
233}
234
235#[cfg_attr(feature = "debug", derive(Debug))]
236#[derive(Default)]
237pub struct Proof<H>
238where
239	H: Default,
240{
241	root: H,
242	leaves_with_index: Vec<(u32, H)>,
243	proof: Vec<H>,
244}
245impl<H> Proof<H>
246where
247	H: Clone + Default,
248{
249	/// Avoid to use this function as far as possible.
250	///
251	/// Pass the `indices` in descend order to [`SparseMerkleRoot::proof_of`],
252	/// then you will get the proof in descend order.
253	pub fn sort(&mut self) -> &mut Self {
254		self.leaves_with_index.sort_by(|(a, _), (b, _)| b.cmp(a));
255
256		self
257	}
258}
259
260pub fn non_empty_to_half_leaves_count(non_empty_leaves_count: u32) -> u32 {
261	non_empty_leaves_count.next_power_of_two()
262}