smoldot/trie/calculate_root.rs
1// Smoldot
2// Copyright (C) 2023 Pierre Krieger
3// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0
4
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU General Public License for more details.
14
15// You should have received a copy of the GNU General Public License
16// along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18//! Freestanding function that calculates the root of a radix-16 Merkle-Patricia trie.
19//!
20//! See the parent module documentation for an explanation of what the trie is.
21//!
22//! This module is meant to be used in situations where all the nodes of the trie that have a
23//! storage value associated to them are known and easily accessible, and that no cache is
24//! available.
25//!
26//! # Usage
27//!
28//! Calling the [`root_merkle_value`] function creates a [`RootMerkleValueCalculation`] object
29//! which you have to drive to completion.
30//!
31//! Example:
32//!
33//! ```
34//! use std::{collections::BTreeMap, ops::Bound};
35//! use smoldot::trie::{HashFunction, TrieEntryVersion, calculate_root};
36//!
37//! // In this example, the storage consists in a binary tree map.
38//! let mut storage = BTreeMap::<Vec<u8>, (Vec<u8>, TrieEntryVersion)>::new();
39//! storage.insert(b"foo".to_vec(), (b"bar".to_vec(), TrieEntryVersion::V1));
40//!
41//! let trie_root = {
42//! let mut calculation = calculate_root::root_merkle_value(HashFunction::Blake2);
43//! loop {
44//! match calculation {
45//! calculate_root::RootMerkleValueCalculation::Finished { hash, .. } => break hash,
46//! calculate_root::RootMerkleValueCalculation::NextKey(next_key) => {
47//! let key_before = next_key.key_before().collect::<Vec<_>>();
48//! let lower_bound = if next_key.or_equal() {
49//! Bound::Included(key_before)
50//! } else {
51//! Bound::Excluded(key_before)
52//! };
53//! let outcome = storage
54//! .range((lower_bound, Bound::Unbounded))
55//! .next()
56//! .filter(|(k, _)| {
57//! k.iter()
58//! .copied()
59//! .zip(next_key.prefix())
60//! .all(|(a, b)| a == b)
61//! })
62//! .map(|(k, _)| k);
63//! calculation = next_key.inject_key(outcome.map(|k| k.iter().copied()));
64//! }
65//! calculate_root::RootMerkleValueCalculation::StorageValue(value_request) => {
66//! let key = value_request.key().collect::<Vec<u8>>();
67//! calculation = value_request.inject(storage.get(&key).map(|(val, v)| (val, *v)));
68//! }
69//! }
70//! }
71//! };
72//!
73//! assert_eq!(
74//! trie_root,
75//! [204, 86, 28, 213, 155, 206, 247, 145, 28, 169, 212, 146, 182, 159, 224, 82,
76//! 116, 162, 143, 156, 19, 43, 183, 8, 41, 178, 204, 69, 41, 37, 224, 91]
77//! );
78//! ```
79//!
80
81use super::{
82 EMPTY_BLAKE2_TRIE_MERKLE_VALUE, EMPTY_KECCAK256_TRIE_MERKLE_VALUE, HashFunction,
83 TrieEntryVersion, branch_search,
84 nibble::{Nibble, nibbles_to_bytes_suffix_extend},
85 trie_node,
86};
87
88use alloc::vec::Vec;
89use core::array;
90
91/// Start calculating the Merkle value of the root node.
92pub fn root_merkle_value(hash_function: HashFunction) -> RootMerkleValueCalculation {
93 CalcInner {
94 hash_function,
95 stack: Vec::with_capacity(8),
96 }
97 .next()
98}
99
100/// Current state of the [`RootMerkleValueCalculation`] and how to continue.
101#[must_use]
102pub enum RootMerkleValueCalculation {
103 /// The calculation is finished.
104 Finished {
105 /// Root hash that has been calculated.
106 hash: [u8; 32],
107 },
108
109 /// Request to return the key that follows (in lexicographic order) a given one in the storage.
110 /// Call [`NextKey::inject_key`] to indicate this list.
111 NextKey(NextKey),
112
113 /// Request the value of the node with a specific key. Call [`StorageValue::inject`] to
114 /// indicate the value.
115 StorageValue(StorageValue),
116}
117
118/// Calculation of the Merkle value is ready to continue.
119/// Shared by all the public-facing structs.
120struct CalcInner {
121 /// Hash function used by the trie.
122 hash_function: HashFunction,
123 /// Stack of nodes whose value is currently being calculated.
124 stack: Vec<Node>,
125}
126
127#[derive(Debug)]
128struct Node {
129 /// Partial key of the node currently being calculated.
130 partial_key: Vec<Nibble>,
131 /// Merkle values of the children of the node. Filled up to 16 elements, then popped. Each
132 /// element is `Some` or `None` depending on whether a child exists.
133 children: arrayvec::ArrayVec<Option<trie_node::MerkleValueOutput>, 16>,
134}
135
136impl CalcInner {
137 /// Returns the full key of the node currently being iterated.
138 fn current_iter_node_full_key(&self) -> impl Iterator<Item = Nibble> {
139 self.stack.iter().flat_map(|node| {
140 let child_nibble = if node.children.len() == 16 {
141 None
142 } else {
143 Some(Nibble::try_from(u8::try_from(node.children.len()).unwrap()).unwrap())
144 };
145
146 node.partial_key.iter().copied().chain(child_nibble)
147 })
148 }
149
150 /// Advances the calculation to the next step.
151 fn next(mut self) -> RootMerkleValueCalculation {
152 loop {
153 // If all the children of the node at the end of the stack are known, calculate the Merkle
154 // value of that node. To do so, we need to ask the user for the storage value.
155 if self
156 .stack
157 .last()
158 .map_or(false, |node| node.children.len() == 16)
159 {
160 // If the key has an even number of nibbles, we need to ask the user for the
161 // storage value.
162 if self.current_iter_node_full_key().count() % 2 == 0 {
163 break RootMerkleValueCalculation::StorageValue(StorageValue {
164 calculation: self,
165 });
166 }
167
168 // Otherwise we can calculate immediately.
169 let calculated_elem = self.stack.pop().unwrap();
170
171 // Calculate the Merkle value of the node.
172 let merkle_value = trie_node::calculate_merkle_value(
173 trie_node::Decoded {
174 children: array::from_fn(|n| calculated_elem.children[n].as_ref()),
175 partial_key: calculated_elem.partial_key.iter().copied(),
176 storage_value: trie_node::StorageValue::None,
177 },
178 self.hash_function,
179 self.stack.is_empty(),
180 )
181 .unwrap_or_else(|_| unreachable!());
182
183 // Insert Merkle value into the stack, or, if no parent, we have our result!
184 if let Some(parent) = self.stack.last_mut() {
185 parent.children.push(Some(merkle_value));
186 } else {
187 // Because we pass `is_root_node: true` in the calculation above, it is
188 // guaranteed that the Merkle value is always 32 bytes.
189 let hash = *<&[u8; 32]>::try_from(merkle_value.as_ref()).unwrap();
190 break RootMerkleValueCalculation::Finished { hash };
191 }
192 } else {
193 // Need to find the closest descendant to the first unknown child at the top of the
194 // stack.
195 break RootMerkleValueCalculation::NextKey(NextKey {
196 branch_search: branch_search::start_branch_search(branch_search::Config {
197 key_before: self.current_iter_node_full_key(),
198 or_equal: true,
199 prefix: self.current_iter_node_full_key(),
200 no_branch_search: false,
201 }),
202 calculation: self,
203 });
204 }
205 }
206 }
207}
208
209/// Request to return the key that follows (in lexicographic order) a given one in the storage.
210/// Call [`NextKey::inject_key`] to indicate this list.
211#[must_use]
212pub struct NextKey {
213 calculation: CalcInner,
214
215 /// Current branch search running to find the closest descendant to the node at the top of
216 /// the trie.
217 branch_search: branch_search::NextKey,
218}
219
220impl NextKey {
221 /// Returns the key whose next key must be passed back.
222 pub fn key_before(&self) -> impl Iterator<Item = u8> {
223 self.branch_search.key_before()
224 }
225
226 /// If `true`, then the provided value must the one superior or equal to the requested key.
227 /// If `false`, then the provided value must be strictly superior to the requested key.
228 pub fn or_equal(&self) -> bool {
229 self.branch_search.or_equal()
230 }
231
232 /// Returns the prefix the next key must start with. If the next key doesn't start with the
233 /// given prefix, then `None` should be provided.
234 pub fn prefix(&self) -> impl Iterator<Item = u8> {
235 self.branch_search.prefix()
236 }
237
238 /// Injects the key.
239 ///
240 /// # Panic
241 ///
242 /// Panics if the key passed as parameter isn't strictly superior to the requested key.
243 ///
244 pub fn inject_key(
245 mut self,
246 key: Option<impl Iterator<Item = u8>>,
247 ) -> RootMerkleValueCalculation {
248 match self.branch_search.inject(key) {
249 branch_search::BranchSearch::NextKey(next_key) => {
250 RootMerkleValueCalculation::NextKey(NextKey {
251 calculation: self.calculation,
252 branch_search: next_key,
253 })
254 }
255 branch_search::BranchSearch::Found {
256 branch_trie_node_key,
257 } => {
258 // Add the closest descendant to the stack.
259 if let Some(branch_trie_node_key) = branch_trie_node_key {
260 let partial_key = branch_trie_node_key
261 .skip(self.calculation.current_iter_node_full_key().count())
262 .collect();
263 self.calculation.stack.push(Node {
264 partial_key,
265 children: arrayvec::ArrayVec::new(),
266 });
267 self.calculation.next()
268 } else if let Some(stack_top) = self.calculation.stack.last_mut() {
269 stack_top.children.push(None);
270 self.calculation.next()
271 } else {
272 // Trie is completely empty.
273 RootMerkleValueCalculation::Finished {
274 hash: match self.calculation.hash_function {
275 HashFunction::Blake2 => EMPTY_BLAKE2_TRIE_MERKLE_VALUE,
276 HashFunction::Keccak256 => EMPTY_KECCAK256_TRIE_MERKLE_VALUE,
277 },
278 }
279 }
280 }
281 }
282 }
283}
284
285/// Request the value of the node with a specific key. Call [`StorageValue::inject`] to indicate
286/// the value.
287#[must_use]
288pub struct StorageValue {
289 calculation: CalcInner,
290}
291
292impl StorageValue {
293 /// Returns the key whose value is being requested.
294 pub fn key(&self) -> impl Iterator<Item = u8> {
295 // This function can never be reached if the number of nibbles is uneven.
296 debug_assert_eq!(self.calculation.current_iter_node_full_key().count() % 2, 0);
297 nibbles_to_bytes_suffix_extend(self.calculation.current_iter_node_full_key())
298 }
299
300 /// Indicates the storage value and advances the calculation.
301 pub fn inject(
302 mut self,
303 storage_value: Option<(impl AsRef<[u8]>, TrieEntryVersion)>,
304 ) -> RootMerkleValueCalculation {
305 let calculated_elem = self.calculation.stack.pop().unwrap();
306
307 // Due to some borrow checker troubles, we need to calculate the storage value
308 // hash ahead of time if relevant.
309 let storage_value_hash = if let Some((value, TrieEntryVersion::V1)) = storage_value.as_ref()
310 {
311 if value.as_ref().len() >= 33 {
312 Some(blake2_rfc::blake2b::blake2b(32, &[], value.as_ref()))
313 } else {
314 None
315 }
316 } else {
317 None
318 };
319
320 // Calculate the Merkle value of the node.
321 let merkle_value = trie_node::calculate_merkle_value(
322 trie_node::Decoded {
323 children: array::from_fn(|n| calculated_elem.children[n].as_ref()),
324 partial_key: calculated_elem.partial_key.iter().copied(),
325 storage_value: match (storage_value.as_ref(), storage_value_hash.as_ref()) {
326 (_, Some(storage_value_hash)) => trie_node::StorageValue::Hashed(
327 <&[u8; 32]>::try_from(storage_value_hash.as_bytes())
328 .unwrap_or_else(|_| unreachable!()),
329 ),
330 (Some((value, _)), _) => trie_node::StorageValue::Unhashed(value.as_ref()),
331 (None, _) => trie_node::StorageValue::None,
332 },
333 },
334 self.calculation.hash_function,
335 self.calculation.stack.is_empty(),
336 )
337 .unwrap_or_else(|_| unreachable!());
338
339 // Insert Merkle value into the stack, or, if no parent, we have our result!
340 if let Some(parent) = self.calculation.stack.last_mut() {
341 parent.children.push(Some(merkle_value));
342 self.calculation.next()
343 } else {
344 // Because we pass `is_root_node: true` in the calculation above, it is guaranteed
345 // that the Merkle value is always 32 bytes.
346 let hash = *<&[u8; 32]>::try_from(merkle_value.as_ref()).unwrap();
347 RootMerkleValueCalculation::Finished { hash }
348 }
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use crate::trie::{HashFunction, TrieEntryVersion};
355 use alloc::collections::BTreeMap;
356 use core::ops::Bound;
357
358 fn calculate_root(version: TrieEntryVersion, trie: &BTreeMap<Vec<u8>, Vec<u8>>) -> [u8; 32] {
359 let mut calculation = super::root_merkle_value(HashFunction::Blake2);
360
361 loop {
362 match calculation {
363 super::RootMerkleValueCalculation::Finished { hash } => {
364 return hash;
365 }
366 super::RootMerkleValueCalculation::NextKey(next_key) => {
367 let lower_bound = if next_key.or_equal() {
368 Bound::Included(next_key.key_before().collect::<Vec<_>>())
369 } else {
370 Bound::Excluded(next_key.key_before().collect::<Vec<_>>())
371 };
372
373 let k = trie
374 .range((lower_bound, Bound::Unbounded))
375 .next()
376 .filter(|(k, _)| {
377 k.iter()
378 .copied()
379 .zip(next_key.prefix())
380 .all(|(a, b)| a == b)
381 })
382 .map(|(k, _)| k);
383
384 calculation = next_key.inject_key(k.map(|k| k.iter().copied()));
385 }
386 super::RootMerkleValueCalculation::StorageValue(value) => {
387 let key = value.key().collect::<Vec<u8>>();
388 calculation = value.inject(trie.get(&key).map(|v| (v, version)));
389 }
390 }
391 }
392 }
393
394 #[test]
395 fn trie_root_one_node() {
396 let mut trie = BTreeMap::new();
397 trie.insert(b"abcd".to_vec(), b"hello world".to_vec());
398
399 let expected = [
400 122, 177, 134, 89, 211, 178, 120, 158, 242, 64, 13, 16, 113, 4, 199, 212, 251, 147,
401 208, 109, 154, 182, 168, 182, 65, 165, 222, 124, 63, 236, 200, 81,
402 ];
403
404 assert_eq!(calculate_root(TrieEntryVersion::V0, &trie), &expected[..]);
405 assert_eq!(calculate_root(TrieEntryVersion::V1, &trie), &expected[..]);
406 }
407
408 #[test]
409 fn trie_root_empty() {
410 let trie = BTreeMap::new();
411 let expected = blake2_rfc::blake2b::blake2b(32, &[], &[0x0]);
412 assert_eq!(
413 calculate_root(TrieEntryVersion::V0, &trie),
414 expected.as_bytes()
415 );
416 assert_eq!(
417 calculate_root(TrieEntryVersion::V1, &trie),
418 expected.as_bytes()
419 );
420 }
421
422 #[test]
423 fn trie_root_single_tuple() {
424 let mut trie = BTreeMap::new();
425 trie.insert([0xaa].to_vec(), [0xbb].to_vec());
426
427 let expected = blake2_rfc::blake2b::blake2b(
428 32,
429 &[],
430 &[
431 0x42, // leaf 0x40 (2^6) with (+) key of 2 nibbles (0x02)
432 0xaa, // key data
433 1 << 2, // length of value in bytes as Compact
434 0xbb, // value data
435 ],
436 );
437
438 assert_eq!(
439 calculate_root(TrieEntryVersion::V0, &trie),
440 expected.as_bytes()
441 );
442 assert_eq!(
443 calculate_root(TrieEntryVersion::V1, &trie),
444 expected.as_bytes()
445 );
446 }
447
448 #[test]
449 fn trie_root_example() {
450 let mut trie = BTreeMap::new();
451 trie.insert([0x48, 0x19].to_vec(), [0xfe].to_vec());
452 trie.insert([0x13, 0x14].to_vec(), [0xff].to_vec());
453
454 let ex = vec![
455 0x80, // branch, no value (0b_10..) no nibble
456 0x12, // slots 1 & 4 are taken from 0-7
457 0x00, // no slots from 8-15
458 0x05 << 2, // first slot: LEAF, 5 bytes long.
459 0x43, // leaf 0x40 with 3 nibbles
460 0x03, // first nibble
461 0x14, // second & third nibble
462 0x01 << 2, // 1 byte data
463 0xff, // value data
464 0x05 << 2, // second slot: LEAF, 5 bytes long.
465 0x43, // leaf with 3 nibbles
466 0x08, // first nibble
467 0x19, // second & third nibble
468 0x01 << 2, // 1 byte data
469 0xfe, // value data
470 ];
471
472 let expected = blake2_rfc::blake2b::blake2b(32, &[], &ex);
473 assert_eq!(
474 calculate_root(TrieEntryVersion::V0, &trie),
475 expected.as_bytes()
476 );
477 assert_eq!(
478 calculate_root(TrieEntryVersion::V1, &trie),
479 expected.as_bytes()
480 );
481 }
482}