1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
use crate::{Commitment, Error, Index, Node, Proof, Result, VectorCommitment};
use std::collections::VecDeque;
use zkp_error_utils::require;
use zkp_hash::{Hash, Hashable};
use zkp_mmap_vec::MmapVec;

#[cfg(feature = "std")]
use rayon::prelude::*;

// Utility function to parallelize iff on std
fn for_each<F>(slice: &mut [Hash], f: F)
where
    F: Fn((usize, &mut Hash)) -> () + Sync + Send,
{
    #[cfg(feature = "std")]
    slice.par_iter_mut().enumerate().for_each(f);

    #[cfg(not(feature = "std"))]
    slice.iter_mut().enumerate().for_each(f);
}

// Utility function to compute the first layer of the tree from the leaves
fn compute<C: VectorCommitment>(leaves: &C, index: Index) -> Hash {
    let leaf_depth = Index::depth_for_size(leaves.len());
    assert!(index.depth() <= leaf_depth);
    if index.depth() == leaf_depth {
        leaves.leaf_hash(index.offset())
    } else {
        Node(
            &compute(leaves, index.left_child()),
            &compute(leaves, index.right_child()),
        )
        .hash()
    }
}

/// Merkle tree
///
/// The tree will become the owner of the `Container`. This is necessary because
/// when low layer-omission is implemented we need immutable access to the
/// leaves. If shared ownership is required the `Container` can be an `Rc<_>`.
// OPT: Do not store leaf hashes but re-create.
// OPT: Allow up to `n` lower layers to be skipped.
// TODO: Make hash depend on type.
#[cfg_attr(feature = "std", derive(Debug))]
pub struct Tree<Container: VectorCommitment> {
    commitment: Commitment,
    nodes:      MmapVec<Hash>,
    leaves:     Container,
}

impl<Container: VectorCommitment> Tree<Container> {
    pub fn from_leaves(leaves: Container) -> Result<Self> {
        Self::from_leaves_skip_layers(leaves, 1)
    }

    pub fn from_leaves_skip_layers(leaves: Container, skip_layers: usize) -> Result<Self> {
        let size = leaves.len();
        if size == 0 {
            return Ok(Self {
                // TODO: Ideally give the empty tree a unique flag value.
                // Size zero commitment always exists
                commitment: Commitment::from_size_hash(size, &Hash::default()).unwrap(),
                nodes: MmapVec::with_capacity(0),
                leaves,
            });
        }
        // TODO: Support non power of two sizes
        require!(size.is_power_of_two(), Error::NumLeavesNotPowerOfTwo);
        require!(size <= Index::max_size(), Error::TreeToLarge);

        // Allocate result
        let leaf_depth = Index::depth_for_size(size);
        let mut nodes = if leaf_depth >= skip_layers {
            // The array size is the largest index + 1
            let depth = leaf_depth - skip_layers;
            let max_index = Index::from_depth_offset(depth, Index::size_at_depth(depth) - 1)
                .unwrap()
                .as_index();
            let mut nodes = MmapVec::with_capacity(max_index + 1);
            for _ in 0..=max_index {
                nodes.push(Hash::default());
            }
            nodes
        } else {
            MmapVec::with_capacity(0)
        };

        // Hash the tree nodes
        // OPT: Instead of layer at a time, have each thread compute a subtree.
        if leaf_depth >= skip_layers {
            let depth = leaf_depth - skip_layers;
            let leaf_layer = &mut nodes[Index::layer_range(depth)];
            // First layer
            for_each(leaf_layer, |(i, hash)| {
                *hash = compute(&leaves, Index::from_depth_offset(depth, i).unwrap())
            });
            // Upper layers
            for depth in (0..depth).rev() {
                // TODO: This makes assumptions about how Index works.
                let (tree, previous) =
                    nodes.split_at_mut(Index::from_depth_offset(depth + 1, 0).unwrap().as_index());
                let current = &mut tree[Index::layer_range(depth)];
                for_each(current, |(i, hash)| {
                    *hash = Node(&previous[i << 1], &previous[i << 1 | 1]).hash()
                });
            }
        }

        let root_hash = if nodes.is_empty() {
            compute(&leaves, Index::root())
        } else {
            nodes[0].clone()
        };
        let commitment = Commitment::from_size_hash(size, &root_hash).unwrap();
        Ok(Self {
            commitment,
            nodes,
            leaves,
        })
    }

    pub fn commitment(&self) -> &Commitment {
        &self.commitment
    }

    pub fn leaf_depth(&self) -> usize {
        Index::depth_for_size(self.leaves().len())
    }

    pub fn leaves(&self) -> &Container {
        &self.leaves
    }

    pub fn leaf(&self, index: usize) -> Container::Leaf {
        self.leaves.leaf(index)
    }

    pub fn node_hash(&self, index: Index) -> Hash {
        if index.as_index() < self.nodes.len() {
            self.nodes[index.as_index()].clone()
        } else {
            assert!(index.depth() <= self.leaf_depth());
            if index.depth() == self.leaf_depth() {
                self.leaves.leaf_hash(index.offset())
            } else {
                Node(
                    &self.node_hash(index.left_child()),
                    &self.node_hash(index.right_child()),
                )
                .hash()
            }
        }
    }

    pub fn open(&self, indices: &[usize]) -> Result<Proof> {
        let indices = self.commitment().sort_indices(indices)?;
        let proof_indices: Vec<usize> = indices.iter().map(|i| i.offset()).collect();
        let mut indices: VecDeque<Index> = indices.into_iter().collect();
        let mut hashes: Vec<Hash> = Vec::new();

        while let Some(current) = indices.pop_front() {
            // Root node has no parent and means we are done
            if let Some(parent) = current.parent() {
                // Add parent index to the queue for the next pass
                indices.push_back(parent);

                // Since we have a parent, we must have a sibling
                let sibling = current.sibling().unwrap();

                // Check if we merge with the next merkle index.
                if let Some(&next) = indices.front() {
                    if next == sibling {
                        // Skip next and don't write a decommitment for either
                        let _ = indices.pop_front();
                        continue;
                    }
                }

                // Add a sibling hash to the decommitment
                hashes.push(self.node_hash(sibling));
            }
        }
        Proof::from_hashes(self.commitment(), &proof_indices, &hashes)
    }
}

// Quickcheck requires pass by value
#[allow(clippy::needless_pass_by_value)]
#[cfg(test)]
mod tests {
    use super::*;
    use quickcheck_macros::quickcheck;
    use zkp_macros_decl::hex;
    use zkp_u256::U256;

    #[test]
    fn test_explicit_values() {
        let depth = 6;
        let leaves: Vec<_> = (0..2_u64.pow(depth))
            .map(|i| U256::from((i + 10).pow(3)))
            .collect();

        // Build the tree
        let tree = Tree::from_leaves(leaves).unwrap();
        let root = tree.commitment();
        assert_eq!(
            root.hash().as_bytes(),
            hex!("fd112f44bc944f33e2567f86eea202350913b11c000000000000000000000000")
        );

        // Open indices
        let indices = vec![1, 11, 14];
        assert_eq!(root.proof_size(&indices).unwrap(), 9);
        let proof = tree.open(&indices).unwrap();
        #[rustfmt::skip]
        assert_eq!(proof.hashes(), &[
            Hash::new(hex!("00000000000000000000000000000000000000000000000000000000000003e8")),
            Hash::new(hex!("0000000000000000000000000000000000000000000000000000000000001f40")),
            Hash::new(hex!("0000000000000000000000000000000000000000000000000000000000003d09")),
            Hash::new(hex!("4ea8b9bafb11dafcfe132a26f8e343eaef0651d9000000000000000000000000")),
            Hash::new(hex!("023a7ce535cadd222093be053ac26f9b800ee476000000000000000000000000")),
            Hash::new(hex!("70b0744af2583d10e7e3236c731d37605e196e06000000000000000000000000")),
            Hash::new(hex!("221aea6e87862ba2d03543d0aa82c6bffee310ae000000000000000000000000")),
            Hash::new(hex!("68b58e5131703684edb16d41b763017dfaa24a35000000000000000000000000")),
            Hash::new(hex!("e108b7dc670810e8588c67c2fde7ec4cc00165e8000000000000000000000000")),
        ]);

        // Verify proof
        let select_leaves: Vec<_> = indices.iter().map(|&i| (i, tree.leaf(i))).collect();
        proof.verify(select_leaves.as_slice()).unwrap();

        // Verify non-root
        let non_root = Hash::new(hex!(
            "ed112f44bc944f33e2567f86eea202350913b11c000000000000000000000000"
        ));
        let non_proof = Proof::from_hashes(
            &Commitment::from_size_hash(root.size(), &non_root).unwrap(),
            &indices,
            &proof.hashes(),
        )
        .unwrap();
        assert_eq!(
            non_proof.verify(&select_leaves),
            Err(Error::RootHashMismatch)
        );
    }

    #[test]
    fn test_empty_tree() {
        let indices: Vec<usize> = vec![];
        let leaves: Vec<U256> = vec![];

        let tree = Tree::from_leaves(leaves).unwrap();
        let root = tree.commitment();

        // Open indices
        let proof = tree.open(&indices).unwrap();
        assert_eq!(root.proof_size(&indices).unwrap(), proof.hashes().len());

        // Verify proof
        let select_leaves: Vec<(usize, U256)> = vec![];
        proof.verify(&select_leaves).unwrap();
    }

    #[quickcheck]
    fn test_merkle_tree(depth: usize, skip: usize, indices: Vec<usize>, seed: U256) {
        // We want tests up to depth 8; adjust the input
        let depth = depth % 9;
        // We want to skip up to 3 layers; adjust the input
        let skip = skip % 4;
        let num_leaves = 1_usize << depth;
        let indices: Vec<_> = indices.iter().map(|&i| i % num_leaves).collect();
        let leaves: Vec<_> = (0..num_leaves)
            .map(|i| (&seed + U256::from(i)).pow(3).unwrap())
            .collect();

        // Build the tree
        let tree = Tree::from_leaves_skip_layers(leaves, skip).unwrap();
        let root = tree.commitment();

        // Open indices
        let proof = tree.open(&indices).unwrap();
        assert_eq!(root.proof_size(&indices).unwrap(), proof.hashes().len());

        // Verify proof
        let select_leaves: Vec<_> = indices.iter().map(|&i| (i, tree.leaf(i))).collect();
        proof.verify(&select_leaves).unwrap();
    }
}