tasm_lib/neptune/mutator_set/
get_swbf_indices.rs

1use triton_vm::prelude::*;
2
3use crate::arithmetic::u128::shift_left_static;
4use crate::arithmetic::u128::shift_right_static;
5use crate::hashing::algebraic_hasher::sample_indices::SampleIndices;
6use crate::list::higher_order::inner_function::InnerFunction;
7use crate::list::higher_order::inner_function::RawCode;
8use crate::list::higher_order::map::Map;
9use crate::prelude::*;
10
11const LOG2_BATCH_SIZE: u8 = 3;
12const LOG2_CHUNK_SIZE: u8 = 12;
13
14/// Derives the indices that make up the removal record from the item
15/// (a digest), the sender randomness (also a digest), receiver
16/// preimage (ditto), and the item's aocl leaf index.
17#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
18pub struct GetSwbfIndices {
19    pub window_size: u32,
20    pub num_trials: usize,
21}
22
23impl BasicSnippet for GetSwbfIndices {
24    fn inputs(&self) -> Vec<(DataType, String)> {
25        vec![
26            (DataType::U64, "aocl_leaf".to_string()),
27            (DataType::Digest, "receiver_preimage".to_string()),
28            (DataType::Digest, "sender_randomness".to_string()),
29            (DataType::Digest, "item".to_string()),
30        ]
31    }
32
33    fn outputs(&self) -> Vec<(DataType, String)> {
34        vec![(DataType::VoidPointer, "*index_list".to_string())]
35    }
36
37    fn entrypoint(&self) -> String {
38        format!(
39            "tasmlib_neptune_mutator_get_swbf_indices_{}_{}",
40            self.window_size, self.num_trials
41        )
42    }
43
44    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
45        let num_trials = self.num_trials;
46        let window_size = self.window_size;
47        let sample_indices = library.import(Box::new(SampleIndices));
48
49        let entrypoint = self.entrypoint();
50
51        let map_add_batch_offset = library.import(Box::new(Map::new(InnerFunction::RawCode(
52            u32_to_u128_add_another_u128(),
53        ))));
54
55        // TODO: This can be replaced by a bit-mask to save some clock cycles
56        let divide_by_batch_size = library.import(Box::new(
57            shift_right_static::ShiftRightStatic::<LOG2_BATCH_SIZE>,
58        ));
59        let mul_by_chunk_size = library.import(Box::new(
60            shift_left_static::ShiftLeftStatic::<LOG2_CHUNK_SIZE>,
61        ));
62
63        triton_asm!(
64        // BEFORE: _ li_hi li_lo r4 r3 r2 r1 r0 s4 s3 s2 s1 s0 i4 i3 i2 i1 i0
65        // AFTER:  _ index_list
66        {entrypoint}:
67
68            sponge_init
69            sponge_absorb
70            // _ li_hi li_lo r4 r3 r2 r1 r0
71
72            push 0
73            push 0
74            // _ li_hi li_lo r4 r3 r2 r1 r0 0 0
75
76            dup 8 dup 8
77            // _ li_hi li_lo r4 r3 r2 r1 r0 0 0 li_hi li_lo
78
79            push 0
80            push 0
81            push 1
82            // _ li_hi li_lo r4 r3 r2 r1 r0 0 0 li_hi li_lo {0 0 1
83
84            dup 4 dup 4
85            // _ li_hi li_lo r4 r3 r2 r1 r0 0 0 li_hi li_lo {0 0 1 li_hi li_lo
86
87            dup 13 dup 13 dup 13 dup 13 dup 13
88            // _ li_hi li_lo r4 r3 r2 r1 r0 0 0 li_hi li_lo {0 0 1 li_hi li_lo r4 r3 r2 r1 r0}
89
90            sponge_absorb
91            // _ li_hi li_lo r4 r3 r2 r1 r0 0 0 li_hi li_lo
92
93            call {divide_by_batch_size}
94            // _ li_hi li_lo r4 r3 r2 r1 r0 (li / bs)_0 (li / bs)_1 (li / bs)_2 (li / bs)_3
95
96            call {mul_by_chunk_size}
97            // _ li_hi li_lo r4 r3 r2 r1 r0 [batch_offset_u128]
98
99            push {num_trials} // _ li_hi li_lo r4 r3 r2 r1 r0 [batch_offset_u128] number
100            push {window_size} // _ li_hi li_lo r4 r3 r2 r1 r0 [batch_offset_u128] number upper_bound
101            call {sample_indices} // _ li_hi li_lo r4 r3 r2 r1 r0 [batch_offset_u128] *list_of_indices_as_u32s
102
103
104            call {map_add_batch_offset}
105            // _ li_hi li_lo r4 r3 r2 r1 r0 [batch_offset_u128] *list_of_absolute_indices_as_u128s
106
107            swap 11 pop 5 pop 5 pop 1
108            // *list_of_absolute_indices_as_u128s
109
110            return
111        )
112    }
113}
114
115/// ```text
116/// BEFORE: _ [x_3, x_2, x_1, x_0] [bu ff er] input_u32
117/// AFTER:  _ [x_3, x_2, x_1, x_0] [bu ff er] output_3 output_2 output_1 output_0
118/// ```
119pub(crate) fn u32_to_u128_add_another_u128() -> RawCode {
120    let buffer_len = Map::NUM_INTERNAL_REGISTERS;
121    let assembly = triton_asm!(
122        u32_to_u128_add_another_u128:
123        dup {buffer_len + 1}
124        add     // _ [x_3, x_2, x_1, x_0] [bu ff er] (input_u32 + x_0)
125        split   // _ [x_3, x_2, x_1, x_0] [bu ff er] carry_to_1 output_0
126        pick 1  // _ [x_3, x_2, x_1, x_0] [bu ff er] output_0 carry_to_1
127        dup {buffer_len + 3}
128        add
129        split   // _ [x_3, x_2, x_1, x_0] [bu ff er] output_0 carry_to_2 output_1
130        pick 1  // _ [x_3, x_2, x_1, x_0] [bu ff er] output_0 output_1 carry_to_2
131        dup {buffer_len + 5}
132        add
133        split   // _ [x_3, x_2, x_1, x_0] [bu ff er] output_0 output_1 carry_to_3 output_2
134        pick 1  // _ [x_3, x_2, x_1, x_0] [bu ff er] output_0 output_1 output_2 carry_to_3
135        dup {buffer_len + 7}
136        add
137        split   // _ [x_3, x_2, x_1, x_0] [bu ff er] output_0 output_1 output_2 overflow output_3
138        pick 1  // _ [x_3, x_2, x_1, x_0] [bu ff er] output_0 output_1 output_2 output_3 overflow
139
140        // verify no overflow
141        push 0
142        eq
143        assert  // _ [x_3, x_2, x_1, x_0] [bu ff er] output_0 output_1 output_2 output_3
144        place 3
145        place 2
146        place 1 // _ [x_3, x_2, x_1, x_0] [bu ff er] output_3 output_2 output_1 output_0
147        return
148    );
149    RawCode::new(assembly, DataType::U32, DataType::U128)
150}
151
152#[cfg(test)]
153mod tests {
154    use twenty_first::prelude::Sponge;
155
156    use super::*;
157    use crate::empty_stack;
158    use crate::rust_shadowing_helper_functions;
159    use crate::test_prelude::*;
160
161    const NUM_TRIALS: usize = 45;
162    const LOG2_WINDOW_SIZE: u32 = 20;
163
164    // Copy-pasted from mutator set implementation
165    // Was there no other way besides code duplication? 😩
166    fn get_swbf_indices(
167        item: &Digest,
168        sender_randomness: &Digest,
169        receiver_preimage: &Digest,
170        aocl_leaf_index: u64,
171    ) -> [u128; 45_usize] {
172        let batch_index: u128 = aocl_leaf_index as u128 / (1 << LOG2_BATCH_SIZE) as u128;
173        let batch_offset: u128 = batch_index * (1 << LOG2_CHUNK_SIZE) as u128;
174        let leaf_index_bfes = aocl_leaf_index.encode();
175        let input = [
176            item.encode(),
177            sender_randomness.encode(),
178            receiver_preimage.encode(),
179            leaf_index_bfes,
180        ]
181        .concat();
182        let mut sponge = Tip5::init();
183        sponge.pad_and_absorb_all(&input);
184        sponge
185            .sample_indices(1 << LOG2_WINDOW_SIZE, NUM_TRIALS)
186            .into_iter()
187            .map(|sample_index| sample_index as u128 + batch_offset)
188            .collect_vec()
189            .try_into()
190            .unwrap()
191    }
192
193    impl Function for GetSwbfIndices {
194        fn rust_shadow(
195            &self,
196            stack: &mut Vec<BFieldElement>,
197            memory: &mut HashMap<BFieldElement, BFieldElement>,
198        ) {
199            let item = pop_encodable::<Digest>(stack);
200            let sender_randomness = pop_encodable::<Digest>(stack);
201            let receiver_preimage = pop_encodable::<Digest>(stack);
202            let aocl_leaf_index = pop_encodable::<u64>(stack);
203
204            let mut sponge_seed = [item, sender_randomness, receiver_preimage]
205                .map(|d| d.values())
206                .concat();
207            sponge_seed.extend(aocl_leaf_index.encode());
208
209            let mut sponge = Tip5::init();
210            sponge.pad_and_absorb_all(&sponge_seed);
211
212            let mut u32_indices = vec![];
213            let mut squeezed_elements = vec![];
214            while u32_indices.len() != self.num_trials {
215                if squeezed_elements.is_empty() {
216                    squeezed_elements = sponge.squeeze().into_iter().rev().collect_vec();
217                }
218                let element = squeezed_elements.pop().unwrap();
219                if element != BFieldElement::new(BFieldElement::MAX) {
220                    u32_indices.push(element.value() as u32 % self.window_size);
221                }
222            }
223
224            let u32_list_pointer =
225                rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator(memory);
226            rust_shadowing_helper_functions::list::list_new(u32_list_pointer, memory);
227
228            rust_shadowing_helper_functions::list::list_set_length(
229                u32_list_pointer,
230                self.num_trials,
231                memory,
232            );
233
234            for (i, index) in u32_indices.iter().enumerate() {
235                rust_shadowing_helper_functions::list::list_set(
236                    u32_list_pointer,
237                    i,
238                    vec![BFieldElement::new(*index as u64)],
239                    memory,
240                );
241            }
242
243            // Compare derived indices to actual implementation (copy-pasted from
244            // mutator set implementaion.
245            // Why would you ever want to copy-paste code like this? 😫
246            let indices_from_mutator_set = get_swbf_indices(
247                &item,
248                &sender_randomness,
249                &receiver_preimage,
250                aocl_leaf_index,
251            );
252
253            // let batch_offset = aocl_leaf_index_u64 as u128 /  (1 << LOG2_BATCH_SIZE)  as u128;
254            let batch_index: u128 = aocl_leaf_index as u128 / (1 << LOG2_BATCH_SIZE) as u128;
255            let batch_offset: u128 = batch_index * (1 << LOG2_CHUNK_SIZE) as u128;
256            let u128_indices = u32_indices
257                .into_iter()
258                .map(|x| (x as u128) + batch_offset)
259                .collect_vec();
260
261            // Sanity check that this RUST-shadowing agrees with the real deal
262            assert_eq!(
263                indices_from_mutator_set.to_vec(),
264                u128_indices,
265                "VM-calculated indices must match that from mutator set module"
266            );
267
268            let u128_list_pointer =
269                rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator(memory);
270            rust_shadowing_helper_functions::list::list_insert(
271                u128_list_pointer,
272                u128_indices,
273                memory,
274            );
275
276            stack.push(u128_list_pointer);
277        }
278
279        fn pseudorandom_initial_state(
280            &self,
281            seed: [u8; 32],
282            _bench_case: Option<crate::snippet_bencher::BenchmarkCase>,
283        ) -> FunctionInitialState {
284            let mut rng = StdRng::from_seed(seed);
285            let mut stack = empty_stack();
286            let (item, sender_randomness, receiver_preimage, aocl_leaf_index): (
287                Digest,
288                Digest,
289                Digest,
290                u64,
291            ) = (rng.random(), rng.random(), rng.random(), rng.random());
292            stack.push(BFieldElement::new(aocl_leaf_index >> 32));
293            stack.push(BFieldElement::new(aocl_leaf_index & u32::MAX as u64));
294            stack.push(receiver_preimage.values()[4]);
295            stack.push(receiver_preimage.values()[3]);
296            stack.push(receiver_preimage.values()[2]);
297            stack.push(receiver_preimage.values()[1]);
298            stack.push(receiver_preimage.values()[0]);
299            stack.push(sender_randomness.values()[4]);
300            stack.push(sender_randomness.values()[3]);
301            stack.push(sender_randomness.values()[2]);
302            stack.push(sender_randomness.values()[1]);
303            stack.push(sender_randomness.values()[0]);
304            stack.push(item.values()[4]);
305            stack.push(item.values()[3]);
306            stack.push(item.values()[2]);
307            stack.push(item.values()[1]);
308            stack.push(item.values()[0]);
309
310            FunctionInitialState {
311                memory: HashMap::new(),
312                stack,
313            }
314        }
315    }
316
317    #[test]
318    fn test() {
319        ShadowedFunction::new(GetSwbfIndices {
320            window_size: 1048576,
321            num_trials: 45,
322        })
323        .test();
324    }
325}
326
327#[cfg(test)]
328mod benches {
329    use super::*;
330    use crate::test_prelude::*;
331
332    #[test]
333    fn benchmark() {
334        ShadowedFunction::new(GetSwbfIndices {
335            window_size: 1048576,
336            num_trials: 45,
337        })
338        .bench();
339    }
340}