tasm_lib/list/higher_order/
map.rs

1use itertools::Itertools;
2use strum::EnumCount;
3use tasm_lib::list::higher_order::inner_function::InnerFunction;
4use tasm_lib::structure::tasm_object::DEFAULT_MAX_DYN_FIELD_SIZE;
5use triton_vm::isa::op_stack::OpStackElement;
6use triton_vm::prelude::*;
7
8use crate::list::new::New;
9use crate::list::push::Push;
10use crate::prelude::*;
11
12const INNER_FN_INCORRECT_NUM_INPUTS: &str = "Inner function in `map` only works with *one* input. \
13                                             Use a tuple as a workaround.";
14const INNER_FN_INCORRECT_INPUT_DYN_LEN: &str = "An input type of dynamic length to `map`s inner \
15                                                function must be a tuple of form `(bfe, _)`.";
16
17/// Applies a given function to every element of a list, and collects the new
18/// elements into a new list.
19///
20/// Mapping over multiple input lists into one output list, effectively chaining
21/// inputs before applying the map, is possible with [`ChainMap`]. See there for
22/// extended documentation.
23pub type Map = ChainMap<1>;
24
25/// Applies a given function `f` to every element of all given lists, collecting
26/// the new elements into a new list.
27///
28/// The given function `f` must produce elements of a type for which the encoded
29/// length is [statically known][len]. The input type may have either
30/// statically or dynamically known length:
31/// - In the static case, the entire element is placed on the stack before
32///   passing control to `f`.
33/// - In the dynamic case, a memory pointer to the encoded element and the
34///   item's length is placed on the stack before passing control to `f`. The
35///   input list **must** be encoded according to [`BFieldCodec`]. Otherwise,
36///   behavior of `ChainMap` is undefined!
37///
38/// The stack layout is independent of the list currently being processed. This
39/// allows the [`InnerFunction`] `f` to use runtime parameters from the stack.
40/// Note that the chain map requires a certain number of stack registers for
41/// internal purposes. This number can be accessed through
42/// [`ChainMap::NUM_INTERNAL_REGISTERS`]. As mentioned above, the stack layout
43/// upon starting execution of `f` depends on the input type's
44/// [static length][len]. In the static case, the stack layout is:
45///
46/// ```txt
47/// // _ <accessible> [_; ChainMap::<N>::NUM_INTERNAL_REGISTERS] [input_element; len]
48/// ```
49///
50/// In the case of input elements with a dynamic length, the stack layout is:
51///
52/// ```txt
53/// // _ <accessible> [_; ChainMap::<N>::NUM_INTERNAL_REGISTERS] *elem_i elem_i_len
54/// ```
55///
56/// [len]: BFieldCodec::static_length
57pub struct ChainMap<const NUM_INPUT_LISTS: usize> {
58    f: InnerFunction,
59}
60
61impl<const NUM_INPUT_LISTS: usize> ChainMap<NUM_INPUT_LISTS> {
62    /// The number of registers required internally. See [`ChainMap`] for additional
63    /// details.
64    pub const NUM_INTERNAL_REGISTERS: usize = {
65        assert!(NUM_INPUT_LISTS <= Self::MAX_NUM_INPUT_LISTS);
66
67        3 + NUM_INPUT_LISTS
68    };
69
70    /// Need access to all lists, plus a little wiggle room.
71    const MAX_NUM_INPUT_LISTS: usize = OpStackElement::COUNT - 1;
72
73    /// # Panics
74    ///
75    /// - if the input type has [static length] _and_ takes up
76    ///   [`OpStackElement::COUNT`] or more words
77    /// - if the input type has dynamic length and is _anything but_ a tuple
78    ///   `(_, `[`BFieldElement`][bfe]`)`
79    /// - if the output type takes up [`OpStackElement::COUNT`]` - 1` or more words
80    /// - if the output type does not have a [static length][len]
81    ///
82    /// [len]: BFieldCodec::static_length
83    /// [bfe]: DataType::Bfe
84    pub fn new(f: InnerFunction) -> Self {
85        let domain = f.domain();
86        if let Some(input_len) = domain.static_length() {
87            // need instruction `place {input_type.stack_size()}`
88            assert!(input_len < OpStackElement::COUNT);
89        } else {
90            let DataType::Tuple(tuple) = domain else {
91                panic!("{INNER_FN_INCORRECT_INPUT_DYN_LEN}");
92            };
93            let [_, DataType::Bfe] = tuple[..] else {
94                panic!("{INNER_FN_INCORRECT_INPUT_DYN_LEN}");
95            };
96        }
97
98        // need instruction `pick {output_type.stack_size() + 1}`
99        let output_len = f
100            .range()
101            .static_length()
102            .expect("output type's encoding length must be static");
103        assert!(output_len + 1 < OpStackElement::COUNT);
104
105        Self { f }
106    }
107}
108
109impl<const NUM_INPUT_LISTS: usize> BasicSnippet for ChainMap<NUM_INPUT_LISTS> {
110    fn inputs(&self) -> Vec<(DataType, String)> {
111        let list_type = DataType::List(Box::new(self.f.domain()));
112
113        (0..NUM_INPUT_LISTS)
114            .map(|i| (list_type.clone(), format!("*input_list_{i}")))
115            .collect_vec()
116    }
117
118    fn outputs(&self) -> Vec<(DataType, String)> {
119        let list_type = DataType::List(Box::new(self.f.range()));
120        vec![(list_type, "*output_list".to_string())]
121    }
122
123    fn entrypoint(&self) -> String {
124        let maybe_chain_surely_map = if NUM_INPUT_LISTS == 1 {
125            "map".to_string()
126        } else {
127            format!("chain_map_{NUM_INPUT_LISTS}")
128        };
129
130        let f_label = self.f.entrypoint();
131        format!("tasmlib_list_higher_order_u32_{maybe_chain_surely_map}_{f_label}")
132    }
133
134    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
135        if self.f.domain().static_length().is_some() {
136            self.code_for_static_len_input_type(library)
137        } else {
138            self.code_for_dyn_len_input_type(library)
139        }
140    }
141}
142
143struct DecomposedInnerFunction<'body> {
144    exec_or_call: Vec<LabelledInstruction>,
145    fn_body: Option<&'body [LabelledInstruction]>,
146}
147
148impl<const NUM_INPUT_LISTS: usize> ChainMap<NUM_INPUT_LISTS> {
149    fn code_for_static_len_input_type(&self, library: &mut Library) -> Vec<LabelledInstruction> {
150        let input_type = self.f.domain();
151        let output_type = self.f.range();
152        assert!(input_type.static_length().is_some());
153
154        let new_list = library.import(Box::new(New));
155        let inner_fn = self.decompose_inner_fn(library);
156
157        let entrypoint = self.entrypoint();
158        let main_loop_fn = format!("{entrypoint}_loop");
159
160        let mul_elem_size = |n| match n {
161            1 => triton_asm!(),
162            n => triton_asm!(push {n} mul),
163        };
164        let adjust_output_list_pointer = match output_type.stack_size() {
165            0 | 1 => triton_asm!(),
166            n => triton_asm!(addi {-(n as i32 - 1)}),
167        };
168
169        let main_loop_body = triton_asm! {
170            // INVARIANT: _ *end_condition_in_list *output_elem *input_elem
171
172            /* maybe return */
173            // not using `recurse_or_return` to have more room for parameters
174            // that might live on the stack, to and used by the inner function
175            dup 2 dup 1 eq
176            skiz return
177
178            /* read */
179            {&input_type.read_value_from_memory_leave_pointer()}
180            place {input_type.stack_size()}
181                        // _ *end_condition_in_list *output_elem *prev_input_elem [input_elem]
182
183            /* map */
184            {&inner_fn.exec_or_call}
185                        // _ *end_condition_in_list *output_elem *prev_input_elem [output_elem]
186
187            /* write */
188            pick {output_type.stack_size() + 1}
189            {&output_type.write_value_to_memory_leave_pointer()}
190            addi {-2 * output_type.stack_size() as i32}
191            place 1     // _ *end_condition_in_list *prev_output_elem *prev_input_elem
192
193            recurse
194        };
195
196        let map_one_list = triton_asm! {
197            // BEFORE: _ [fill; M]   [*in_list; N-M]   *out_list
198            // AFTER:  _ [fill; M+1] [*in_list; N-M-1] *out_list
199
200            /* read list lengths */
201            read_mem 1
202            addi 1      // _ [_; M] [_; N-M-1] *in_list out_list_len *out_list
203
204            pick 2
205            read_mem 1
206            addi 1      // _ [_; M] [_; N-M-1] out_list_len *out_list in_list_len *in_list
207
208            /* prepare in_list pointer for main loop */
209            dup 1
210            {&mul_elem_size(input_type.stack_size())}
211            dup 1
212            add         // _ [_; M] [_; N-M-1] out_list_len *out_list in_list_len *in_list *in_list_first_elem_last_word
213
214            /* update out_list's len */
215            pick 2
216            pick 4      // _ [_; M] [_; N-M-1] *out_list *in_list *in_list_first_elem_last_word in_list_len out_list_len
217            add         // _ [_; M] [_; N-M-1] *out_list *in_list *in_list_first_elem_last_word new_out_list_len
218
219            dup 0
220            pick 4
221            write_mem 1
222            addi -1     // _ [_; M] [_; N-M-1] *in_list *in_list_first_elem_last_word new_out_list_len *out_list
223
224            /* store *out_list for next iterations */
225            dup 0
226            place 4     // _ [_; M] [_; N-M-1] *out_list *in_list *in_list_first_elem_last_word new_out_list_len *out_list
227
228            /* prepare out_list pointer for main loop */
229            pick 1
230            {&mul_elem_size(output_type.stack_size())}
231            add         // _ [_; M] [_; N-M-1] *out_list *in_list *in_list_first_elem_last_word *out_list_last_elem_last_word
232
233            {&adjust_output_list_pointer}
234            place 1     // _ [_; M] [_; N-M-1] *out_list *in_list *out_list_last_elem_first_word *in_list_first_elem_last_word
235
236            call {main_loop_fn}
237                        hint used_list: Pointer = stack[2]
238                        // _ [_; M] [_; N-M-1] *out_list fill garbage fill
239
240            /* clean up */
241            pop 2
242            place {NUM_INPUT_LISTS}
243                        // _ [_; M+1] [_; N-M-1] *out_list
244        };
245        let map_all_lists = vec![map_one_list; NUM_INPUT_LISTS].concat();
246
247        triton_asm! {
248            // BEFORE: _ [*in_list; N]
249            // AFTER:  _ *out_list
250            {entrypoint}:
251                call {new_list}
252                    hint chain_map_output_list: Pointer = stack[0]
253                {&map_all_lists}
254                place {NUM_INPUT_LISTS}
255                {&Self::pop_input_lists()}
256                return
257            {main_loop_fn}:
258                {&main_loop_body}
259            {&inner_fn.fn_body.unwrap_or_default()}
260        }
261    }
262
263    fn code_for_dyn_len_input_type(&self, library: &mut Library) -> Vec<LabelledInstruction> {
264        let input_type = self.f.domain();
265        let output_type = self.f.range();
266        assert!(input_type.static_length().is_none());
267
268        let new_list = library.import(Box::new(New));
269        let push = library.import(Box::new(Push::new(output_type.clone())));
270        let inner_fn = self.decompose_inner_fn(library);
271
272        let entrypoint = self.entrypoint();
273        let main_loop_fn = format!("{entrypoint}_loop");
274
275        let main_loop_body = triton_asm! {
276            //                ⬐ for Self::NUM_INTERNAL_REGISTERS
277            // BEFORE:    _ fill 0           in_list_len *out_list *in_list[0]_si
278            // INVARIANT: _ fill i           in_list_len *out_list *in_list[i]_si
279            // AFTER:     _ fill in_list_len in_list_len *out_list garbage
280
281            /* maybe return */
282            dup 3
283            dup 3
284            eq
285            skiz return
286
287            /* read field size */
288            read_mem 1  hint item_len = stack[0]
289            addi 2      // _ fill i in_list_len *out_list elem_len *in_list[i]
290
291            /* check field size is reasonable */
292            push {DEFAULT_MAX_DYN_FIELD_SIZE}
293                        hint default_max_dyn_field_size = stack[0]
294            dup 2       // _ fill i in_list_len *out_list l[i]_len *in_list[i] max l[i]_len
295            lt
296            assert      // _ fill i in_list_len *out_list l[i]_len *in_list[i]
297
298            /* advance item iterator */
299            dup 1
300            dup 1
301            add
302            place 2
303
304            /* prepare for inner function */
305            place 1     // _ fill i in_list_len *out_list *in_list[i+1]_si *in_list[i] l[i]_len
306
307            /* map */
308            {&inner_fn.exec_or_call}
309                        // _ fill i in_list_len *out_list *in_list[i]_si [out_elem]
310
311            /* write */
312            dup {output_type.stack_size() + 1}
313            place {output_type.stack_size()}
314            call {push}
315                        // _ fill i in_list_len *out_list *in_list[i]_si
316
317            /* advance i */
318            pick 3
319            addi 1
320            place 3
321                        // _ fill (i+i) in_list_len *out_list *in_list[i+1]_si
322
323            recurse
324        };
325
326        let map_one_list = triton_asm! {
327            // BEFORE: _ [fill; M]   [*in_list; N-M]   *out_list
328            // AFTER:  _ [fill; M+1] [*in_list; N-M-1] *out_list
329
330            /* read in_list length */
331            pick 1
332            read_mem 1  hint in_list_len = stack[1]
333            addi 2      // _ [_; M] [_; N-M-1] *out_list in_list_len *in_list[0]_si
334
335
336            /* setup for main loop */
337            pick 2
338            place 1
339            push 0      hint filler = stack[0]
340            place 3
341            push 0      hint index = stack[0]
342            place 3
343                        // _ [_; M] [_; N-M-1] fill 0 in_list_len *out_list *in_list[0]_si
344
345            call {main_loop_fn}
346
347            /* clean up */
348            pick 1
349            place 4
350            pop 3
351            place {NUM_INPUT_LISTS}
352                        // _ [_; M+1] [_; N-M-1] *out_list
353        };
354        let map_all_lists = vec![map_one_list; NUM_INPUT_LISTS].concat();
355
356        triton_asm! {
357            // BEFORE: _ [*in_list; N]
358            // AFTER:  _ *out_list
359            {entrypoint}:
360                call {new_list}
361                    hint chain_map_output_list: Pointer = stack[0]
362                {&map_all_lists}
363                place {NUM_INPUT_LISTS}
364                {&Self::pop_input_lists()}
365                return
366            {main_loop_fn}:
367                {&main_loop_body}
368            {&inner_fn.fn_body.unwrap_or_default()}
369        }
370    }
371
372    fn decompose_inner_fn(&self, library: &mut Library) -> DecomposedInnerFunction<'_> {
373        let exec_or_call = match &self.f {
374            InnerFunction::RawCode(code) => {
375                // Inlining saves two clock cycles per iteration. If the function cannot be
376                // inlined, it needs to be appended to the function body.
377                code.inlined_body()
378                    .unwrap_or(triton_asm!(call {code.entrypoint()}))
379            }
380            InnerFunction::BasicSnippet(snippet) => {
381                assert_eq!(1, snippet.inputs().len(), "{INNER_FN_INCORRECT_NUM_INPUTS}");
382                let labelled_instructions = snippet.annotated_code(library);
383                let label = library.explicit_import(&snippet.entrypoint(), &labelled_instructions);
384                triton_asm!(call { label })
385            }
386            InnerFunction::NoFunctionBody(lnat) => {
387                triton_asm!(call { lnat.label_name })
388            }
389        };
390
391        let fn_body = if let InnerFunction::RawCode(c) = &self.f {
392            c.inlined_body().is_none().then_some(c.function.as_slice())
393        } else {
394            None
395        };
396
397        DecomposedInnerFunction {
398            exec_or_call,
399            fn_body,
400        }
401    }
402
403    fn pop_input_lists() -> Vec<LabelledInstruction> {
404        match NUM_INPUT_LISTS {
405            0 => triton_asm!(),
406            i @ 1..=5 => triton_asm!(pop { i }),
407            i @ 6..=10 => triton_asm!(pop 5 pop { i - 5 }),
408            i @ 11..=15 => triton_asm!(pop 5 pop 5 pop { i - 10 }),
409            _ => unreachable!("see compile time checks for `NUM_INPUT_LISTS`"),
410        }
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use itertools::Itertools;
417
418    use super::*;
419    use crate::arithmetic;
420    use crate::list::higher_order::inner_function::InnerFunction;
421    use crate::list::higher_order::inner_function::RawCode;
422    use crate::neptune::mutator_set::get_swbf_indices::u32_to_u128_add_another_u128;
423    use crate::rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator;
424    use crate::rust_shadowing_helper_functions::list::list_get;
425    use crate::rust_shadowing_helper_functions::list::list_get_length;
426    use crate::rust_shadowing_helper_functions::list::list_pointer_to_elem_pointer;
427    use crate::rust_shadowing_helper_functions::list::list_set;
428    use crate::rust_shadowing_helper_functions::list::list_set_length;
429    use crate::test_helpers::test_rust_equivalence_given_execution_state;
430    use crate::test_prelude::*;
431
432    impl<const NUM_INPUT_LISTS: usize> ChainMap<NUM_INPUT_LISTS> {
433        fn init_state(
434            &self,
435            environment_args: impl IntoIterator<Item = BFieldElement>,
436            list_lengths: [u16; NUM_INPUT_LISTS],
437            seed: <StdRng as SeedableRng>::Seed,
438        ) -> FunctionInitialState {
439            let input_type = self.f.domain();
440            let mut stack = self.init_stack_for_isolated_run();
441            let mut memory = HashMap::default();
442            let mut rng = StdRng::from_seed(seed);
443
444            stack.extend(environment_args);
445
446            for list_length in list_lengths {
447                let list_length = usize::from(list_length);
448                let list = input_type.random_list(&mut rng, list_length);
449                let list_pointer = dynamic_allocator(&mut memory);
450                let indexed_list = list
451                    .into_iter()
452                    .enumerate()
453                    .map(|(i, v)| (list_pointer + bfe!(i), v));
454
455                memory.extend(indexed_list);
456                stack.push(list_pointer);
457            }
458
459            FunctionInitialState { stack, memory }
460        }
461    }
462
463    impl<const NUM_INPUT_LISTS: usize> Function for ChainMap<NUM_INPUT_LISTS> {
464        fn rust_shadow(
465            &self,
466            stack: &mut Vec<BFieldElement>,
467            memory: &mut HashMap<BFieldElement, BFieldElement>,
468        ) {
469            let input_type = self.f.domain();
470            let output_type = self.f.range();
471
472            New.rust_shadow(stack, memory);
473            let output_list_pointer = stack.pop().unwrap();
474
475            let input_list_pointers = (0..NUM_INPUT_LISTS)
476                .map(|_| stack.pop().unwrap())
477                .collect_vec();
478
479            // the inner function _must not_ rely on these elements
480            let buffer = (0..Self::NUM_INTERNAL_REGISTERS).map(|_| rand::random::<BFieldElement>());
481            stack.extend(buffer);
482
483            let mut total_output_len = 0;
484            for input_list_pointer in input_list_pointers {
485                let input_list_len = list_get_length(input_list_pointer, memory);
486                let output_list_len = list_get_length(output_list_pointer, memory);
487                let new_output_list_len = output_list_len + input_list_len;
488                list_set_length(output_list_pointer, new_output_list_len, memory);
489
490                for i in (0..input_list_len).rev() {
491                    if input_type.static_length().is_some() {
492                        let elem = list_get(input_list_pointer, i, memory, input_type.stack_size());
493                        stack.extend(elem.into_iter().rev());
494                    } else {
495                        let (len, ptr) = list_pointer_to_elem_pointer(
496                            input_list_pointer,
497                            i,
498                            memory,
499                            &input_type,
500                        );
501                        stack.push(ptr);
502                        stack.push(bfe!(len));
503                    };
504                    self.f.apply(stack, memory);
505                    let elem = (0..output_type.stack_size())
506                        .map(|_| stack.pop().unwrap())
507                        .collect();
508                    list_set(output_list_pointer, total_output_len + i, elem, memory);
509                }
510
511                total_output_len += input_list_len;
512            }
513
514            for _ in 0..Self::NUM_INTERNAL_REGISTERS {
515                stack.pop();
516            }
517
518            stack.push(output_list_pointer);
519        }
520
521        fn pseudorandom_initial_state(
522            &self,
523            seed: [u8; 32],
524            bench: Option<BenchmarkCase>,
525        ) -> FunctionInitialState {
526            let mut rng = StdRng::from_seed(seed);
527            let environment_args = rng.random::<[BFieldElement; OpStackElement::COUNT]>();
528
529            let list_lengths = match bench {
530                None => rng.random::<[u8; NUM_INPUT_LISTS]>(),
531                Some(BenchmarkCase::CommonCase) => [10; NUM_INPUT_LISTS],
532                Some(BenchmarkCase::WorstCase) => [100; NUM_INPUT_LISTS],
533            };
534            let list_lengths = list_lengths.map(Into::into);
535
536            self.init_state(environment_args, list_lengths, rng.random())
537        }
538    }
539
540    #[derive(Debug, Clone)]
541    pub(crate) struct TestHashXFieldElement;
542
543    impl BasicSnippet for TestHashXFieldElement {
544        fn inputs(&self) -> Vec<(DataType, String)> {
545            vec![(DataType::Xfe, "element".to_string())]
546        }
547
548        fn outputs(&self) -> Vec<(DataType, String)> {
549            vec![(DataType::Digest, "digest".to_string())]
550        }
551
552        fn entrypoint(&self) -> String {
553            "test_hash_xfield_element".to_string()
554        }
555
556        fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
557            let entrypoint = self.entrypoint();
558            let unused_import = library.import(Box::new(arithmetic::u32::safe_add::SafeAdd));
559            triton_asm!(
560                // BEFORE: _ x2 x1 x0
561                // AFTER:  _ d4 d3 d2 d1 d0
562                {entrypoint}:
563                    push 0 push 0
564                    push 0 push 0
565                    push 0 push 0   // _ x2 x1 x0 0 0 0 0 0 0
566                    push 1          // _ x2 x1 x0 0 0 0 0 0 0 1
567                    pick 9          // _ x1 x0 0 0 0 0 0 0 1 x2
568                    pick 9          // _ x0 0 0 0 0 0 0 1 x2 x1
569                    pick 9          // _ 0 0 0 0 0 0 1 x2 x1 x0
570
571                    // Useless additions, to ensure that imports are accepted inside the
572                    // map-generated code
573                    push 0
574                    push 0
575                    call {unused_import}
576                    pop 1
577
578                    sponge_init
579                    sponge_absorb
580                    sponge_squeeze // _ d9 d8 d7 d6 d5 d4 d3 d2 d1 d0
581                    pick 9 pick 9  // _ d7 d6 d5 d4 d3 d2 d1 d0 d9 d8
582                    pick 9 pick 9  // _ d5 d4 d3 d2 d1 d0 d9 d8 d7 d6
583                    pick 9 pop 5   // _ d4 d3 d2 d1 d0
584                    return
585            )
586        }
587    }
588
589    /// `InnerFunction` does not implement `Clone`, and it is hard (impossible?) to
590    /// teach it. Hence, take a function `f` to generate the `InnerFunction`.
591    ///
592    /// Theoretically, this _could_ mean that `f` produces a different
593    /// `InnerFunction` each time it is called, as every `Fn()` is `FnMut()`.
594    /// Since this is a test helper, how about it doesn't. 😊
595    fn test_chain_map_with_different_num_input_lists(f: impl Fn() -> InnerFunction) {
596        ShadowedFunction::new(ChainMap::<0>::new(f())).test();
597        ShadowedFunction::new(ChainMap::<1>::new(f())).test();
598        ShadowedFunction::new(ChainMap::<2>::new(f())).test();
599        ShadowedFunction::new(ChainMap::<3>::new(f())).test();
600        ShadowedFunction::new(ChainMap::<4>::new(f())).test();
601        ShadowedFunction::new(ChainMap::<5>::new(f())).test();
602
603        ShadowedFunction::new(ChainMap::<7>::new(f())).test();
604        ShadowedFunction::new(ChainMap::<11>::new(f())).test();
605        ShadowedFunction::new(ChainMap::<15>::new(f())).test();
606    }
607
608    #[test]
609    fn test_with_raw_function_identity_on_bfe() {
610        let f = || {
611            InnerFunction::RawCode(RawCode::new(
612                triton_asm!(identity_bfe: return),
613                DataType::Bfe,
614                DataType::Bfe,
615            ))
616        };
617        test_chain_map_with_different_num_input_lists(f);
618    }
619
620    #[test]
621    fn test_with_raw_function_bfe_lift() {
622        let f = || {
623            InnerFunction::RawCode(RawCode::new(
624                triton_asm!(bfe_lift: push 0 push 0 pick 2 return),
625                DataType::Bfe,
626                DataType::Xfe,
627            ))
628        };
629        test_chain_map_with_different_num_input_lists(f);
630    }
631
632    #[test]
633    fn test_with_raw_function_xfe_get_coeff_0() {
634        let f = || {
635            InnerFunction::RawCode(RawCode::new(
636                triton_asm!(get_0: place 2 pop 2 return),
637                DataType::Xfe,
638                DataType::Bfe,
639            ))
640        };
641        test_chain_map_with_different_num_input_lists(f);
642    }
643
644    #[test]
645    fn test_with_raw_function_square_on_bfe() {
646        let f = || {
647            InnerFunction::RawCode(RawCode::new(
648                triton_asm!(square_bfe: dup 0 mul return),
649                DataType::Bfe,
650                DataType::Bfe,
651            ))
652        };
653        test_chain_map_with_different_num_input_lists(f);
654    }
655
656    #[test]
657    fn test_with_raw_function_square_plus_n_on_bfe() {
658        // Inner function calculates `|x| -> x*x + n`, where `x` is the list element,
659        // and `n` is the same value for all elements.
660        fn test_case<const N: usize>() {
661            let raw_code = InnerFunction::RawCode(RawCode::new(
662                triton_asm!(square_plus_n_bfe: dup 0 mul dup {5 + N} add return),
663                DataType::Bfe,
664                DataType::Bfe,
665            ));
666            ShadowedFunction::new(ChainMap::<N>::new(raw_code)).test();
667        }
668
669        test_case::<0>();
670        test_case::<1>();
671        test_case::<2>();
672        test_case::<3>();
673        test_case::<4>();
674        test_case::<5>();
675        test_case::<7>();
676        test_case::<9>();
677        test_case::<10>();
678    }
679
680    #[test]
681    fn test_with_raw_function_square_on_xfe() {
682        let f = || {
683            InnerFunction::RawCode(RawCode::new(
684                triton_asm!(square_xfe: dup 2 dup 2 dup 2 xx_mul return),
685                DataType::Xfe,
686                DataType::Xfe,
687            ))
688        };
689        test_chain_map_with_different_num_input_lists(f);
690    }
691
692    #[test]
693    fn test_with_raw_function_xfe_to_digest() {
694        let f = || {
695            InnerFunction::RawCode(RawCode::new(
696                triton_asm!(xfe_to_digest: push 0 push 0 return),
697                DataType::Xfe,
698                DataType::Digest,
699            ))
700        };
701        test_chain_map_with_different_num_input_lists(f);
702    }
703
704    #[test]
705    fn test_with_raw_function_digest_to_xfe() {
706        let f = || {
707            InnerFunction::RawCode(RawCode::new(
708                triton_asm!(xfe_to_digest: pop 2 return),
709                DataType::Digest,
710                DataType::Xfe,
711            ))
712        };
713        test_chain_map_with_different_num_input_lists(f);
714    }
715
716    #[test]
717    fn test_with_raw_function_square_on_xfe_plus_another_xfe() {
718        fn test_case<const N: usize>() {
719            let offset = ChainMap::<{ N }>::NUM_INTERNAL_REGISTERS;
720            let raw_code = InnerFunction::RawCode(RawCode::new(
721                triton_asm!(
722                    square_xfe_plus_another_xfe:
723                        dup 2 dup 2 dup 2 xx_mul
724                        dup {5 + offset}
725                        dup {5 + offset}
726                        dup {5 + offset}
727                        xx_add
728                        return
729                ),
730                DataType::Xfe,
731                DataType::Xfe,
732            ));
733            ShadowedFunction::new(ChainMap::<N>::new(raw_code)).test();
734        }
735
736        test_case::<0>();
737        test_case::<1>();
738        test_case::<2>();
739        test_case::<3>();
740        test_case::<5>();
741        test_case::<4>();
742        test_case::<6>();
743        test_case::<7>();
744    }
745
746    #[test]
747    fn test_u32_list_to_unit_list() {
748        let f = || {
749            InnerFunction::RawCode(RawCode::new(
750                triton_asm!(remove_elements: pop 1 return),
751                DataType::U32,
752                DataType::Tuple(vec![]),
753            ))
754        };
755        test_chain_map_with_different_num_input_lists(f);
756    }
757
758    #[test]
759    fn test_u32_list_to_u64_list() {
760        let f = || {
761            InnerFunction::RawCode(RawCode::new(
762                triton_asm!(duplicate_u32: dup 0 return),
763                DataType::U32,
764                DataType::U64,
765            ))
766        };
767        test_chain_map_with_different_num_input_lists(f);
768    }
769
770    #[test]
771    fn test_u32_list_to_u128_list_plus_x() {
772        // this code only works with 1 input list
773        let raw_code = InnerFunction::RawCode(u32_to_u128_add_another_u128());
774        let snippet = Map::new(raw_code);
775        let encoded_u128 = rand::random::<u128>().encode();
776        let input_list_len = rand::rng().random_range(0u16..200);
777        let initial_state = snippet.init_state(encoded_u128, [input_list_len], rand::random());
778        test_rust_equivalence_given_execution_state(
779            &ShadowedFunction::new(snippet),
780            initial_state.into(),
781        );
782    }
783
784    #[proptest(cases = 10)]
785    fn num_internal_registers_is_correct(#[strategy(arb())] guard: BFieldElement) {
786        fn test_case<const N: usize>(guard: BFieldElement) {
787            let offset = ChainMap::<{ N }>::NUM_INTERNAL_REGISTERS;
788            let raw_code = InnerFunction::RawCode(RawCode::new(
789                triton_asm! { check_env: dup {offset} push {guard} eq assert return },
790                DataType::Tuple(vec![]),
791                DataType::Tuple(vec![]),
792            ));
793            let snippet = ChainMap::<N>::new(raw_code);
794            let initial_state = snippet.init_state(vec![guard], [1; N], rand::random());
795            test_rust_equivalence_given_execution_state(
796                &ShadowedFunction::new(snippet),
797                initial_state.into(),
798            );
799        }
800
801        test_case::<0>(guard);
802        test_case::<1>(guard);
803        test_case::<2>(guard);
804        test_case::<3>(guard);
805        test_case::<4>(guard);
806        test_case::<5>(guard);
807        test_case::<6>(guard);
808        test_case::<7>(guard);
809        test_case::<8>(guard);
810        test_case::<9>(guard);
811        test_case::<10>(guard);
812        test_case::<11>(guard);
813        test_case::<12>(guard);
814    }
815
816    #[test]
817    fn mapping_over_dynamic_length_items_works() {
818        let f = || {
819            let list_type = DataType::List(Box::new(DataType::Bfe));
820            InnerFunction::RawCode(RawCode::new(
821                triton_asm!(just_forty_twos: pop 2 push 42 return),
822                DataType::Tuple(vec![list_type, DataType::Bfe]),
823                DataType::Bfe,
824            ))
825        };
826        assert!(f().domain().static_length().is_none());
827
828        test_chain_map_with_different_num_input_lists(f);
829    }
830
831    #[test]
832    fn mapping_over_list_of_lists_writing_their_lengths_works() {
833        let f = || {
834            let list_type = DataType::List(Box::new(DataType::Bfe));
835            InnerFunction::RawCode(RawCode::new(
836                triton_asm!(write_list_length: pop 1 read_mem 1 pop 1 return),
837                DataType::Tuple(vec![list_type, DataType::Bfe]),
838                DataType::Bfe,
839            ))
840        };
841        assert!(f().domain().static_length().is_none());
842
843        test_chain_map_with_different_num_input_lists(f);
844    }
845}
846
847#[cfg(test)]
848mod benches {
849    use super::tests::TestHashXFieldElement;
850    use super::*;
851    use crate::list::higher_order::inner_function::InnerFunction;
852    use crate::list::higher_order::inner_function::RawCode;
853    use crate::test_prelude::*;
854
855    #[test]
856    fn map_benchmark() {
857        let f = InnerFunction::BasicSnippet(Box::new(TestHashXFieldElement));
858        ShadowedFunction::new(Map::new(f)).bench();
859    }
860
861    #[test]
862    fn map_with_dyn_items_benchmark() {
863        let list_type = DataType::List(Box::new(DataType::Bfe));
864        let f = InnerFunction::RawCode(RawCode::new(
865            triton_asm!(dyn_length_elements: pop 2 push 42 return),
866            DataType::Tuple(vec![list_type, DataType::Bfe]),
867            DataType::Bfe,
868        ));
869        ShadowedFunction::new(Map::new(f)).bench();
870    }
871}