tasm_lib/mmr/
bag_peaks.rs

1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4use triton_vm::twenty_first::util_types::mmr::mmr_accumulator::MmrAccumulator;
5
6use crate::arithmetic;
7use crate::prelude::*;
8use crate::traits::basic_snippet::Reviewer;
9use crate::traits::basic_snippet::SignOffFingerprint;
10
11/// [Bag the peaks][bag] of an MMR into a single [`Digest`].
12///
13/// # Behavior
14///
15/// ```text
16/// BEFORE: _ *mmr_accumulator
17/// AFTER:  _ [bagged_peaks: Digest]
18/// ```
19///
20/// # Preconditions
21///
22/// - the input argument points to a properly [`BFieldCodec`]-encoded list of
23///   [`Digest`]s in memory
24/// - the pointed-to MMR accumulator is consistent, *i.e.*, the number of peaks
25///   matches with the number of set bits in the leaf count.
26///
27/// # Postconditions
28///
29/// - the output is a single [`Digest`] computed like in [`bag_peaks`][bag]
30/// - the output is properly [`BFieldCodec`] encoded
31///
32/// # Crashes
33///
34///  - if the MMR accumulator is inconsistent, *i.e.*, if the number of peaks
35///    does not match the number of set bits in the leaf count
36///
37/// [bag]: twenty_first::util_types::mmr::mmr_trait::Mmr::bag_peaks
38#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
39pub struct BagPeaks;
40
41impl BagPeaks {
42    const INCONSISTENT_NUM_PEAKS_ERROR_ID: usize = 560;
43}
44
45impl BasicSnippet for BagPeaks {
46    fn inputs(&self) -> Vec<(DataType, String)> {
47        let mmr_accumulator = DataType::List(Box::new(DataType::Digest));
48
49        vec![(mmr_accumulator, "*mmra".to_string())]
50    }
51
52    fn outputs(&self) -> Vec<(DataType, String)> {
53        vec![(DataType::Digest, "digest".to_owned())]
54    }
55
56    fn entrypoint(&self) -> String {
57        "tasmlib_mmr_bag_peaks".to_string()
58    }
59
60    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
61        let entrypoint = self.entrypoint();
62        let bagging_loop = format!("{entrypoint}_loop");
63
64        let destructure_mmra = MmrAccumulator::destructure();
65        let pop_count = library.import(Box::new(arithmetic::u64::popcount::PopCount));
66
67        triton_asm!(
68        // BEFORE: _ *mmra
69        // AFTER:  _ [bagged_peaks: Digest]
70        {entrypoint}:
71            {&destructure_mmra}
72            // _ *peaks *leaf_count
73
74            addi 1 read_mem 2 pop 1
75            hint leaf_count : u64 = stack[0..2]
76            // _ *peaks [leaf_count]
77
78            dup 1 dup 1
79            call {pop_count}
80            // _ *peaks [leaf_count] popcount
81
82            dup 3 read_mem 1 pop 1
83            // _ *peaks [leaf_count] popcount num_peaks
84
85            dup 1 eq
86            // _ *peaks [leaf_count] pop_count (num_peaks==pop_count)
87
88            assert error_id {Self::INCONSISTENT_NUM_PEAKS_ERROR_ID}
89            hint num_peaks: u32 = stack[0]
90            // _ *peaks [leaf_count] num_peaks
91
92            place 2
93            // _ *peaks num_peaks [leaf_count]
94            // _ *peaks len [leaf_count] <-- rename
95
96            push 0
97            push 0
98            push 0
99            push 0
100            push 0
101            push 0
102            push 0
103            push 0
104            pick 9
105            pick 9
106            // _ *peaks len 0 0 0 0 0 0 0 0 [leaf_count; 2]
107
108            hash
109            // _ *peaks len [hash_of_leaf_count]
110
111            pick 5
112            push {Digest::LEN}
113            mul
114            // _ *peaks [hash_of_leaf_count] size_of_peaks_list
115
116            dup 6
117            add
118            // _ *peaks [hash_of_leaf_count] *peaks[last]_lw
119
120            place 5
121            // _ *peaks *peaks[last]_lw [hash_of_leaf_count]
122
123            dup 6 dup 6 eq push 0 eq
124            // _ *peaks *peaks[last]_lw [hash_of_leaf_count] (num_peaks == 0)
125
126            skiz call {bagging_loop}
127            // _ *peaks *peaks [bag_hash]
128
129            pick 6 pick 6 pop 2
130            // _ [bagged_peaks: Digest]
131
132            return
133
134        // INVARIANT: _ *peaks *peaks[i]_lw [acc: Digest]
135        {bagging_loop}:
136            pick 5
137            read_mem {Digest::LEN}
138            place 10
139            // _*peaks *peaks[i-1]_lw [acc: Digest] [peaks[i - 1]: Digest]
140
141            hash
142            // _*peaks *peaks[i-1]_lw [acc': Digest]
143
144            recurse_or_return
145        )
146    }
147
148    fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
149        [].into()
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use std::collections::HashMap;
156
157    use triton_vm::twenty_first::prelude::Mmr;
158    use twenty_first::math::other::random_elements;
159
160    use super::*;
161    use crate::test_prelude::*;
162
163    impl BagPeaks {
164        fn set_up_initial_state(&self, leaf_count: u64) -> FunctionInitialState {
165            let address = rand::random();
166            let mut stack = self.init_stack_for_isolated_run();
167            push_encodable(&mut stack, &address);
168
169            let mut memory = HashMap::new();
170            let num_peaks = leaf_count.count_ones();
171            let mmra =
172                MmrAccumulator::init(random_elements::<Digest>(num_peaks as usize), leaf_count);
173            encode_to_memory(&mut memory, address, &mmra);
174
175            FunctionInitialState { stack, memory }
176        }
177    }
178
179    impl Function for BagPeaks {
180        fn rust_shadow(
181            &self,
182            stack: &mut Vec<BFieldElement>,
183            memory: &mut HashMap<BFieldElement, BFieldElement>,
184        ) {
185            let address = pop_encodable(stack);
186            let mmra = *MmrAccumulator::decode_from_memory(memory, address).unwrap();
187
188            fn bag_peaks(peaks: &[Digest], leaf_count: u64) -> Digest {
189                // use `hash_10` over `hash` or `hash_varlen` to simplify hashing in Triton VM
190                let [lo_limb, hi_limb] = leaf_count.encode()[..] else {
191                    panic!("internal error: unknown encoding of type `u64`")
192                };
193                let padded_leaf_count = bfe_array![lo_limb, hi_limb, 0, 0, 0, 0, 0, 0, 0, 0];
194                let hashed_leaf_count = Digest::new(Tip5::hash_10(&padded_leaf_count));
195
196                peaks
197                    .iter()
198                    .rev()
199                    .fold(hashed_leaf_count, |acc, &peak| Tip5::hash_pair(peak, acc))
200            }
201
202            let bag = bag_peaks(&mmra.peaks(), mmra.num_leafs());
203            println!("bag: {bag}");
204            push_encodable(stack, &bag);
205        }
206
207        fn pseudorandom_initial_state(
208            &self,
209            seed: [u8; 32],
210            bench_case: Option<BenchmarkCase>,
211        ) -> FunctionInitialState {
212            let num_leafs = match bench_case {
213                Some(BenchmarkCase::CommonCase) => 348753,
214                Some(BenchmarkCase::WorstCase) => 843759843768,
215                None => StdRng::from_seed(seed).random_range(0u64..(u64::MAX >> 1)),
216            };
217
218            self.set_up_initial_state(num_leafs)
219        }
220
221        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
222            (0..=5)
223                .chain([63])
224                .map(|num_peaks| self.set_up_initial_state((1 << num_peaks) - 1))
225                .collect()
226        }
227    }
228
229    #[test]
230    fn rust_shadow() {
231        ShadowedFunction::new(BagPeaks).test()
232    }
233}
234
235#[cfg(test)]
236mod benches {
237    use super::BagPeaks;
238    use crate::test_prelude::*;
239
240    #[test]
241    fn benchmark() {
242        ShadowedFunction::new(BagPeaks).bench();
243    }
244}