tasm_lib/mmr/
bag_peaks.rs1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4use triton_vm::twenty_first::util_types::mmr::mmr_accumulator::MmrAccumulator;
5
6use crate::arithmetic;
7use crate::prelude::*;
8use crate::traits::basic_snippet::Reviewer;
9use crate::traits::basic_snippet::SignOffFingerprint;
10
11#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
39pub struct BagPeaks;
40
41impl BagPeaks {
42 const INCONSISTENT_NUM_PEAKS_ERROR_ID: usize = 560;
43}
44
45impl BasicSnippet for BagPeaks {
46 fn inputs(&self) -> Vec<(DataType, String)> {
47 let mmr_accumulator = DataType::List(Box::new(DataType::Digest));
48
49 vec![(mmr_accumulator, "*mmra".to_string())]
50 }
51
52 fn outputs(&self) -> Vec<(DataType, String)> {
53 vec![(DataType::Digest, "digest".to_owned())]
54 }
55
56 fn entrypoint(&self) -> String {
57 "tasmlib_mmr_bag_peaks".to_string()
58 }
59
60 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
61 let entrypoint = self.entrypoint();
62 let bagging_loop = format!("{entrypoint}_loop");
63
64 let destructure_mmra = MmrAccumulator::destructure();
65 let pop_count = library.import(Box::new(arithmetic::u64::popcount::PopCount));
66
67 triton_asm!(
68 {entrypoint}:
71 {&destructure_mmra}
72 addi 1 read_mem 2 pop 1
75 hint leaf_count : u64 = stack[0..2]
76 dup 1 dup 1
79 call {pop_count}
80 dup 3 read_mem 1 pop 1
83 dup 1 eq
86 assert error_id {Self::INCONSISTENT_NUM_PEAKS_ERROR_ID}
89 hint num_peaks: u32 = stack[0]
90 place 2
93 push 0
97 push 0
98 push 0
99 push 0
100 push 0
101 push 0
102 push 0
103 push 0
104 pick 9
105 pick 9
106 hash
109 pick 5
112 push {Digest::LEN}
113 mul
114 dup 6
117 add
118 place 5
121 dup 6 dup 6 eq push 0 eq
124 skiz call {bagging_loop}
127 pick 6 pick 6 pop 2
130 return
133
134 {bagging_loop}:
136 pick 5
137 read_mem {Digest::LEN}
138 place 10
139 hash
142 recurse_or_return
145 )
146 }
147
148 fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
149 [].into()
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use std::collections::HashMap;
156
157 use triton_vm::twenty_first::prelude::Mmr;
158 use twenty_first::math::other::random_elements;
159
160 use super::*;
161 use crate::test_prelude::*;
162
163 impl BagPeaks {
164 fn set_up_initial_state(&self, leaf_count: u64) -> FunctionInitialState {
165 let address = rand::random();
166 let mut stack = self.init_stack_for_isolated_run();
167 push_encodable(&mut stack, &address);
168
169 let mut memory = HashMap::new();
170 let num_peaks = leaf_count.count_ones();
171 let mmra =
172 MmrAccumulator::init(random_elements::<Digest>(num_peaks as usize), leaf_count);
173 encode_to_memory(&mut memory, address, &mmra);
174
175 FunctionInitialState { stack, memory }
176 }
177 }
178
179 impl Function for BagPeaks {
180 fn rust_shadow(
181 &self,
182 stack: &mut Vec<BFieldElement>,
183 memory: &mut HashMap<BFieldElement, BFieldElement>,
184 ) {
185 let address = pop_encodable(stack);
186 let mmra = *MmrAccumulator::decode_from_memory(memory, address).unwrap();
187
188 fn bag_peaks(peaks: &[Digest], leaf_count: u64) -> Digest {
189 let [lo_limb, hi_limb] = leaf_count.encode()[..] else {
191 panic!("internal error: unknown encoding of type `u64`")
192 };
193 let padded_leaf_count = bfe_array![lo_limb, hi_limb, 0, 0, 0, 0, 0, 0, 0, 0];
194 let hashed_leaf_count = Digest::new(Tip5::hash_10(&padded_leaf_count));
195
196 peaks
197 .iter()
198 .rev()
199 .fold(hashed_leaf_count, |acc, &peak| Tip5::hash_pair(peak, acc))
200 }
201
202 let bag = bag_peaks(&mmra.peaks(), mmra.num_leafs());
203 println!("bag: {bag}");
204 push_encodable(stack, &bag);
205 }
206
207 fn pseudorandom_initial_state(
208 &self,
209 seed: [u8; 32],
210 bench_case: Option<BenchmarkCase>,
211 ) -> FunctionInitialState {
212 let num_leafs = match bench_case {
213 Some(BenchmarkCase::CommonCase) => 348753,
214 Some(BenchmarkCase::WorstCase) => 843759843768,
215 None => StdRng::from_seed(seed).random_range(0u64..(u64::MAX >> 1)),
216 };
217
218 self.set_up_initial_state(num_leafs)
219 }
220
221 fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
222 (0..=5)
223 .chain([63])
224 .map(|num_peaks| self.set_up_initial_state((1 << num_peaks) - 1))
225 .collect()
226 }
227 }
228
229 #[test]
230 fn rust_shadow() {
231 ShadowedFunction::new(BagPeaks).test()
232 }
233}
234
235#[cfg(test)]
236mod benches {
237 use super::BagPeaks;
238 use crate::test_prelude::*;
239
240 #[test]
241 fn benchmark() {
242 ShadowedFunction::new(BagPeaks).bench();
243 }
244}