tasm_lib/list/higher_order/
all.rs

1use itertools::Itertools;
2use triton_vm::prelude::*;
3
4use super::inner_function::InnerFunction;
5use crate::list::get::Get;
6use crate::list::length::Length;
7use crate::prelude::*;
8
9/// Runs a predicate over all elements of a list and returns true only if all elements satisfy the
10/// predicate.
11pub struct All {
12    pub f: InnerFunction,
13}
14
15impl All {
16    pub fn new(f: InnerFunction) -> Self {
17        Self { f }
18    }
19}
20
21impl BasicSnippet for All {
22    fn inputs(&self) -> Vec<(DataType, String)> {
23        let element_type = self.f.domain();
24        let list_type = DataType::List(Box::new(element_type));
25        vec![(list_type, "*input_list".to_string())]
26    }
27
28    fn outputs(&self) -> Vec<(DataType, String)> {
29        vec![(DataType::Bool, "all_true".to_string())]
30    }
31
32    fn entrypoint(&self) -> String {
33        format!("tasmlib_list_higher_order_u32_all_{}", self.f.entrypoint())
34    }
35
36    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
37        let input_type = self.f.domain();
38        let output_type = self.f.range();
39        assert_eq!(output_type, DataType::Bool);
40
41        let get_length = library.import(Box::new(Length));
42        let list_get = library.import(Box::new(Get::new(input_type)));
43
44        let inner_function_name = match &self.f {
45            InnerFunction::RawCode(rc) => rc.entrypoint(),
46            InnerFunction::NoFunctionBody(_) => todo!(),
47            InnerFunction::BasicSnippet(bs) => {
48                let labelled_instructions = bs.annotated_code(library);
49                library.explicit_import(&bs.entrypoint(), &labelled_instructions)
50            }
51        };
52
53        // If function was supplied as raw instructions, we need to append the inner function to the function
54        // body. Otherwise, `library` handles the imports.
55        let maybe_inner_function_body_raw = match &self.f {
56            InnerFunction::RawCode(rc) => rc.function.iter().map(|x| x.to_string()).join("\n"),
57            InnerFunction::NoFunctionBody(_) => todo!(),
58            InnerFunction::BasicSnippet(_) => Default::default(),
59        };
60        let entrypoint = self.entrypoint();
61        let main_loop = format!("{entrypoint}_loop");
62
63        let result_type_hint = format!("hint all_{}: Boolean = stack[0]", self.f.entrypoint());
64
65        triton_asm!(
66            // BEFORE: _ input_list
67            // AFTER:  _ result
68            {entrypoint}:
69                hint input_list = stack[0]
70                push 1  // _ input_list res
71                {result_type_hint}
72                swap 1  // _ res input_list
73                dup 0   // _ res input_list input_list
74                call {get_length}
75                hint list_item: Index = stack[0]
76                        // _ res input_list len
77
78                call {main_loop}
79                        // _ res input_list 0
80
81                pop 2   // _ res
82                return
83
84            // INVARIANT: _ res input_list index
85            {main_loop}:
86                // test return condition
87                dup 0 push 0 eq
88                        // _ res input_list index index==0
89
90                skiz return
91                        // _ res input_list index
92
93                // decrement index
94                push -1 add
95
96                // body
97
98                // read
99                dup 1 dup 1
100                        // _ res input_list index input_list index
101                call {list_get}
102                        // _ res input_list index [input_elements]
103
104                // compute predicate
105                call {inner_function_name}
106                        // _ res input_list index b
107
108                // accumulate
109                dup 3   // _ res input_list index b res
110                mul     // _ res input_list index (b && res)
111                swap 3  // _ (b && res) input_list index res
112                pop 1   // _ (b && res) input_list index
113
114                recurse
115
116            {maybe_inner_function_body_raw}
117        )
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use num::One;
124    use num::Zero;
125
126    use super::*;
127    use crate::arithmetic;
128    use crate::empty_stack;
129    use crate::list::LIST_METADATA_SIZE;
130    use crate::list::higher_order::inner_function::RawCode;
131    use crate::rust_shadowing_helper_functions;
132    use crate::rust_shadowing_helper_functions::list::list_get;
133    use crate::rust_shadowing_helper_functions::list::untyped_insert_random_list;
134    use crate::test_helpers::test_rust_equivalence_given_complete_state;
135    use crate::test_prelude::*;
136
137    impl All {
138        fn generate_input_state(
139            &self,
140            list_pointer: BFieldElement,
141            list_length: usize,
142            random: bool,
143        ) -> InitVmState {
144            let mut stack = empty_stack();
145            stack.push(list_pointer);
146
147            let mut memory = HashMap::default();
148            let input_type = self.f.domain();
149            let list_bookkeeping_offset = LIST_METADATA_SIZE;
150            let element_index_in_list =
151                list_bookkeeping_offset + list_length * input_type.stack_size();
152            let element_index = list_pointer + BFieldElement::new(element_index_in_list as u64);
153            memory.insert(BFieldElement::zero(), element_index);
154
155            if random {
156                untyped_insert_random_list(
157                    list_pointer,
158                    list_length,
159                    &mut memory,
160                    input_type.stack_size(),
161                );
162            } else {
163                rust_shadowing_helper_functions::list::list_insert(
164                    list_pointer,
165                    (0..list_length as u64)
166                        .map(BFieldElement::new)
167                        .collect_vec(),
168                    &mut memory,
169                );
170            }
171
172            InitVmState::with_stack_and_memory(stack, memory)
173        }
174    }
175
176    impl Function for All {
177        fn rust_shadow(
178            &self,
179            stack: &mut Vec<BFieldElement>,
180            memory: &mut HashMap<BFieldElement, BFieldElement>,
181        ) {
182            let input_type = self.f.domain();
183            let list_pointer = stack.pop().unwrap();
184
185            // forall elements, read + map + maybe copy
186            let list_length =
187                rust_shadowing_helper_functions::list::list_get_length(list_pointer, memory);
188            let mut satisfied = true;
189            for i in 0..list_length {
190                let input_item = list_get(list_pointer, i, memory, input_type.stack_size());
191                for bfe in input_item.into_iter().rev() {
192                    stack.push(bfe);
193                }
194
195                self.f.apply(stack, memory);
196
197                let single_result = stack.pop().unwrap().value() != 0;
198                satisfied = satisfied && single_result;
199            }
200
201            stack.push(BFieldElement::new(satisfied as u64));
202        }
203
204        fn pseudorandom_initial_state(
205            &self,
206            seed: [u8; 32],
207            bench_case: Option<BenchmarkCase>,
208        ) -> FunctionInitialState {
209            let (stack, memory) = match bench_case {
210                Some(BenchmarkCase::CommonCase) => {
211                    let list_pointer = BFieldElement::new(5);
212                    let list_length = 10;
213                    let execution_state =
214                        self.generate_input_state(list_pointer, list_length, false);
215                    (execution_state.stack, execution_state.nondeterminism.ram)
216                }
217                Some(BenchmarkCase::WorstCase) => {
218                    let list_pointer = BFieldElement::new(5);
219                    let list_length = 100;
220                    let execution_state =
221                        self.generate_input_state(list_pointer, list_length, false);
222                    (execution_state.stack, execution_state.nondeterminism.ram)
223                }
224                None => {
225                    let mut rng = StdRng::from_seed(seed);
226                    let list_pointer = BFieldElement::new(rng.next_u64() % (1 << 20));
227                    let list_length = 1 << (rng.next_u32() as usize % 4);
228                    let execution_state =
229                        self.generate_input_state(list_pointer, list_length, true);
230                    (execution_state.stack, execution_state.nondeterminism.ram)
231                }
232            };
233
234            FunctionInitialState { stack, memory }
235        }
236    }
237
238    #[test]
239    fn rust_shadow() {
240        let inner_function = InnerFunction::BasicSnippet(Box::new(TestHashXFieldElementLsb));
241        ShadowedFunction::new(All::new(inner_function)).test();
242    }
243
244    #[test]
245    fn all_lt_test() {
246        const TWO_POW_31: u64 = 1u64 << 31;
247        let rawcode = RawCode::new(
248            triton_asm!(
249                less_than_2_pow_31:
250                    push 2147483648 // == 2^31
251                    swap 1
252                    lt
253                    return
254            ),
255            DataType::Bfe,
256            DataType::Bool,
257        );
258        let snippet = All::new(InnerFunction::RawCode(rawcode));
259        let mut memory = HashMap::new();
260
261        // Should return true
262        rust_shadowing_helper_functions::list::list_insert(
263            BFieldElement::new(42),
264            (0..30).map(BFieldElement::new).collect_vec(),
265            &mut memory,
266        );
267        let input_stack = [empty_stack(), vec![BFieldElement::new(42)]].concat();
268        let expected_end_stack_true = [empty_stack(), vec![BFieldElement::one()]].concat();
269        let shadowed_snippet = ShadowedFunction::new(snippet);
270        let mut nondeterminism = NonDeterminism::default().with_ram(memory);
271        test_rust_equivalence_given_complete_state(
272            &shadowed_snippet,
273            &input_stack,
274            &[],
275            &nondeterminism,
276            &None,
277            Some(&expected_end_stack_true),
278        );
279
280        // Should return false
281        rust_shadowing_helper_functions::list::list_insert(
282            BFieldElement::new(42),
283            (0..30)
284                .map(|x| BFieldElement::new(x + TWO_POW_31 - 20))
285                .collect_vec(),
286            &mut nondeterminism.ram,
287        );
288        let expected_end_stack_false = [empty_stack(), vec![BFieldElement::zero()]].concat();
289        test_rust_equivalence_given_complete_state(
290            &shadowed_snippet,
291            &input_stack,
292            &[],
293            &nondeterminism,
294            &None,
295            Some(&expected_end_stack_false),
296        );
297    }
298
299    #[test]
300    fn test_with_raw_function_lsb_on_bfe() {
301        let rawcode = RawCode::new(
302            triton_asm!(
303                lsb_bfe:
304                split    // _ hi lo
305                push 2   // _ hi lo 2
306                swap 1   // _ hi 2 lo
307                div_mod  // _ hi q r
308                swap 2   // _ r q hi
309                pop 2    // _ r
310                return
311            ),
312            DataType::Bfe,
313            DataType::Bool,
314        );
315        let snippet = All::new(InnerFunction::RawCode(rawcode));
316        ShadowedFunction::new(snippet).test();
317    }
318
319    #[test]
320    fn test_with_raw_function_eq_42() {
321        let raw_code = RawCode::new(
322            triton_asm!(
323                eq_42:
324                push 42
325                eq
326                return
327            ),
328            DataType::U32,
329            DataType::Bool,
330        );
331        let snippet = All::new(InnerFunction::RawCode(raw_code));
332        ShadowedFunction::new(snippet).test();
333    }
334
335    #[test]
336    fn test_with_raw_function_lsb_on_xfe() {
337        let rawcode = RawCode::new(
338            triton_asm!(
339                lsb_xfe:
340                split    // _ x2 x1 hi lo
341                push 2   // _ x2 x1 hi lo 2
342                swap 1   // _ x2 x1 hi 2 lo
343                div_mod  // _ x2 x1 hi q r
344                swap 4   // _ r x1 q hi x2
345                pop 4    // _ r x1 q hi
346                return
347            ),
348            DataType::Xfe,
349            DataType::Bool,
350        );
351        let snippet = All::new(InnerFunction::RawCode(rawcode));
352        ShadowedFunction::new(snippet).test();
353    }
354
355    /// Only used for tests. Please don't export this.
356    #[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
357    pub(super) struct TestHashXFieldElementLsb;
358
359    impl BasicSnippet for TestHashXFieldElementLsb {
360        fn inputs(&self) -> Vec<(DataType, String)> {
361            vec![(DataType::Xfe, "element".to_string())]
362        }
363
364        fn outputs(&self) -> Vec<(DataType, String)> {
365            vec![(DataType::Bool, "bool".to_string())]
366        }
367
368        fn entrypoint(&self) -> String {
369            "test_hash_xfield_element_lsb".to_string()
370        }
371
372        fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
373            let entrypoint = self.entrypoint();
374            let unused_import = library.import(Box::new(arithmetic::u32::safe_add::SafeAdd));
375            triton_asm!(
376            // BEFORE: _ [x: XFieldElement]
377            // AFTER:  _ [b: bool]
378            {entrypoint}:
379                /* Useless additions: ensure that dependencies are accepted inside
380                 * the generated code of `all`
381                 */
382                push 0
383                push 0
384                call {unused_import}
385                pop 1
386
387                push 0
388                push 0
389                push 0
390                push 0
391                push 0
392                push 0
393                push 1  // _ x2 x1 x0 0 0 0 0 0 0 1
394                pick 9
395                pick 9
396                pick 9  // _ 0 0 0 0 0 0 1 x2 x1 x0
397
398                sponge_init
399                sponge_absorb
400                sponge_squeeze
401                        // _ [d; 10]
402
403                split
404                push 2
405                place 1
406                div_mod // _ [d'; 9] d0_hi (d0_lo // 2) (d0_lo % 2)
407
408                place 11
409                pop 5
410                pop 5
411                pop 1   // _ (d0_lo % 2)
412
413                return
414            )
415        }
416    }
417}
418
419#[cfg(test)]
420mod benches {
421    use super::tests::TestHashXFieldElementLsb;
422    use super::*;
423    use crate::test_prelude::*;
424
425    #[test]
426    fn benchmark() {
427        let inner_function = InnerFunction::BasicSnippet(Box::new(TestHashXFieldElementLsb));
428        ShadowedFunction::new(All::new(inner_function)).bench();
429    }
430}