Skip to main content

tasm_lib/list/higher_order/
all.rs

1use itertools::Itertools;
2use triton_vm::prelude::*;
3
4use super::inner_function::InnerFunction;
5use crate::list::get::Get;
6use crate::list::length::Length;
7use crate::prelude::*;
8
9/// Runs a predicate over all elements of a list and returns true only if all elements satisfy the
10/// predicate.
11pub struct All {
12    pub f: InnerFunction,
13}
14
15impl All {
16    pub fn new(f: InnerFunction) -> Self {
17        Self { f }
18    }
19}
20
21impl BasicSnippet for All {
22    fn parameters(&self) -> Vec<(DataType, String)> {
23        let element_type = self.f.domain();
24        let list_type = DataType::List(Box::new(element_type));
25        vec![(list_type, "*input_list".to_string())]
26    }
27
28    fn return_values(&self) -> Vec<(DataType, String)> {
29        vec![(DataType::Bool, "all_true".to_string())]
30    }
31
32    fn entrypoint(&self) -> String {
33        format!("tasmlib_list_higher_order_u32_all_{}", self.f.entrypoint())
34    }
35
36    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
37        let input_type = self.f.domain();
38        let output_type = self.f.range();
39        assert_eq!(output_type, DataType::Bool);
40
41        let get_length = library.import(Box::new(Length));
42        let list_get = library.import(Box::new(Get::new(input_type)));
43
44        let inner_function_name = match &self.f {
45            InnerFunction::RawCode(rc) => rc.entrypoint(),
46            InnerFunction::NoFunctionBody(_) => todo!(),
47            InnerFunction::BasicSnippet(bs) => {
48                let labelled_instructions = bs.annotated_code(library);
49                library.explicit_import(&bs.entrypoint(), &labelled_instructions)
50            }
51        };
52
53        // If function was supplied as raw instructions, we need to append the inner function to the function
54        // body. Otherwise, `library` handles the imports.
55        let maybe_inner_function_body_raw = match &self.f {
56            InnerFunction::RawCode(rc) => rc.function.iter().map(|x| x.to_string()).join("\n"),
57            InnerFunction::NoFunctionBody(_) => todo!(),
58            InnerFunction::BasicSnippet(_) => Default::default(),
59        };
60        let entrypoint = self.entrypoint();
61        let main_loop = format!("{entrypoint}_loop");
62
63        let result_type_hint = format!("hint all_{}: Boolean = stack[0]", self.f.entrypoint());
64
65        triton_asm!(
66            // BEFORE: _ input_list
67            // AFTER:  _ result
68            {entrypoint}:
69                hint input_list = stack[0]
70                push 1  // _ input_list res
71                {result_type_hint}
72                swap 1  // _ res input_list
73                dup 0   // _ res input_list input_list
74                call {get_length}
75                hint list_item: Index = stack[0]
76                        // _ res input_list len
77
78                call {main_loop}
79                        // _ res input_list 0
80
81                pop 2   // _ res
82                return
83
84            // INVARIANT: _ res input_list index
85            {main_loop}:
86                // test return condition
87                dup 0 push 0 eq
88                        // _ res input_list index index==0
89
90                skiz return
91                        // _ res input_list index
92
93                // decrement index
94                push -1 add
95
96                // body
97
98                // read
99                dup 1 dup 1
100                        // _ res input_list index input_list index
101                call {list_get}
102                        // _ res input_list index [input_elements]
103
104                // compute predicate
105                call {inner_function_name}
106                        // _ res input_list index b
107
108                // accumulate
109                dup 3   // _ res input_list index b res
110                mul     // _ res input_list index (b && res)
111                swap 3  // _ (b && res) input_list index res
112                pop 1   // _ (b && res) input_list index
113
114                recurse
115
116            {maybe_inner_function_body_raw}
117        )
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use num::One;
124    use num::Zero;
125
126    use super::*;
127    use crate::arithmetic;
128    use crate::empty_stack;
129    use crate::list::LIST_METADATA_SIZE;
130    use crate::list::higher_order::inner_function::RawCode;
131    use crate::rust_shadowing_helper_functions;
132    use crate::rust_shadowing_helper_functions::list::list_get;
133    use crate::rust_shadowing_helper_functions::list::untyped_insert_random_list;
134    use crate::test_helpers::test_rust_equivalence_given_complete_state;
135    use crate::test_prelude::*;
136
137    impl All {
138        fn generate_input_state(
139            &self,
140            list_pointer: BFieldElement,
141            list_length: usize,
142            random: bool,
143        ) -> InitVmState {
144            let mut stack = empty_stack();
145            stack.push(list_pointer);
146
147            let mut memory = HashMap::default();
148            let input_type = self.f.domain();
149            let list_bookkeeping_offset = LIST_METADATA_SIZE;
150            let element_index_in_list =
151                list_bookkeeping_offset + list_length * input_type.stack_size();
152            let element_index = list_pointer + BFieldElement::new(element_index_in_list as u64);
153            memory.insert(BFieldElement::zero(), element_index);
154
155            if random {
156                untyped_insert_random_list(
157                    list_pointer,
158                    list_length,
159                    &mut memory,
160                    input_type.stack_size(),
161                );
162            } else {
163                rust_shadowing_helper_functions::list::list_insert(
164                    list_pointer,
165                    (0..list_length as u64)
166                        .map(BFieldElement::new)
167                        .collect_vec(),
168                    &mut memory,
169                );
170            }
171
172            InitVmState::with_stack_and_memory(stack, memory)
173        }
174    }
175
176    impl Function for All {
177        fn rust_shadow(
178            &self,
179            stack: &mut Vec<BFieldElement>,
180            memory: &mut HashMap<BFieldElement, BFieldElement>,
181        ) -> Result<(), RustShadowError> {
182            let input_type = self.f.domain();
183            let list_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
184
185            // forall elements, read + map + maybe copy
186            let list_length =
187                rust_shadowing_helper_functions::list::list_get_length(list_pointer, memory)?;
188            let mut satisfied = true;
189            for i in 0..list_length {
190                let input_item = list_get(list_pointer, i, memory, input_type.stack_size())?;
191                for bfe in input_item.into_iter().rev() {
192                    stack.push(bfe);
193                }
194
195                self.f.apply(stack, memory);
196
197                let single_result =
198                    stack.pop().ok_or(RustShadowError::StackUnderflow)?.value() != 0;
199                satisfied = satisfied && single_result;
200            }
201
202            stack.push(BFieldElement::new(satisfied as u64));
203            Ok(())
204        }
205
206        fn pseudorandom_initial_state(
207            &self,
208            seed: [u8; 32],
209            bench_case: Option<BenchmarkCase>,
210        ) -> FunctionInitialState {
211            let (stack, memory) = match bench_case {
212                Some(BenchmarkCase::CommonCase) => {
213                    let list_pointer = BFieldElement::new(5);
214                    let list_length = 10;
215                    let execution_state =
216                        self.generate_input_state(list_pointer, list_length, false);
217                    (execution_state.stack, execution_state.nondeterminism.ram)
218                }
219                Some(BenchmarkCase::WorstCase) => {
220                    let list_pointer = BFieldElement::new(5);
221                    let list_length = 100;
222                    let execution_state =
223                        self.generate_input_state(list_pointer, list_length, false);
224                    (execution_state.stack, execution_state.nondeterminism.ram)
225                }
226                None => {
227                    let mut rng = StdRng::from_seed(seed);
228                    let list_pointer = BFieldElement::new(rng.next_u64() % (1 << 20));
229                    let list_length = 1 << (rng.next_u32() as usize % 4);
230                    let execution_state =
231                        self.generate_input_state(list_pointer, list_length, true);
232                    (execution_state.stack, execution_state.nondeterminism.ram)
233                }
234            };
235
236            FunctionInitialState { stack, memory }
237        }
238    }
239
240    #[macro_rules_attr::apply(test)]
241    fn rust_shadow() {
242        let inner_function = InnerFunction::BasicSnippet(Box::new(TestHashXFieldElementLsb));
243        ShadowedFunction::new(All::new(inner_function)).test();
244    }
245
246    #[macro_rules_attr::apply(test)]
247    fn all_lt_test() {
248        const TWO_POW_31: u64 = 1u64 << 31;
249        let rawcode = RawCode::new(
250            triton_asm!(
251                less_than_2_pow_31:
252                    push 2147483648 // == 2^31
253                    swap 1
254                    lt
255                    return
256            ),
257            DataType::Bfe,
258            DataType::Bool,
259        );
260        let snippet = All::new(InnerFunction::RawCode(rawcode));
261        let mut memory = HashMap::new();
262
263        // Should return true
264        rust_shadowing_helper_functions::list::list_insert(
265            BFieldElement::new(42),
266            (0..30).map(BFieldElement::new).collect_vec(),
267            &mut memory,
268        );
269        let input_stack = [empty_stack(), vec![BFieldElement::new(42)]].concat();
270        let expected_end_stack_true = [empty_stack(), vec![BFieldElement::one()]].concat();
271        let shadowed_snippet = ShadowedFunction::new(snippet);
272        let mut nondeterminism = NonDeterminism::default().with_ram(memory);
273        test_rust_equivalence_given_complete_state(
274            &shadowed_snippet,
275            &input_stack,
276            &[],
277            &nondeterminism,
278            &None,
279            Some(&expected_end_stack_true),
280        );
281
282        // Should return false
283        rust_shadowing_helper_functions::list::list_insert(
284            BFieldElement::new(42),
285            (0..30)
286                .map(|x| BFieldElement::new(x + TWO_POW_31 - 20))
287                .collect_vec(),
288            &mut nondeterminism.ram,
289        );
290        let expected_end_stack_false = [empty_stack(), vec![BFieldElement::zero()]].concat();
291        test_rust_equivalence_given_complete_state(
292            &shadowed_snippet,
293            &input_stack,
294            &[],
295            &nondeterminism,
296            &None,
297            Some(&expected_end_stack_false),
298        );
299    }
300
301    #[macro_rules_attr::apply(test)]
302    fn test_with_raw_function_lsb_on_bfe() {
303        let rawcode = RawCode::new(
304            triton_asm!(
305                lsb_bfe:
306                split    // _ hi lo
307                push 2   // _ hi lo 2
308                swap 1   // _ hi 2 lo
309                div_mod  // _ hi q r
310                swap 2   // _ r q hi
311                pop 2    // _ r
312                return
313            ),
314            DataType::Bfe,
315            DataType::Bool,
316        );
317        let snippet = All::new(InnerFunction::RawCode(rawcode));
318        ShadowedFunction::new(snippet).test();
319    }
320
321    #[macro_rules_attr::apply(test)]
322    fn test_with_raw_function_eq_42() {
323        let raw_code = RawCode::new(
324            triton_asm!(
325                eq_42:
326                push 42
327                eq
328                return
329            ),
330            DataType::U32,
331            DataType::Bool,
332        );
333        let snippet = All::new(InnerFunction::RawCode(raw_code));
334        ShadowedFunction::new(snippet).test();
335    }
336
337    #[macro_rules_attr::apply(test)]
338    fn test_with_raw_function_lsb_on_xfe() {
339        let rawcode = RawCode::new(
340            triton_asm!(
341                lsb_xfe:
342                split    // _ x2 x1 hi lo
343                push 2   // _ x2 x1 hi lo 2
344                swap 1   // _ x2 x1 hi 2 lo
345                div_mod  // _ x2 x1 hi q r
346                swap 4   // _ r x1 q hi x2
347                pop 4    // _ r x1 q hi
348                return
349            ),
350            DataType::Xfe,
351            DataType::Bool,
352        );
353        let snippet = All::new(InnerFunction::RawCode(rawcode));
354        ShadowedFunction::new(snippet).test();
355    }
356
357    /// Only used for tests. Please don't export this.
358    #[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
359    pub(super) struct TestHashXFieldElementLsb;
360
361    impl BasicSnippet for TestHashXFieldElementLsb {
362        fn parameters(&self) -> Vec<(DataType, String)> {
363            vec![(DataType::Xfe, "element".to_string())]
364        }
365
366        fn return_values(&self) -> Vec<(DataType, String)> {
367            vec![(DataType::Bool, "bool".to_string())]
368        }
369
370        fn entrypoint(&self) -> String {
371            "test_hash_xfield_element_lsb".to_string()
372        }
373
374        fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
375            let entrypoint = self.entrypoint();
376            let unused_import = library.import(Box::new(arithmetic::u32::safe_add::SafeAdd));
377            triton_asm!(
378            // BEFORE: _ [x: XFieldElement]
379            // AFTER:  _ [b: bool]
380            {entrypoint}:
381                /* Useless additions: ensure that dependencies are accepted inside
382                 * the generated code of `all`
383                 */
384                push 0
385                push 0
386                call {unused_import}
387                pop 1
388
389                push 0
390                push 0
391                push 0
392                push 0
393                push 0
394                push 0
395                push 1  // _ x2 x1 x0 0 0 0 0 0 0 1
396                pick 9
397                pick 9
398                pick 9  // _ 0 0 0 0 0 0 1 x2 x1 x0
399
400                sponge_init
401                sponge_absorb
402                sponge_squeeze
403                        // _ [d; 10]
404
405                split
406                push 2
407                place 1
408                div_mod // _ [d'; 9] d0_hi (d0_lo // 2) (d0_lo % 2)
409
410                place 11
411                pop 5
412                pop 5
413                pop 1   // _ (d0_lo % 2)
414
415                return
416            )
417        }
418    }
419}
420
421#[cfg(test)]
422mod benches {
423    use super::tests::TestHashXFieldElementLsb;
424    use super::*;
425    use crate::test_prelude::*;
426
427    #[macro_rules_attr::apply(test)]
428    fn benchmark() {
429        let inner_function = InnerFunction::BasicSnippet(Box::new(TestHashXFieldElementLsb));
430        ShadowedFunction::new(All::new(inner_function)).bench();
431    }
432}