wnfs_hamt/
node.rs

1use super::{
2    HAMT_BITMASK_BIT_SIZE, HAMT_BITMASK_BYTE_SIZE, HashPrefix, Pair, Pointer,
3    error::HamtError,
4    hash::{HashNibbles, Hasher},
5};
6use crate::{HAMT_VALUES_BUCKET_SIZE, serializable::NodeSerializable};
7use anyhow::{Result, bail};
8use async_once_cell::OnceCell;
9use async_recursion::async_recursion;
10use bitvec::array::BitArray;
11use either::{Either, Either::*};
12#[cfg(feature = "log")]
13use log::debug;
14use serde::{Serialize, de::DeserializeOwned};
15use serde_byte_array::ByteArray;
16use std::{
17    collections::HashMap,
18    fmt::{self, Debug, Formatter},
19    hash::Hash,
20    marker::PhantomData,
21};
22use wnfs_common::{
23    BlockStore, Cid, HashOutput, Link, Storable,
24    utils::{Arc, BoxFuture, CondSend, CondSync, boxed_fut},
25};
26
27//--------------------------------------------------------------------------------------------------
28// Type Definitions
29//--------------------------------------------------------------------------------------------------
30
31/// The bitmask used by the HAMT which 16-bit, [u8; 2] type.
32pub type BitMaskType = [u8; HAMT_BITMASK_BYTE_SIZE];
33
34/// Represents a node in the HAMT tree structure.
35///
36/// # Examples
37///
38/// ```
39/// use std::sync::Arc;
40/// use wnfs_hamt::Node;
41/// use wnfs_common::MemoryBlockStore;
42///
43/// let store = &MemoryBlockStore::new();
44/// let node = Arc::new(Node::<String, usize>::default());
45///
46/// assert!(node.is_empty());
47/// ```
48pub struct Node<K, V, H = blake3::Hasher>
49where
50    H: Hasher + CondSync,
51    K: CondSync,
52    V: CondSync,
53{
54    persisted_as: OnceCell<Cid>,
55    pub(crate) bitmask: BitArray<BitMaskType>,
56    pub(crate) pointers: Vec<Pointer<K, V, H>>,
57    hasher: PhantomData<H>,
58}
59
60//--------------------------------------------------------------------------------------------------
61// Implementations
62//--------------------------------------------------------------------------------------------------
63
64impl<K, V, H> Node<K, V, H>
65where
66    H: Hasher + CondSync,
67    K: CondSync,
68    V: CondSync,
69{
70    /// Sets a new value at the given key.
71    ///
72    /// # Examples
73    ///
74    /// ```
75    /// use std::sync::Arc;
76    /// use wnfs_hamt::Node;
77    /// use wnfs_common::MemoryBlockStore;
78    ///
79    /// #[async_std::main]
80    /// async fn main() {
81    ///     let store = &MemoryBlockStore::new();
82    ///     let mut node = Arc::new(Node::<String, usize>::default());
83    ///
84    ///     node.set("key".into(), 42, store).await.unwrap();
85    ///     assert_eq!(node.get(&String::from("key"), store).await.unwrap(), Some(&42));
86    /// }
87    /// ```
88    pub async fn set(self: &mut Arc<Self>, key: K, value: V, store: &impl BlockStore) -> Result<()>
89    where
90        K: Storable + AsRef<[u8]> + Clone,
91        V: Storable + Clone,
92        K::Serializable: Serialize + DeserializeOwned,
93        V::Serializable: Serialize + DeserializeOwned,
94    {
95        let hash = &H::hash(&key);
96
97        #[cfg(feature = "log")]
98        debug!("set: hash = {:02x?}", hash);
99
100        self.set_value(&mut HashNibbles::new(hash), key, value, store)
101            .await
102    }
103
104    /// Gets the value at the given key.
105    ///
106    /// # Examples
107    ///
108    /// ```
109    /// use std::sync::Arc;
110    /// use wnfs_hamt::Node;
111    /// use wnfs_common::MemoryBlockStore;
112    ///
113    /// #[async_std::main]
114    /// async fn main() {
115    ///     let store = &MemoryBlockStore::new();
116    ///     let mut node = Arc::new(Node::<String, usize>::default());
117    ///
118    ///     node.set("key".into(), 42, store).await.unwrap();
119    ///     assert_eq!(node.get(&String::from("key"), store).await.unwrap(), Some(&42));
120    /// }
121    /// ```
122    pub async fn get<'a>(&'a self, key: &K, store: &impl BlockStore) -> Result<Option<&'a V>>
123    where
124        K: Storable + AsRef<[u8]>,
125        V: Storable,
126        K::Serializable: Serialize + DeserializeOwned,
127        V::Serializable: Serialize + DeserializeOwned,
128    {
129        let hash = &H::hash(key);
130
131        #[cfg(feature = "log")]
132        debug!("get: hash = {:02x?}", hash);
133
134        Ok(self
135            .get_value(&mut HashNibbles::new(hash), store)
136            .await?
137            .map(|pair| &pair.value))
138    }
139
140    /// Obtain a mutable reference to a given key.
141    ///
142    /// Will copy parts of the tree to prepare for changes, if necessary.
143    ///
144    /// # Examples
145    ///
146    /// ```
147    /// use std::sync::Arc;
148    /// use wnfs_hamt::Node;
149    /// use wnfs_common::MemoryBlockStore;
150    ///
151    /// #[async_std::main]
152    /// async fn main() {
153    ///     let store = &mut MemoryBlockStore::new();
154    ///     let mut node = Arc::new(Node::<String, usize>::default());
155    ///     node.set("key".into(), 40, store).await.unwrap();
156    ///
157    ///     let value = node.get_mut(&String::from("key"), store).await.unwrap().unwrap();
158    ///     *value += 2;
159    ///
160    ///     assert_eq!(node.get(&String::from("key"), store).await.unwrap(), Some(&42));
161    /// }
162    /// ```
163    // TODO(matheus23): Eventually provide a HashMap::Entry-similar API
164    pub async fn get_mut<'a>(
165        self: &'a mut Arc<Self>,
166        key: &K,
167        store: &'a impl BlockStore,
168    ) -> Result<Option<&'a mut V>>
169    where
170        K: Storable + AsRef<[u8]> + Clone,
171        V: Storable + Clone,
172        K::Serializable: Serialize + DeserializeOwned,
173        V::Serializable: Serialize + DeserializeOwned,
174    {
175        let hash = &H::hash(key);
176
177        #[cfg(feature = "log")]
178        debug!("get_mut: hash = {:02x?}", hash);
179
180        Ok(self
181            .get_value_mut(&mut HashNibbles::new(hash), store)
182            .await?
183            .map(|pair| &mut pair.value))
184    }
185
186    /// Removes the value at the given key.
187    ///
188    /// # Examples
189    ///
190    /// ```
191    /// use std::sync::Arc;
192    /// use wnfs_hamt::{Node, Pair};
193    /// use wnfs_common::MemoryBlockStore;
194    ///
195    /// #[async_std::main]
196    /// async fn main() {
197    ///     let store = &MemoryBlockStore::new();
198    ///     let mut node = Arc::new(Node::<String, usize>::default());
199    ///
200    ///     node.set("key".into(), 42, store).await.unwrap();
201    ///     assert_eq!(node.get(&String::from("key"), store).await.unwrap(), Some(&42));
202    ///
203    ///     let value = node.remove(&String::from("key"), store).await.unwrap();
204    ///     assert_eq!(value, Some(Pair::new("key".into(), 42)));
205    ///     assert_eq!(node.get(&String::from("key"), store).await.unwrap(), None);
206    /// }
207    /// ```
208    pub async fn remove(
209        self: &mut Arc<Self>,
210        key: &K,
211        store: &impl BlockStore,
212    ) -> Result<Option<Pair<K, V>>>
213    where
214        K: Storable + AsRef<[u8]> + Clone,
215        V: Storable + Clone,
216        K::Serializable: Serialize + DeserializeOwned,
217        V::Serializable: Serialize + DeserializeOwned,
218    {
219        let hash = &H::hash(key);
220
221        #[cfg(feature = "log")]
222        debug!("remove: hash = {:02x?}", hash);
223
224        self.remove_value(&mut HashNibbles::new(hash), store).await
225    }
226
227    /// Gets the value at the key matching the provided hash.
228    ///
229    /// # Examples
230    ///
231    /// ```
232    /// use std::sync::Arc;
233    /// use wnfs_hamt::{Node, Hasher};
234    /// use wnfs_common::MemoryBlockStore;
235    ///
236    /// #[async_std::main]
237    /// async fn main() {
238    ///     let store = &MemoryBlockStore::new();
239    ///     let mut node = Arc::new(Node::<String, usize>::default());
240    ///
241    ///     node.set("key".into(), 42, store).await.unwrap();
242    ///
243    ///     let key_hash = &blake3::Hasher::hash(&String::from("key"));
244    ///     assert_eq!(node.get_by_hash(key_hash, store).await.unwrap(), Some(&42));
245    /// }
246    /// ```
247    pub async fn get_by_hash<'a>(
248        &'a self,
249        hash: &HashOutput,
250        store: &impl BlockStore,
251    ) -> Result<Option<&'a V>>
252    where
253        K: Storable + AsRef<[u8]>,
254        V: Storable,
255        K::Serializable: Serialize + DeserializeOwned,
256        V::Serializable: Serialize + DeserializeOwned,
257    {
258        #[cfg(feature = "log")]
259        debug!("get_by_hash: hash = {:02x?}", hash);
260
261        Ok(self
262            .get_value(&mut HashNibbles::new(hash), store)
263            .await?
264            .map(|pair| &pair.value))
265    }
266
267    /// Removes the value at the key matching the provided hash.
268    ///
269    /// # Examples
270    ///
271    /// ```
272    /// use std::sync::Arc;
273    /// use wnfs_hamt::{Node, Hasher, Pair};
274    /// use wnfs_common::MemoryBlockStore;
275    ///
276    /// #[async_std::main]
277    /// async fn main() {
278    ///     let store = &MemoryBlockStore::new();
279    ///     let mut node = Arc::new(Node::<String, usize>::default());
280    ///
281    ///     node.set("key".into(), 42, store).await.unwrap();
282    ///     assert_eq!(node.get(&String::from("key"), store).await.unwrap(), Some(&42));
283    ///
284    ///     let key_hash = &blake3::Hasher::hash(&String::from("key"));
285    ///     let value = node.remove_by_hash(key_hash, store).await.unwrap();
286    ///
287    ///     assert_eq!(value, Some(Pair::new("key".into(), 42)));
288    ///     assert_eq!(node.get(&String::from("key"), store).await.unwrap(), None);
289    /// }
290    /// ```
291    pub async fn remove_by_hash(
292        self: &mut Arc<Self>,
293        hash: &HashOutput,
294        store: &impl BlockStore,
295    ) -> Result<Option<Pair<K, V>>>
296    where
297        K: Storable + AsRef<[u8]> + Clone,
298        V: Storable + Clone,
299        K::Serializable: Serialize + DeserializeOwned,
300        V::Serializable: Serialize + DeserializeOwned,
301    {
302        self.remove_value(&mut HashNibbles::new(hash), store).await
303    }
304
305    /// Checks if the node is empty.
306    ///
307    /// # Examples
308    ///
309    /// ```
310    /// use std::sync::Arc;
311    /// use wnfs_hamt::Node;
312    /// use wnfs_common::MemoryBlockStore;
313    ///
314    /// #[async_std::main]
315    /// async fn main() {
316    ///     let store = &MemoryBlockStore::new();
317    ///
318    ///     let mut node = Arc::new(Node::<String, usize>::default());
319    ///     assert!(node.is_empty());
320    ///
321    ///     node.set("key".into(), 42, store).await.unwrap();
322    ///     assert!(!node.is_empty());
323    /// }
324    /// ```
325    pub fn is_empty(&self) -> bool {
326        self.bitmask.count_ones() == 0
327    }
328
329    /// Calculates the value index from the bitmask index.
330    pub(crate) fn get_value_index(&self, bit_index: usize) -> usize {
331        let shift_amount = HAMT_BITMASK_BIT_SIZE - bit_index;
332        let mask = if shift_amount < HAMT_BITMASK_BIT_SIZE {
333            let mut tmp = BitArray::<BitMaskType>::new([0xff, 0xff]);
334            tmp.shift_left(shift_amount);
335            tmp
336        } else {
337            BitArray::ZERO
338        };
339        debug_assert_eq!(mask.count_ones(), bit_index);
340        (mask & self.bitmask).count_ones()
341    }
342
343    pub fn set_value<'a>(
344        self: &'a mut Arc<Self>,
345        hashnibbles: &'a mut HashNibbles,
346        key: K,
347        value: V,
348        store: &'a impl BlockStore,
349    ) -> BoxFuture<'a, Result<()>>
350    where
351        K: Storable + Clone + AsRef<[u8]> + 'a,
352        V: Storable + Clone + 'a,
353        K::Serializable: Serialize + DeserializeOwned,
354        V::Serializable: Serialize + DeserializeOwned,
355    {
356        Box::pin(async move {
357            let bit_index = hashnibbles.try_next()?;
358            let value_index = self.get_value_index(bit_index);
359
360            #[cfg(feature = "log")]
361            debug!(
362                "set_value: bit_index = {}, value_index = {}",
363                bit_index, value_index
364            );
365
366            let node = Arc::make_mut(self);
367            node.persisted_as = OnceCell::new();
368
369            // If the bit is not set yet, insert a new pointer.
370            if !node.bitmask[bit_index] {
371                node.pointers
372                    .insert(value_index, Pointer::Values(vec![Pair { key, value }]));
373
374                node.bitmask.set(bit_index, true);
375
376                return Ok(());
377            }
378
379            match &mut node.pointers[value_index] {
380                Pointer::Values(values) => {
381                    if let Some(i) = values
382                        .iter()
383                        .position(|p| &H::hash(&p.key) == hashnibbles.digest)
384                    {
385                        // If the key is already present, update the value.
386                        values[i] = Pair::new(key, value);
387                    } else {
388                        // Otherwise, insert the new value if bucket is not full. Create new node if it is.
389                        if values.len() < HAMT_VALUES_BUCKET_SIZE {
390                            // Insert in order of key.
391                            let index = values
392                                .iter()
393                                .position(|p| &H::hash(&p.key) > hashnibbles.digest)
394                                .unwrap_or(values.len());
395                            values.insert(index, Pair::new(key, value));
396                        } else {
397                            // If values has reached threshold, we need to create a node link that splits it.
398                            let mut sub_node = Arc::new(Node::<K, V, H>::default());
399                            let cursor = hashnibbles.get_cursor();
400                            // We can take because
401                            // Pointer::Values() gets replaced with Pointer::Link at the end
402                            let values = std::mem::take(values);
403                            for Pair { key, value } in
404                                values.into_iter().chain(Some(Pair::new(key, value)))
405                            {
406                                let hash = &H::hash(&key);
407                                let hashnibbles = &mut HashNibbles::with_cursor(hash, cursor);
408                                sub_node.set_value(hashnibbles, key, value, store).await?;
409                            }
410                            node.pointers[value_index] = Pointer::Link(Link::from(sub_node));
411                        }
412                    }
413                }
414                Pointer::Link(link) => {
415                    let mut child: Arc<Node<K, V, H>> =
416                        Arc::clone(link.resolve_value(store).await?);
417                    child.set_value(hashnibbles, key, value, store).await?;
418                    node.pointers[value_index] = Pointer::Link(Link::from(child));
419                }
420            }
421
422            Ok(())
423        })
424    }
425
426    #[cfg_attr(not(target_arch = "wasm32"), async_recursion)]
427    #[cfg_attr(target_arch = "wasm32", async_recursion(?Send))]
428    pub async fn get_value<'a>(
429        &'a self,
430        hashnibbles: &mut HashNibbles,
431        store: &impl BlockStore,
432    ) -> Result<Option<&'a Pair<K, V>>>
433    where
434        K: Storable + AsRef<[u8]>,
435        V: Storable,
436        K::Serializable: Serialize + DeserializeOwned,
437        V::Serializable: Serialize + DeserializeOwned,
438    {
439        let bit_index = hashnibbles.try_next()?;
440
441        // If the bit is not set yet, return None.
442        if !self.bitmask[bit_index] {
443            return Ok(None);
444        }
445
446        let value_index = self.get_value_index(bit_index);
447        match &self.pointers[value_index] {
448            Pointer::Values(values) => Ok({
449                values
450                    .iter()
451                    .find(|p| &H::hash(&p.key) == hashnibbles.digest)
452            }),
453            Pointer::Link(link) => {
454                let child = link.resolve_value(store).await?;
455                child.get_value(hashnibbles, store).await
456            }
457        }
458    }
459
460    #[cfg_attr(not(target_arch = "wasm32"), async_recursion)]
461    #[cfg_attr(target_arch = "wasm32", async_recursion(?Send))]
462    pub async fn get_value_mut<'a>(
463        self: &'a mut Arc<Self>,
464        hashnibbles: &mut HashNibbles,
465        store: &'a impl BlockStore,
466    ) -> Result<Option<&'a mut Pair<K, V>>>
467    where
468        K: Storable + AsRef<[u8]> + Clone,
469        V: Storable + Clone,
470        K::Serializable: Serialize + DeserializeOwned,
471        V::Serializable: Serialize + DeserializeOwned,
472    {
473        let bit_index = hashnibbles.try_next()?;
474
475        // If the bit is not set yet, return None.
476        if !self.bitmask[bit_index] {
477            return Ok(None);
478        }
479
480        let value_index = self.get_value_index(bit_index);
481        let node = Arc::make_mut(self);
482        node.persisted_as = OnceCell::new();
483
484        match &mut node.pointers[value_index] {
485            Pointer::Values(values) => Ok({
486                values
487                    .iter_mut()
488                    .find(|p| &H::hash(&p.key) == hashnibbles.digest)
489            }),
490            Pointer::Link(link) => {
491                let child = link.resolve_value_mut(store).await?;
492                child.get_value_mut(hashnibbles, store).await
493            }
494        }
495    }
496
497    pub fn remove_value<'a>(
498        self: &'a mut Arc<Self>,
499        hashnibbles: &'a mut HashNibbles,
500        store: &'a impl BlockStore,
501    ) -> BoxFuture<'a, Result<Option<Pair<K, V>>>>
502    where
503        K: Storable + AsRef<[u8]> + Clone + 'a,
504        V: Storable + Clone + 'a,
505        K::Serializable: Serialize + DeserializeOwned,
506        V::Serializable: Serialize + DeserializeOwned,
507    {
508        Box::pin(async move {
509            let bit_index = hashnibbles.try_next()?;
510
511            // If the bit is not set yet, return None.
512            if !self.bitmask[bit_index] {
513                return Ok(None);
514            }
515
516            let value_index = self.get_value_index(bit_index);
517
518            let node = Arc::make_mut(self);
519            node.persisted_as = OnceCell::new();
520
521            Ok(match &mut node.pointers[value_index] {
522                // If there is only one value, we can remove the entire pointer.
523                Pointer::Values(values) if values.len() == 1 => {
524                    // If the key doesn't match, return without removing.
525                    if &H::hash(&values[0].key) != hashnibbles.digest {
526                        None
527                    } else {
528                        node.bitmask.set(bit_index, false);
529                        match node.pointers.remove(value_index) {
530                            Pointer::Values(mut values) => Some(values.pop().unwrap()),
531                            _ => unreachable!(),
532                        }
533                    }
534                }
535                // Otherwise, remove just the value.
536                Pointer::Values(values) => {
537                    match values
538                        .iter()
539                        .position(|p| &H::hash(&p.key) == hashnibbles.digest)
540                    {
541                        Some(i) => {
542                            let value = values.remove(i);
543                            // We can take here because we replace the node.pointers here afterwards anyway
544                            let values = std::mem::take(values);
545                            node.pointers[value_index] = Pointer::Values(values);
546                            Some(value)
547                        }
548                        None => None,
549                    }
550                }
551                Pointer::Link(link) => {
552                    let mut child = Arc::clone(link.resolve_value(store).await?);
553                    let removed = child.remove_value(hashnibbles, store).await?;
554                    if removed.is_some() {
555                        // If something has been deleted, we attempt to canonicalize the pointer.
556                        match Pointer::Link(Link::from(child)).canonicalize(store).await? {
557                            Some(pointer) => {
558                                node.pointers[value_index] = pointer;
559                            }
560                            _ => {
561                                // This is None if the pointer now points to an empty node.
562                                // In that case, we remove it from the parent.
563                                node.bitmask.set(bit_index, false);
564                                node.pointers.remove(value_index);
565                            }
566                        }
567                    } else {
568                        node.pointers[value_index] = Pointer::Link(Link::from(child))
569                    };
570                    removed
571                }
572            })
573        })
574    }
575
576    /// Visits all the leaf nodes in the trie and calls the given function on each of them.
577    ///
578    /// # Examples
579    ///
580    /// ```
581    /// use std::sync::Arc;
582    /// use wnfs_hamt::{Node, Pair, Hasher};
583    /// use wnfs_common::{utils, MemoryBlockStore};
584    ///
585    /// #[async_std::main]
586    /// async fn main() {
587    ///     let store = &MemoryBlockStore::new();
588    ///     let mut node = Arc::new(Node::<[u8; 4], String>::default());
589    ///     for i in 0..99_u32 {
590    ///         node
591    ///             .set(i.to_le_bytes(), i.to_string(), store)
592    ///             .await
593    ///             .unwrap();
594    ///     }
595    ///
596    ///     let keys = node
597    ///         .flat_map(&|Pair { key, .. }| Ok(*key), store)
598    ///         .await
599    ///         .unwrap();
600    ///
601    ///     assert_eq!(keys.len(), 99);
602    /// }
603    /// ```
604    #[cfg_attr(not(target_arch = "wasm32"), async_recursion)]
605    #[cfg_attr(target_arch = "wasm32", async_recursion(?Send))]
606    pub async fn flat_map<F, T>(&self, f: &F, store: &impl BlockStore) -> Result<Vec<T>>
607    where
608        F: Fn(&Pair<K, V>) -> Result<T> + CondSync,
609        K: Storable + AsRef<[u8]>,
610        V: Storable,
611        K::Serializable: Serialize + DeserializeOwned,
612        V::Serializable: Serialize + DeserializeOwned,
613        T: CondSend,
614    {
615        let mut items = <Vec<T>>::new();
616        for p in self.pointers.iter() {
617            match p {
618                Pointer::Values(values) => {
619                    for pair in values {
620                        items.push(f(pair)?);
621                    }
622                }
623                Pointer::Link(link) => {
624                    let child = link.resolve_value(store).await?;
625                    items.extend(child.flat_map(f, store).await?);
626                }
627            }
628        }
629
630        Ok(items)
631    }
632
633    /// Given a hashprefix representing the path to a node in the trie. This function will
634    /// return the key-value pair or the intermediate node that the hashprefix points to.
635    ///
636    /// # Examples
637    ///
638    /// ```
639    /// use std::sync::Arc;
640    /// use wnfs_hamt::{Node, HashPrefix, Hasher};
641    /// use wnfs_common::{MemoryBlockStore, utils};
642    ///
643    /// #[async_std::main]
644    /// async fn main() {
645    ///     let store = &MemoryBlockStore::new();
646    ///
647    ///     let mut node = Arc::new(Node::<[u8; 4], String>::default());
648    ///     for i in 0..100_u32 {
649    ///         node
650    ///             .set(i.to_le_bytes(), i.to_string(), store)
651    ///             .await
652    ///             .unwrap();
653    ///     }
654    ///
655    ///     let hashprefix = HashPrefix::with_length(utils::to_hash_output(&[0x8C]), 2);
656    ///     let result = node.get_node_at(&hashprefix, store).await.unwrap();
657    ///
658    ///     println!("Result: {:#?}", result);
659    /// }
660    /// ```
661    #[cfg_attr(not(target_arch = "wasm32"), async_recursion)]
662    #[cfg_attr(target_arch = "wasm32", async_recursion(?Send))]
663    pub async fn get_node_at<'a>(
664        &'a self,
665        hashprefix: &HashPrefix,
666        store: &impl BlockStore,
667    ) -> Result<Option<Either<&'a Pair<K, V>, &'a Arc<Self>>>>
668    where
669        K: Storable + AsRef<[u8]>,
670        V: Storable,
671        K::Serializable: Serialize + DeserializeOwned,
672        V::Serializable: Serialize + DeserializeOwned,
673    {
674        self.get_node_at_helper(hashprefix, 0, store).await
675    }
676
677    #[cfg_attr(not(target_arch = "wasm32"), async_recursion)]
678    #[cfg_attr(target_arch = "wasm32", async_recursion(?Send))]
679    async fn get_node_at_helper<'a>(
680        &'a self,
681        hashprefix: &HashPrefix,
682        index: u8,
683        store: &impl BlockStore,
684    ) -> Result<Option<Either<&'a Pair<K, V>, &'a Arc<Self>>>>
685    where
686        K: Storable + AsRef<[u8]>,
687        V: Storable,
688        K::Serializable: Serialize + DeserializeOwned,
689        V::Serializable: Serialize + DeserializeOwned,
690    {
691        let bit_index = hashprefix
692            .get(index)
693            .ok_or(HamtError::HashPrefixIndexOutOfBounds(index))? as usize;
694
695        if !self.bitmask[bit_index] {
696            return Ok(None);
697        }
698
699        let value_index = self.get_value_index(bit_index);
700        match &self.pointers[value_index] {
701            Pointer::Values(values) => Ok({
702                values
703                    .iter()
704                    .find(|p| hashprefix.is_prefix_of(&H::hash(&p.key)))
705                    .map(Left)
706            }),
707            Pointer::Link(link) => {
708                let child = link.resolve_value(store).await?;
709                if index == hashprefix.len() as u8 - 1 {
710                    return Ok(Some(Right(child)));
711                }
712
713                child.get_node_at_helper(hashprefix, index + 1, store).await
714            }
715        }
716    }
717
718    /// Generates a hashmap from the node.
719    ///
720    /// # Examples
721    ///
722    /// ```
723    /// use std::sync::Arc;
724    /// use wnfs_hamt::{Node, Hasher};
725    /// use wnfs_common::MemoryBlockStore;
726    ///
727    /// #[async_std::main]
728    /// async fn main() {
729    ///     let store = &MemoryBlockStore::new();
730    ///
731    ///     let mut node = Arc::new(Node::<[u8; 4], String>::default());
732    ///     for i in 0..100_u32 {
733    ///         node
734    ///             .set(i.to_le_bytes(), i.to_string(), store)
735    ///             .await
736    ///             .unwrap();
737    ///     }
738    ///
739    ///     let map = node.to_hashmap(store).await.unwrap();
740    ///
741    ///     assert_eq!(map.len(), 100);
742    /// }
743    /// ```
744    pub async fn to_hashmap<B: BlockStore>(&self, store: &B) -> Result<HashMap<K, V>>
745    where
746        K: Storable + AsRef<[u8]> + Clone + Eq + Hash,
747        V: Storable + Clone,
748        K::Serializable: Serialize + DeserializeOwned,
749        V::Serializable: Serialize + DeserializeOwned,
750    {
751        let mut map = HashMap::new();
752        let key_values = self
753            .flat_map(
754                &|Pair { key, value }| Ok((key.clone(), value.clone())),
755                store,
756            )
757            .await?;
758
759        for (key, value) in key_values {
760            map.insert(key, value);
761        }
762
763        Ok(map)
764    }
765
766    /// Returns the count of the values in all the values pointer of a node.
767    pub fn count_values(self: &Arc<Self>) -> Result<usize> {
768        let mut len = 0;
769        for i in self.pointers.iter() {
770            if let Pointer::Values(values) = i {
771                len += values.len();
772            } else {
773                bail!(HamtError::ValuesPointerExpected);
774            }
775        }
776
777        Ok(len)
778    }
779}
780
781impl<K: Clone + CondSync, V: CondSync + Clone, H: Hasher + CondSync> Clone for Node<K, V, H> {
782    fn clone(&self) -> Self {
783        Self {
784            persisted_as: self
785                .persisted_as
786                .get()
787                .cloned()
788                .map(OnceCell::new_with)
789                .unwrap_or_default(),
790            bitmask: self.bitmask,
791            pointers: self.pointers.clone(),
792            hasher: PhantomData,
793        }
794    }
795}
796
797impl<K: CondSync, V: CondSync, H: Hasher + CondSync> Default for Node<K, V, H> {
798    fn default() -> Self {
799        Node {
800            persisted_as: OnceCell::new(),
801            bitmask: BitArray::ZERO,
802            pointers: Vec::with_capacity(HAMT_BITMASK_BIT_SIZE),
803            hasher: PhantomData,
804        }
805    }
806}
807
808impl<K, V, H> PartialEq for Node<K, V, H>
809where
810    K: Storable + PartialEq + CondSync,
811    V: Storable + PartialEq + CondSync,
812    K::Serializable: Serialize + DeserializeOwned,
813    V::Serializable: Serialize + DeserializeOwned,
814    H: Hasher + CondSync,
815{
816    fn eq(&self, other: &Self) -> bool {
817        self.bitmask == other.bitmask && self.pointers == other.pointers
818    }
819}
820
821impl<K, V, H> Debug for Node<K, V, H>
822where
823    K: Debug + CondSync,
824    V: Debug + CondSync,
825    H: Hasher + CondSync,
826{
827    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
828        let mut bitmask_str = String::new();
829        for i in self.bitmask.as_raw_slice().iter().rev() {
830            bitmask_str.push_str(&format!("{i:08b}"));
831        }
832
833        f.debug_struct("Node")
834            .field("bitmask", &bitmask_str)
835            .field("pointers", &self.pointers)
836            .finish()
837    }
838}
839
840impl<K, V, H> Storable for Node<K, V, H>
841where
842    K: Storable + CondSync,
843    V: Storable + CondSync,
844    K::Serializable: Serialize + DeserializeOwned,
845    V::Serializable: Serialize + DeserializeOwned,
846    H: Hasher + CondSync,
847{
848    type Serializable = NodeSerializable<K::Serializable, V::Serializable>;
849
850    async fn to_serializable(&self, store: &impl BlockStore) -> Result<Self::Serializable> {
851        let bitmask = ByteArray::from(self.bitmask.into_inner());
852
853        let mut pointers = Vec::with_capacity(self.pointers.len());
854        for pointer in self.pointers.iter() {
855            // Boxing the future due to recursion
856            pointers.push(boxed_fut(pointer.to_serializable(store)).await?);
857        }
858
859        Ok(NodeSerializable(bitmask, pointers))
860    }
861
862    async fn from_serializable(
863        cid: Option<&Cid>,
864        serializable: Self::Serializable,
865    ) -> Result<Self> {
866        let NodeSerializable(bitmask, ser_pointers) = serializable;
867
868        let bitmask = BitArray::<BitMaskType>::new(bitmask.into());
869        let bitmask_bits_set = bitmask.count_ones();
870
871        if ser_pointers.len() != bitmask_bits_set {
872            bail!(
873                "pointers length does not match bitmask, bitmask bits set: {}, pointers length: {}",
874                bitmask_bits_set,
875                ser_pointers.len()
876            );
877        }
878
879        let mut pointers = Vec::with_capacity(ser_pointers.len());
880        for ser_pointer in ser_pointers {
881            pointers.push(Pointer::from_serializable(cid, ser_pointer).await?);
882        }
883
884        Ok(Self {
885            persisted_as: cid.cloned().map(OnceCell::new_with).unwrap_or_default(),
886            bitmask,
887            pointers,
888            hasher: PhantomData,
889        })
890    }
891
892    fn persisted_as(&self) -> Option<&OnceCell<Cid>> {
893        Some(&self.persisted_as)
894    }
895}
896
897//--------------------------------------------------------------------------------------------------
898// Tests
899//--------------------------------------------------------------------------------------------------
900
901#[cfg(test)]
902mod tests {
903    use super::*;
904    use helper::*;
905    use wnfs_common::{MemoryBlockStore, utils};
906
907    mod helper {
908        use crate::Hasher;
909        use once_cell::sync::Lazy;
910        use wnfs_common::{HashOutput, utils};
911
912        pub(super) static HASH_KV_PAIRS: Lazy<Vec<(HashOutput, &'static str)>> = Lazy::new(|| {
913            vec![
914                (utils::to_hash_output(&[0xE0]), "first"),
915                (utils::to_hash_output(&[0xE1]), "second"),
916                (utils::to_hash_output(&[0xE2]), "third"),
917                (utils::to_hash_output(&[0xE3]), "fourth"),
918            ]
919        });
920
921        #[derive(Debug, Clone)]
922        pub(super) struct MockHasher;
923        impl Hasher for MockHasher {
924            fn hash<K: AsRef<[u8]>>(key: &K) -> HashOutput {
925                HASH_KV_PAIRS
926                    .iter()
927                    .find(|(_, v)| key.as_ref() == <dyn AsRef<[u8]>>::as_ref(v))
928                    .unwrap()
929                    .0
930            }
931        }
932    }
933
934    #[async_std::test]
935    async fn get_value_fetches_deeply_linked_value() {
936        let store = &MemoryBlockStore::default();
937
938        // Insert 4 values to trigger the creation of a linked node.
939        let working_node = &mut Arc::new(Node::<String, String, MockHasher>::default());
940        for (digest, kv) in HASH_KV_PAIRS.iter().take(4) {
941            let hashnibbles = &mut HashNibbles::new(digest);
942            working_node
943                .set_value(hashnibbles, kv.to_string(), kv.to_string(), store)
944                .await
945                .unwrap();
946        }
947
948        // Get the values.
949        for (digest, kv) in HASH_KV_PAIRS.iter().take(4) {
950            let hashnibbles = &mut HashNibbles::new(digest);
951            let value = working_node.get_value(hashnibbles, store).await.unwrap();
952
953            assert_eq!(value, Some(&Pair::new(kv.to_string(), kv.to_string())));
954        }
955    }
956
957    #[async_std::test]
958    async fn remove_value_canonicalizes_linked_node() {
959        let store = &MemoryBlockStore::default();
960
961        // Insert 4 values to trigger the creation of a linked node.
962        let working_node = &mut Arc::new(Node::<String, String, MockHasher>::default());
963        for (digest, kv) in HASH_KV_PAIRS.iter().take(4) {
964            let hashnibbles = &mut HashNibbles::new(digest);
965            working_node
966                .set_value(hashnibbles, kv.to_string(), kv.to_string(), store)
967                .await
968                .unwrap();
969        }
970
971        assert_eq!(working_node.pointers.len(), 1);
972
973        // Remove the third value.
974        let third_hashnibbles = &mut HashNibbles::new(&HASH_KV_PAIRS[2].0);
975        working_node
976            .remove_value(third_hashnibbles, store)
977            .await
978            .unwrap();
979
980        // Check that the third value is gone.
981        match &working_node.pointers[0] {
982            Pointer::Values(values) => {
983                assert_eq!(values.len(), 3);
984            }
985            _ => panic!("Expected values pointer"),
986        }
987
988        let value = working_node
989            .get_value(third_hashnibbles, store)
990            .await
991            .unwrap();
992
993        assert!(value.is_none());
994    }
995
996    #[async_std::test]
997    async fn set_value_splits_when_bucket_threshold_reached() {
998        let store = &MemoryBlockStore::default();
999
1000        // Insert 3 values into the HAMT.
1001        let working_node = &mut Arc::new(Node::<String, String, MockHasher>::default());
1002        for (idx, (digest, kv)) in HASH_KV_PAIRS.iter().take(3).enumerate() {
1003            let kv = kv.to_string();
1004            let hashnibbles = &mut HashNibbles::new(digest);
1005            working_node
1006                .set_value(hashnibbles, kv.clone(), kv.clone(), store)
1007                .await
1008                .unwrap();
1009
1010            match &working_node.pointers[0] {
1011                Pointer::Values(values) => {
1012                    assert_eq!(values.len(), idx + 1);
1013                    assert_eq!(values[idx].key, kv.clone());
1014                    assert_eq!(values[idx].value, kv.clone());
1015                }
1016                _ => panic!("Expected values pointer"),
1017            }
1018        }
1019
1020        // Inserting the fourth value should introduce a link indirection.
1021        working_node
1022            .set_value(
1023                &mut HashNibbles::new(&HASH_KV_PAIRS[3].0),
1024                "fourth".to_string(),
1025                "fourth".to_string(),
1026                store,
1027            )
1028            .await
1029            .unwrap();
1030
1031        match &working_node.pointers[0] {
1032            Pointer::Link(link) => {
1033                let node = link.get_value().unwrap();
1034                assert_eq!(node.bitmask.count_ones(), 4);
1035                assert_eq!(node.pointers.len(), 4);
1036            }
1037            _ => panic!("Expected link pointer"),
1038        }
1039    }
1040
1041    #[async_std::test]
1042    async fn get_value_index_gets_correct_index() {
1043        let store = &MemoryBlockStore::default();
1044        let hash_expected_idx_samples = [
1045            (&[0x00], 0),
1046            (&[0x20], 1),
1047            (&[0x10], 1),
1048            (&[0x30], 3),
1049            (&[0x50], 4),
1050            (&[0x60], 5),
1051            (&[0x70], 6),
1052            (&[0x40], 4),
1053            (&[0x80], 8),
1054            (&[0xA0], 9),
1055            (&[0xB0], 10),
1056            (&[0xC0], 11),
1057            (&[0x90], 9),
1058            (&[0xE0], 13),
1059            (&[0xD0], 13),
1060            (&[0xF0], 15),
1061        ];
1062
1063        let working_node = &mut Arc::new(Node::<String, String>::default());
1064        for (hash, expected_idx) in hash_expected_idx_samples.into_iter() {
1065            let bytes = utils::to_hash_output(&hash[..]);
1066            let hashnibbles = &mut HashNibbles::new(&bytes);
1067
1068            working_node
1069                .set_value(
1070                    hashnibbles,
1071                    expected_idx.to_string(),
1072                    expected_idx.to_string(),
1073                    store,
1074                )
1075                .await
1076                .unwrap();
1077
1078            assert_eq!(
1079                working_node.pointers[expected_idx],
1080                Pointer::Values(vec![Pair::new(
1081                    expected_idx.to_string(),
1082                    expected_idx.to_string()
1083                )])
1084            );
1085        }
1086    }
1087
1088    #[async_std::test]
1089    async fn node_can_insert_pair_and_retrieve() {
1090        let store = MemoryBlockStore::default();
1091        let node = &mut Arc::new(Node::<String, (i32, f64)>::default());
1092
1093        node.set("pill".into(), (10, 0.315), &store).await.unwrap();
1094
1095        let value = node.get(&"pill".into(), &store).await.unwrap().unwrap();
1096
1097        assert_eq!(value, &(10, 0.315));
1098    }
1099
1100    #[async_std::test]
1101    async fn node_is_same_with_irrelevant_remove() {
1102        // These two keys' hashes have the same first nibble (7)
1103        let insert_key: String = "GL59 Tg4phDb  bv".into();
1104        let remove_key: String = "hK i3b4V4152EPOdA".into();
1105
1106        let store = &MemoryBlockStore::default();
1107        let node0: &mut Arc<Node<String, u64>> = &mut Arc::new(Node::default());
1108
1109        node0.set(insert_key.clone(), 0, store).await.unwrap();
1110        node0.remove(&remove_key, store).await.unwrap();
1111
1112        assert_eq!(node0.count_values().unwrap(), 1);
1113    }
1114
1115    #[async_std::test]
1116    async fn node_history_independence_regression() {
1117        let store = &MemoryBlockStore::default();
1118
1119        let node1: &mut Arc<Node<String, u64>> = &mut Arc::new(Node::default());
1120        let node2: &mut Arc<Node<String, u64>> = &mut Arc::new(Node::default());
1121
1122        node1.set("key 17".into(), 508, store).await.unwrap();
1123        node1.set("key 81".into(), 971, store).await.unwrap();
1124        node1.set("key 997".into(), 365, store).await.unwrap();
1125        node1.remove(&"key 17".into(), store).await.unwrap();
1126        node1.set("key 68".into(), 870, store).await.unwrap();
1127        node1.set("key 304".into(), 331, store).await.unwrap();
1128
1129        node2.set("key 81".into(), 971, store).await.unwrap();
1130        node2.set("key 17".into(), 508, store).await.unwrap();
1131        node2.set("key 997".into(), 365, store).await.unwrap();
1132        node2.set("key 304".into(), 331, store).await.unwrap();
1133        node2.set("key 68".into(), 870, store).await.unwrap();
1134        node2.remove(&"key 17".into(), store).await.unwrap();
1135
1136        let cid1 = node1.store(store).await.unwrap();
1137        let cid2 = node2.store(store).await.unwrap();
1138
1139        assert_eq!(cid1, cid2);
1140    }
1141
1142    #[async_std::test]
1143    async fn can_map_over_leaf_nodes() {
1144        let store = &MemoryBlockStore::default();
1145
1146        let node = &mut Arc::new(Node::<[u8; 4], String>::default());
1147        for i in 0..99_u32 {
1148            node.set(i.to_le_bytes(), i.to_string(), store)
1149                .await
1150                .unwrap();
1151        }
1152
1153        let keys = node
1154            .flat_map(&|Pair { key, .. }| Ok(*key), store)
1155            .await
1156            .unwrap();
1157
1158        assert_eq!(keys.len(), 99);
1159    }
1160
1161    #[async_std::test]
1162    async fn can_fetch_node_at_hashprefix() {
1163        let store = &MemoryBlockStore::default();
1164
1165        let node = &mut Arc::new(Node::<String, String, MockHasher>::default());
1166        for (digest, kv) in HASH_KV_PAIRS.iter() {
1167            let hashnibbles = &mut HashNibbles::new(digest);
1168            node.set_value(hashnibbles, kv.to_string(), kv.to_string(), store)
1169                .await
1170                .unwrap();
1171        }
1172
1173        for (digest, kv) in HASH_KV_PAIRS.iter().take(4) {
1174            let hashprefix = HashPrefix::with_length(*digest, 2);
1175            let result = node.get_node_at(&hashprefix, store).await.unwrap();
1176            let (key, value) = (kv.to_string(), kv.to_string());
1177            assert_eq!(result, Some(Either::Left(&Pair { key, value })));
1178        }
1179
1180        let hashprefix = HashPrefix::with_length(utils::to_hash_output(&[0xE0]), 1);
1181        let result = node.get_node_at(&hashprefix, store).await.unwrap();
1182
1183        assert!(matches!(result, Some(Either::Right(_))));
1184    }
1185
1186    #[async_std::test]
1187    async fn can_generate_hashmap_from_node() {
1188        let store = &MemoryBlockStore::default();
1189
1190        let node = &mut Arc::new(Node::<[u8; 4], String>::default());
1191        const NUM_VALUES: u32 = 1000;
1192        for i in (u32::MAX - NUM_VALUES..u32::MAX).rev() {
1193            node.set(i.to_le_bytes(), i.to_string(), store)
1194                .await
1195                .unwrap();
1196        }
1197
1198        let map = node.to_hashmap(store).await.unwrap();
1199        assert_eq!(map.len(), NUM_VALUES as usize);
1200        for i in (u32::MAX - NUM_VALUES..u32::MAX).rev() {
1201            assert_eq!(map.get(&i.to_le_bytes()).unwrap(), &i.to_string());
1202        }
1203    }
1204}
1205
1206#[cfg(test)]
1207mod proptests {
1208    use super::*;
1209    use crate::strategies::{
1210        Operations, node_from_operations, operations, operations_and_shuffled,
1211    };
1212    use proptest::prelude::*;
1213    use test_strategy::proptest;
1214    use wnfs_common::MemoryBlockStore;
1215
1216    fn small_key() -> impl Strategy<Value = String> {
1217        (0..1000).prop_map(|i| format!("key {i}"))
1218    }
1219
1220    #[proptest(cases = 50)]
1221    fn test_insert_idempotence(
1222        #[strategy(operations(small_key(), 0u64..1000, 0..100))] operations: Operations<
1223            String,
1224            u64,
1225        >,
1226        #[strategy(small_key())] key: String,
1227        #[strategy(0..1000u64)] value: u64,
1228    ) {
1229        async_std::task::block_on(async move {
1230            let store = &MemoryBlockStore::default();
1231            let node = &mut node_from_operations(&operations, store).await.unwrap();
1232
1233            node.set(key.clone(), value, store).await.unwrap();
1234            let cid1 = node.store(store).await.unwrap();
1235
1236            node.set(key, value, store).await.unwrap();
1237            let cid2 = node.store(store).await.unwrap();
1238
1239            prop_assert_eq!(cid1, cid2);
1240            Ok(())
1241        })?;
1242    }
1243
1244    #[proptest(cases = 50)]
1245    fn test_remove_idempotence(
1246        #[strategy(operations(small_key(), 0u64..1000, 0..100))] operations: Operations<
1247            String,
1248            u64,
1249        >,
1250        #[strategy(small_key())] key: String,
1251    ) {
1252        async_std::task::block_on(async move {
1253            let store = &MemoryBlockStore::default();
1254            let node = &mut node_from_operations(&operations, store).await.unwrap();
1255
1256            node.remove(&key, store).await.unwrap();
1257            let cid1 = node.store(store).await.unwrap();
1258
1259            node.remove(&key, store).await.unwrap();
1260            let cid2 = node.store(store).await.unwrap();
1261
1262            prop_assert_eq!(cid1, cid2);
1263            Ok(())
1264        })?;
1265    }
1266
1267    #[proptest(cases = 100)]
1268    fn node_can_encode_decode_as_cbor(
1269        #[strategy(operations(small_key(), 0u64..1000, 0..1000))] operations: Operations<
1270            String,
1271            u64,
1272        >,
1273    ) {
1274        async_std::task::block_on(async move {
1275            let store = &MemoryBlockStore::default();
1276            let node = node_from_operations(&operations, store).await.unwrap();
1277
1278            let node_cid = node.store(store).await.unwrap();
1279            let decoded_node = Node::<String, u64>::load(&node_cid, store).await.unwrap();
1280
1281            prop_assert_eq!(node.as_ref(), &decoded_node);
1282            Ok(())
1283        })?;
1284    }
1285
1286    #[proptest(cases = 1000, max_shrink_iters = 10_000)]
1287    fn node_operations_are_history_independent(
1288        #[strategy(operations_and_shuffled(small_key(), 0u64..1000, 0..100))] pair: (
1289            Operations<String, u64>,
1290            Operations<String, u64>,
1291        ),
1292    ) {
1293        async_std::task::block_on(async move {
1294            let (original, shuffled) = pair;
1295
1296            let store = &MemoryBlockStore::default();
1297
1298            let node1 = node_from_operations(&original, store).await.unwrap();
1299            let node2 = node_from_operations(&shuffled, store).await.unwrap();
1300
1301            let cid1 = node1.store(store).await.unwrap();
1302            let cid2 = node2.store(store).await.unwrap();
1303
1304            prop_assert_eq!(cid1, cid2);
1305            Ok(())
1306        })?;
1307    }
1308
1309    // This is sort of a "control group" for making sure that operations_and_shuffled is correct.
1310    #[proptest(cases = 200, max_shrink_iters = 10_000)]
1311    fn hash_map_is_history_independent(
1312        #[strategy(operations_and_shuffled(small_key(), 0u64..1000, 0..1000))] pair: (
1313            Operations<String, u64>,
1314            Operations<String, u64>,
1315        ),
1316    ) {
1317        let (original, shuffled) = pair;
1318
1319        let map1 = HashMap::from(&original);
1320        let map2 = HashMap::from(&shuffled);
1321
1322        prop_assert_eq!(map1, map2);
1323    }
1324
1325    #[proptest]
1326    fn hamt_is_like_hash_map(
1327        #[strategy(operations(small_key(), 0u64..1000, 0..1000))] operations: Operations<
1328            String,
1329            u64,
1330        >,
1331    ) {
1332        async_std::task::block_on(async move {
1333            let store = &MemoryBlockStore::new();
1334
1335            let node = node_from_operations(&operations, store).await.unwrap();
1336            let map = HashMap::from(&operations);
1337            let map_result = node.to_hashmap(store).await.unwrap();
1338
1339            prop_assert_eq!(map, map_result);
1340            Ok(())
1341        })?;
1342    }
1343}
1344
1345#[cfg(test)]
1346mod snapshot_tests {
1347    use super::*;
1348    use wnfs_common::utils::SnapshotBlockStore;
1349
1350    #[async_std::test]
1351    async fn test_node() {
1352        let store = &SnapshotBlockStore::default();
1353        let node = &mut Arc::new(Node::<[u8; 4], String>::default());
1354        for i in 0..99_u32 {
1355            node.set(i.to_le_bytes(), i.to_string(), store)
1356                .await
1357                .unwrap();
1358        }
1359
1360        let cid = node.store(store).await.unwrap();
1361        let node = store.get_block_snapshot(&cid).await.unwrap();
1362
1363        insta::assert_json_snapshot!(node);
1364    }
1365}