tasm_lib/list/higher_order/
map.rs

1use itertools::Itertools;
2use strum::EnumCount;
3use tasm_lib::list::higher_order::inner_function::InnerFunction;
4use tasm_lib::structure::tasm_object::DEFAULT_MAX_DYN_FIELD_SIZE;
5use triton_vm::isa::op_stack::OpStackElement;
6use triton_vm::prelude::*;
7
8use crate::list::new::New;
9use crate::list::push::Push;
10use crate::prelude::*;
11
12const INNER_FN_INCORRECT_NUM_INPUTS: &str = "Inner function in `map` only works with *one* input. \
13                                             Use a tuple as a workaround.";
14const INNER_FN_INCORRECT_INPUT_DYN_LEN: &str = "An input type of dynamic length to `map`s inner \
15                                                function must be a tuple of form `(bfe, _)`.";
16
17/// Applies a given function to every element of a list, and collects the new
18/// elements into a new list.
19///
20/// Mapping over multiple input lists into one output list, effectively chaining
21/// inputs before applying the map, is possible with [`ChainMap`]. See there for
22/// extended documentation.
23pub type Map = ChainMap<1>;
24
25/// Applies a given function `f` to every element of all given lists, collecting
26/// the new elements into a new list.
27///
28/// The given function `f` must produce elements of a type for which the encoded
29/// length is [statically known][len]. The input type may have either
30/// statically or dynamically known length:
31/// - In the static case, the entire element is placed on the stack before
32///   passing control to `f`.
33/// - In the dynamic case, a memory pointer to the encoded element and the
34///   item's length is placed on the stack before passing control to `f`. The
35///   input list **must** be encoded according to [`BFieldCodec`]. Otherwise,
36///   behavior of `ChainMap` is undefined!
37///
38/// The stack layout is independent of the list currently being processed. This
39/// allows the [`InnerFunction`] `f` to use runtime parameters from the stack.
40/// Note that the chain map requires a certain number of stack registers for
41/// internal purposes. This number can be accessed through
42/// [`ChainMap::NUM_INTERNAL_REGISTERS`]. As mentioned above, the stack layout
43/// upon starting execution of `f` depends on the input type's
44/// [static length][len]. In the static case, the stack layout is:
45///
46/// ```txt
47/// // _ <accessible> [_; ChainMap::<N>::NUM_INTERNAL_REGISTERS] [input_element; len]
48/// ```
49///
50/// In the case of input elements with a dynamic length, the stack layout is:
51///
52/// ```txt
53/// // _ <accessible> [_; ChainMap::<N>::NUM_INTERNAL_REGISTERS] *elem_i elem_i_len
54/// ```
55///
56/// [len]: BFieldCodec::static_length
57pub struct ChainMap<const NUM_INPUT_LISTS: usize> {
58    f: InnerFunction,
59}
60
61impl<const NUM_INPUT_LISTS: usize> ChainMap<NUM_INPUT_LISTS> {
62    /// The number of registers required internally. See [`ChainMap`] for additional
63    /// details.
64    pub const NUM_INTERNAL_REGISTERS: usize = {
65        assert!(NUM_INPUT_LISTS <= Self::MAX_NUM_INPUT_LISTS);
66
67        3 + NUM_INPUT_LISTS
68    };
69
70    /// Need access to all lists, plus a little wiggle room.
71    const MAX_NUM_INPUT_LISTS: usize = OpStackElement::COUNT - 1;
72
73    /// # Panics
74    ///
75    /// - if the input type has [static length] _and_ takes up
76    ///   [`OpStackElement::COUNT`] or more words
77    /// - if the input type has dynamic length and is _anything but_ a tuple
78    ///   `(_, `[`BFieldElement`][bfe]`)`
79    /// - if the output type takes up [`OpStackElement::COUNT`]` - 1` or more words
80    /// - if the output type does not have a [static length][len]
81    ///
82    /// [len]: BFieldCodec::static_length
83    /// [bfe]: DataType::Bfe
84    pub fn new(f: InnerFunction) -> Self {
85        let domain = f.domain();
86        if let Some(input_len) = domain.static_length() {
87            // need instruction `place {input_type.stack_size()}`
88            assert!(input_len < OpStackElement::COUNT);
89        } else {
90            let DataType::Tuple(tuple) = domain else {
91                panic!("{INNER_FN_INCORRECT_INPUT_DYN_LEN}");
92            };
93            let [_, DataType::Bfe] = tuple[..] else {
94                panic!("{INNER_FN_INCORRECT_INPUT_DYN_LEN}");
95            };
96        }
97
98        // need instruction `pick {output_type.stack_size() + 1}`
99        let output_len = f
100            .range()
101            .static_length()
102            .expect("output type's encoding length must be static");
103        assert!(output_len + 1 < OpStackElement::COUNT);
104
105        Self { f }
106    }
107}
108
109impl<const NUM_INPUT_LISTS: usize> BasicSnippet for ChainMap<NUM_INPUT_LISTS> {
110    fn parameters(&self) -> Vec<(DataType, String)> {
111        let list_type = DataType::List(Box::new(self.f.domain()));
112
113        (0..NUM_INPUT_LISTS)
114            .map(|i| (list_type.clone(), format!("*input_list_{i}")))
115            .collect_vec()
116    }
117
118    fn return_values(&self) -> Vec<(DataType, String)> {
119        let list_type = DataType::List(Box::new(self.f.range()));
120        vec![(list_type, "*output_list".to_string())]
121    }
122
123    fn entrypoint(&self) -> String {
124        let maybe_chain_surely_map = if NUM_INPUT_LISTS == 1 {
125            "map".to_string()
126        } else {
127            format!("chain_map_{NUM_INPUT_LISTS}")
128        };
129
130        let f_label = self.f.entrypoint();
131        format!("tasmlib_list_higher_order_u32_{maybe_chain_surely_map}_{f_label}")
132    }
133
134    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
135        if self.f.domain().static_length().is_some() {
136            self.code_for_static_len_input_type(library)
137        } else {
138            self.code_for_dyn_len_input_type(library)
139        }
140    }
141}
142
143struct DecomposedInnerFunction<'body> {
144    exec_or_call: Vec<LabelledInstruction>,
145    fn_body: Option<&'body [LabelledInstruction]>,
146}
147
148impl<const NUM_INPUT_LISTS: usize> ChainMap<NUM_INPUT_LISTS> {
149    fn code_for_static_len_input_type(&self, library: &mut Library) -> Vec<LabelledInstruction> {
150        let input_type = self.f.domain();
151        let output_type = self.f.range();
152        assert!(input_type.static_length().is_some());
153
154        let new_list = library.import(Box::new(New));
155        let inner_fn = self.decompose_inner_fn(library);
156
157        let entrypoint = self.entrypoint();
158        let main_loop_fn = format!("{entrypoint}_loop");
159
160        let mul_elem_size = |n| match n {
161            1 => triton_asm!(),
162            n => triton_asm!(push {n} mul),
163        };
164        let adjust_output_list_pointer = match output_type.stack_size() {
165            0 | 1 => triton_asm!(),
166            n => triton_asm!(addi {-(n as i32 - 1)}),
167        };
168
169        let main_loop_body = triton_asm! {
170            // INVARIANT: _ *end_condition_in_list *output_elem *input_elem
171
172            /* maybe return */
173            // not using `recurse_or_return` to have more room for parameters
174            // that might live on the stack, to and used by the inner function
175            dup 2 dup 1 eq
176            skiz return
177
178            /* read */
179            {&input_type.read_value_from_memory_leave_pointer()}
180            place {input_type.stack_size()}
181                        // _ *end_condition_in_list *output_elem *prev_input_elem [input_elem]
182
183            /* map */
184            {&inner_fn.exec_or_call}
185                        // _ *end_condition_in_list *output_elem *prev_input_elem [output_elem]
186
187            /* write */
188            pick {output_type.stack_size() + 1}
189            {&output_type.write_value_to_memory_leave_pointer()}
190            addi {-2 * output_type.stack_size() as i32}
191            place 1     // _ *end_condition_in_list *prev_output_elem *prev_input_elem
192
193            recurse
194        };
195
196        let map_one_list = triton_asm! {
197            // BEFORE: _ [fill; M]   [*in_list; N-M]   *out_list
198            // AFTER:  _ [fill; M+1] [*in_list; N-M-1] *out_list
199
200            /* read list lengths */
201            read_mem 1
202            addi 1      // _ [_; M] [_; N-M-1] *in_list out_list_len *out_list
203
204            pick 2
205            read_mem 1
206            addi 1      // _ [_; M] [_; N-M-1] out_list_len *out_list in_list_len *in_list
207
208            /* prepare in_list pointer for main loop */
209            dup 1
210            {&mul_elem_size(input_type.stack_size())}
211            dup 1
212            add         // _ [_; M] [_; N-M-1] out_list_len *out_list in_list_len *in_list *in_list_first_elem_last_word
213
214            /* update out_list's len */
215            pick 2
216            pick 4      // _ [_; M] [_; N-M-1] *out_list *in_list *in_list_first_elem_last_word in_list_len out_list_len
217            add         // _ [_; M] [_; N-M-1] *out_list *in_list *in_list_first_elem_last_word new_out_list_len
218
219            dup 0
220            pick 4
221            write_mem 1
222            addi -1     // _ [_; M] [_; N-M-1] *in_list *in_list_first_elem_last_word new_out_list_len *out_list
223
224            /* store *out_list for next iterations */
225            dup 0
226            place 4     // _ [_; M] [_; N-M-1] *out_list *in_list *in_list_first_elem_last_word new_out_list_len *out_list
227
228            /* prepare out_list pointer for main loop */
229            pick 1
230            {&mul_elem_size(output_type.stack_size())}
231            add         // _ [_; M] [_; N-M-1] *out_list *in_list *in_list_first_elem_last_word *out_list_last_elem_last_word
232
233            {&adjust_output_list_pointer}
234            place 1     // _ [_; M] [_; N-M-1] *out_list *in_list *out_list_last_elem_first_word *in_list_first_elem_last_word
235
236            call {main_loop_fn}
237                        hint used_list: Pointer = stack[2]
238                        // _ [_; M] [_; N-M-1] *out_list fill garbage fill
239
240            /* clean up */
241            pop 2
242            place {NUM_INPUT_LISTS}
243                        // _ [_; M+1] [_; N-M-1] *out_list
244        };
245        let map_all_lists = vec![map_one_list; NUM_INPUT_LISTS].concat();
246
247        triton_asm! {
248            // BEFORE: _ [*in_list; N]
249            // AFTER:  _ *out_list
250            {entrypoint}:
251                call {new_list}
252                    hint chain_map_output_list: Pointer = stack[0]
253                {&map_all_lists}
254                place {NUM_INPUT_LISTS}
255                {&Self::pop_input_lists()}
256                return
257            {main_loop_fn}:
258                {&main_loop_body}
259            {&inner_fn.fn_body.unwrap_or_default()}
260        }
261    }
262
263    fn code_for_dyn_len_input_type(&self, library: &mut Library) -> Vec<LabelledInstruction> {
264        let input_type = self.f.domain();
265        let output_type = self.f.range();
266        assert!(input_type.static_length().is_none());
267
268        let new_list = library.import(Box::new(New));
269        let push = library.import(Box::new(Push::new(output_type.clone())));
270        let inner_fn = self.decompose_inner_fn(library);
271
272        let entrypoint = self.entrypoint();
273        let main_loop_fn = format!("{entrypoint}_loop");
274
275        let main_loop_body = triton_asm! {
276            //                ⬐ for Self::NUM_INTERNAL_REGISTERS
277            // BEFORE:    _ fill 0           in_list_len *out_list *in_list[0]_si
278            // INVARIANT: _ fill i           in_list_len *out_list *in_list[i]_si
279            // AFTER:     _ fill in_list_len in_list_len *out_list garbage
280
281            /* maybe return */
282            dup 3
283            dup 3
284            eq
285            skiz return
286
287            /* read field size */
288            read_mem 1  hint item_len = stack[0]
289            addi 2      // _ fill i in_list_len *out_list elem_len *in_list[i]
290
291            /* check field size is reasonable */
292            push {DEFAULT_MAX_DYN_FIELD_SIZE}
293                        hint default_max_dyn_field_size = stack[0]
294            dup 2       // _ fill i in_list_len *out_list l[i]_len *in_list[i] max l[i]_len
295            lt
296            assert      // _ fill i in_list_len *out_list l[i]_len *in_list[i]
297
298            /* advance item iterator */
299            dup 1
300            dup 1
301            add
302            place 2
303
304            /* prepare for inner function */
305            place 1     // _ fill i in_list_len *out_list *in_list[i+1]_si *in_list[i] l[i]_len
306
307            /* map */
308            {&inner_fn.exec_or_call}
309                        // _ fill i in_list_len *out_list *in_list[i]_si [out_elem]
310
311            /* write */
312            dup {output_type.stack_size() + 1}
313            place {output_type.stack_size()}
314            call {push}
315                        // _ fill i in_list_len *out_list *in_list[i]_si
316
317            /* advance i */
318            pick 3
319            addi 1
320            place 3
321                        // _ fill (i+i) in_list_len *out_list *in_list[i+1]_si
322
323            recurse
324        };
325
326        let map_one_list = triton_asm! {
327            // BEFORE: _ [fill; M]   [*in_list; N-M]   *out_list
328            // AFTER:  _ [fill; M+1] [*in_list; N-M-1] *out_list
329
330            /* read in_list length */
331            pick 1
332            read_mem 1  hint in_list_len = stack[1]
333            addi 2      // _ [_; M] [_; N-M-1] *out_list in_list_len *in_list[0]_si
334
335
336            /* setup for main loop */
337            pick 2
338            place 1
339            push 0      hint filler = stack[0]
340            place 3
341            push 0      hint index = stack[0]
342            place 3
343                        // _ [_; M] [_; N-M-1] fill 0 in_list_len *out_list *in_list[0]_si
344
345            call {main_loop_fn}
346
347            /* clean up */
348            pick 1
349            place 4
350            pop 3
351            place {NUM_INPUT_LISTS}
352                        // _ [_; M+1] [_; N-M-1] *out_list
353        };
354        let map_all_lists = vec![map_one_list; NUM_INPUT_LISTS].concat();
355
356        triton_asm! {
357            // BEFORE: _ [*in_list; N]
358            // AFTER:  _ *out_list
359            {entrypoint}:
360                call {new_list}
361                    hint chain_map_output_list: Pointer = stack[0]
362                {&map_all_lists}
363                place {NUM_INPUT_LISTS}
364                {&Self::pop_input_lists()}
365                return
366            {main_loop_fn}:
367                {&main_loop_body}
368            {&inner_fn.fn_body.unwrap_or_default()}
369        }
370    }
371
372    fn decompose_inner_fn(&self, library: &mut Library) -> DecomposedInnerFunction<'_> {
373        let exec_or_call = match &self.f {
374            InnerFunction::RawCode(code) => {
375                // Inlining saves two clock cycles per iteration. If the function cannot be
376                // inlined, it needs to be appended to the function body.
377                code.inlined_body()
378                    .unwrap_or(triton_asm!(call {code.entrypoint()}))
379            }
380            InnerFunction::BasicSnippet(snippet) => {
381                assert_eq!(
382                    1,
383                    snippet.parameters().len(),
384                    "{INNER_FN_INCORRECT_NUM_INPUTS}"
385                );
386                let labelled_instructions = snippet.annotated_code(library);
387                let label = library.explicit_import(&snippet.entrypoint(), &labelled_instructions);
388                triton_asm!(call { label })
389            }
390            InnerFunction::NoFunctionBody(lnat) => {
391                triton_asm!(call { lnat.label_name })
392            }
393        };
394
395        let fn_body = if let InnerFunction::RawCode(c) = &self.f {
396            c.inlined_body().is_none().then_some(c.function.as_slice())
397        } else {
398            None
399        };
400
401        DecomposedInnerFunction {
402            exec_or_call,
403            fn_body,
404        }
405    }
406
407    fn pop_input_lists() -> Vec<LabelledInstruction> {
408        match NUM_INPUT_LISTS {
409            0 => triton_asm!(),
410            i @ 1..=5 => triton_asm!(pop { i }),
411            i @ 6..=10 => triton_asm!(pop 5 pop { i - 5 }),
412            i @ 11..=15 => triton_asm!(pop 5 pop 5 pop { i - 10 }),
413            _ => unreachable!("see compile time checks for `NUM_INPUT_LISTS`"),
414        }
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use itertools::Itertools;
421
422    use super::*;
423    use crate::arithmetic;
424    use crate::list::higher_order::inner_function::InnerFunction;
425    use crate::list::higher_order::inner_function::RawCode;
426    use crate::neptune::mutator_set::get_swbf_indices::u32_to_u128_add_another_u128;
427    use crate::rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator;
428    use crate::rust_shadowing_helper_functions::list::list_get;
429    use crate::rust_shadowing_helper_functions::list::list_get_length;
430    use crate::rust_shadowing_helper_functions::list::list_pointer_to_elem_pointer;
431    use crate::rust_shadowing_helper_functions::list::list_set;
432    use crate::rust_shadowing_helper_functions::list::list_set_length;
433    use crate::test_helpers::test_rust_equivalence_given_execution_state;
434    use crate::test_prelude::*;
435
436    impl<const NUM_INPUT_LISTS: usize> ChainMap<NUM_INPUT_LISTS> {
437        fn init_state(
438            &self,
439            environment_args: impl IntoIterator<Item = BFieldElement>,
440            list_lengths: [u16; NUM_INPUT_LISTS],
441            seed: <StdRng as SeedableRng>::Seed,
442        ) -> FunctionInitialState {
443            let input_type = self.f.domain();
444            let mut stack = self.init_stack_for_isolated_run();
445            let mut memory = HashMap::default();
446            let mut rng = StdRng::from_seed(seed);
447
448            stack.extend(environment_args);
449
450            for list_length in list_lengths {
451                let list_length = usize::from(list_length);
452                let list = input_type.random_list(&mut rng, list_length);
453                let list_pointer = dynamic_allocator(&mut memory);
454                let indexed_list = list
455                    .into_iter()
456                    .enumerate()
457                    .map(|(i, v)| (list_pointer + bfe!(i), v));
458
459                memory.extend(indexed_list);
460                stack.push(list_pointer);
461            }
462
463            FunctionInitialState { stack, memory }
464        }
465    }
466
467    impl<const NUM_INPUT_LISTS: usize> Function for ChainMap<NUM_INPUT_LISTS> {
468        fn rust_shadow(
469            &self,
470            stack: &mut Vec<BFieldElement>,
471            memory: &mut HashMap<BFieldElement, BFieldElement>,
472        ) {
473            let input_type = self.f.domain();
474            let output_type = self.f.range();
475
476            New.rust_shadow(stack, memory);
477            let output_list_pointer = stack.pop().unwrap();
478
479            let input_list_pointers = (0..NUM_INPUT_LISTS)
480                .map(|_| stack.pop().unwrap())
481                .collect_vec();
482
483            // the inner function _must not_ rely on these elements
484            let buffer = (0..Self::NUM_INTERNAL_REGISTERS).map(|_| rand::random::<BFieldElement>());
485            stack.extend(buffer);
486
487            let mut total_output_len = 0;
488            for input_list_pointer in input_list_pointers {
489                let input_list_len = list_get_length(input_list_pointer, memory);
490                let output_list_len = list_get_length(output_list_pointer, memory);
491                let new_output_list_len = output_list_len + input_list_len;
492                list_set_length(output_list_pointer, new_output_list_len, memory);
493
494                for i in (0..input_list_len).rev() {
495                    if input_type.static_length().is_some() {
496                        let elem = list_get(input_list_pointer, i, memory, input_type.stack_size());
497                        stack.extend(elem.into_iter().rev());
498                    } else {
499                        let (len, ptr) = list_pointer_to_elem_pointer(
500                            input_list_pointer,
501                            i,
502                            memory,
503                            &input_type,
504                        );
505                        stack.push(ptr);
506                        stack.push(bfe!(len));
507                    };
508                    self.f.apply(stack, memory);
509                    let elem = (0..output_type.stack_size())
510                        .map(|_| stack.pop().unwrap())
511                        .collect();
512                    list_set(output_list_pointer, total_output_len + i, elem, memory);
513                }
514
515                total_output_len += input_list_len;
516            }
517
518            for _ in 0..Self::NUM_INTERNAL_REGISTERS {
519                stack.pop();
520            }
521
522            stack.push(output_list_pointer);
523        }
524
525        fn pseudorandom_initial_state(
526            &self,
527            seed: [u8; 32],
528            bench: Option<BenchmarkCase>,
529        ) -> FunctionInitialState {
530            let mut rng = StdRng::from_seed(seed);
531            let environment_args = rng.random::<[BFieldElement; OpStackElement::COUNT]>();
532
533            let list_lengths = match bench {
534                None => rng.random::<[u8; NUM_INPUT_LISTS]>(),
535                Some(BenchmarkCase::CommonCase) => [10; NUM_INPUT_LISTS],
536                Some(BenchmarkCase::WorstCase) => [100; NUM_INPUT_LISTS],
537            };
538            let list_lengths = list_lengths.map(Into::into);
539
540            self.init_state(environment_args, list_lengths, rng.random())
541        }
542    }
543
544    #[derive(Debug, Clone)]
545    pub(crate) struct TestHashXFieldElement;
546
547    impl BasicSnippet for TestHashXFieldElement {
548        fn parameters(&self) -> Vec<(DataType, String)> {
549            vec![(DataType::Xfe, "element".to_string())]
550        }
551
552        fn return_values(&self) -> Vec<(DataType, String)> {
553            vec![(DataType::Digest, "digest".to_string())]
554        }
555
556        fn entrypoint(&self) -> String {
557            "test_hash_xfield_element".to_string()
558        }
559
560        fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
561            let entrypoint = self.entrypoint();
562            let unused_import = library.import(Box::new(arithmetic::u32::safe_add::SafeAdd));
563            triton_asm!(
564                // BEFORE: _ x2 x1 x0
565                // AFTER:  _ d4 d3 d2 d1 d0
566                {entrypoint}:
567                    push 0 push 0
568                    push 0 push 0
569                    push 0 push 0   // _ x2 x1 x0 0 0 0 0 0 0
570                    push 1          // _ x2 x1 x0 0 0 0 0 0 0 1
571                    pick 9          // _ x1 x0 0 0 0 0 0 0 1 x2
572                    pick 9          // _ x0 0 0 0 0 0 0 1 x2 x1
573                    pick 9          // _ 0 0 0 0 0 0 1 x2 x1 x0
574
575                    // Useless additions, to ensure that imports are accepted inside the
576                    // map-generated code
577                    push 0
578                    push 0
579                    call {unused_import}
580                    pop 1
581
582                    sponge_init
583                    sponge_absorb
584                    sponge_squeeze // _ d9 d8 d7 d6 d5 d4 d3 d2 d1 d0
585                    pick 9 pick 9  // _ d7 d6 d5 d4 d3 d2 d1 d0 d9 d8
586                    pick 9 pick 9  // _ d5 d4 d3 d2 d1 d0 d9 d8 d7 d6
587                    pick 9 pop 5   // _ d4 d3 d2 d1 d0
588                    return
589            )
590        }
591    }
592
593    /// `InnerFunction` does not implement `Clone`, and it is hard (impossible?) to
594    /// teach it. Hence, take a function `f` to generate the `InnerFunction`.
595    ///
596    /// Theoretically, this _could_ mean that `f` produces a different
597    /// `InnerFunction` each time it is called, as every `Fn()` is `FnMut()`.
598    /// Since this is a test helper, how about it doesn't. 😊
599    fn test_chain_map_with_different_num_input_lists(f: impl Fn() -> InnerFunction) {
600        ShadowedFunction::new(ChainMap::<0>::new(f())).test();
601        ShadowedFunction::new(ChainMap::<1>::new(f())).test();
602        ShadowedFunction::new(ChainMap::<2>::new(f())).test();
603        ShadowedFunction::new(ChainMap::<3>::new(f())).test();
604        ShadowedFunction::new(ChainMap::<4>::new(f())).test();
605        ShadowedFunction::new(ChainMap::<5>::new(f())).test();
606
607        ShadowedFunction::new(ChainMap::<7>::new(f())).test();
608        ShadowedFunction::new(ChainMap::<11>::new(f())).test();
609        ShadowedFunction::new(ChainMap::<15>::new(f())).test();
610    }
611
612    #[test]
613    fn test_with_raw_function_identity_on_bfe() {
614        let f = || {
615            InnerFunction::RawCode(RawCode::new(
616                triton_asm!(identity_bfe: return),
617                DataType::Bfe,
618                DataType::Bfe,
619            ))
620        };
621        test_chain_map_with_different_num_input_lists(f);
622    }
623
624    #[test]
625    fn test_with_raw_function_bfe_lift() {
626        let f = || {
627            InnerFunction::RawCode(RawCode::new(
628                triton_asm!(bfe_lift: push 0 push 0 pick 2 return),
629                DataType::Bfe,
630                DataType::Xfe,
631            ))
632        };
633        test_chain_map_with_different_num_input_lists(f);
634    }
635
636    #[test]
637    fn test_with_raw_function_xfe_get_coeff_0() {
638        let f = || {
639            InnerFunction::RawCode(RawCode::new(
640                triton_asm!(get_0: place 2 pop 2 return),
641                DataType::Xfe,
642                DataType::Bfe,
643            ))
644        };
645        test_chain_map_with_different_num_input_lists(f);
646    }
647
648    #[test]
649    fn test_with_raw_function_square_on_bfe() {
650        let f = || {
651            InnerFunction::RawCode(RawCode::new(
652                triton_asm!(square_bfe: dup 0 mul return),
653                DataType::Bfe,
654                DataType::Bfe,
655            ))
656        };
657        test_chain_map_with_different_num_input_lists(f);
658    }
659
660    #[test]
661    fn test_with_raw_function_square_plus_n_on_bfe() {
662        // Inner function calculates `|x| -> x*x + n`, where `x` is the list element,
663        // and `n` is the same value for all elements.
664        fn test_case<const N: usize>() {
665            let raw_code = InnerFunction::RawCode(RawCode::new(
666                triton_asm!(square_plus_n_bfe: dup 0 mul dup {5 + N} add return),
667                DataType::Bfe,
668                DataType::Bfe,
669            ));
670            ShadowedFunction::new(ChainMap::<N>::new(raw_code)).test();
671        }
672
673        test_case::<0>();
674        test_case::<1>();
675        test_case::<2>();
676        test_case::<3>();
677        test_case::<4>();
678        test_case::<5>();
679        test_case::<7>();
680        test_case::<9>();
681        test_case::<10>();
682    }
683
684    #[test]
685    fn test_with_raw_function_square_on_xfe() {
686        let f = || {
687            InnerFunction::RawCode(RawCode::new(
688                triton_asm!(square_xfe: dup 2 dup 2 dup 2 xx_mul return),
689                DataType::Xfe,
690                DataType::Xfe,
691            ))
692        };
693        test_chain_map_with_different_num_input_lists(f);
694    }
695
696    #[test]
697    fn test_with_raw_function_xfe_to_digest() {
698        let f = || {
699            InnerFunction::RawCode(RawCode::new(
700                triton_asm!(xfe_to_digest: push 0 push 0 return),
701                DataType::Xfe,
702                DataType::Digest,
703            ))
704        };
705        test_chain_map_with_different_num_input_lists(f);
706    }
707
708    #[test]
709    fn test_with_raw_function_digest_to_xfe() {
710        let f = || {
711            InnerFunction::RawCode(RawCode::new(
712                triton_asm!(xfe_to_digest: pop 2 return),
713                DataType::Digest,
714                DataType::Xfe,
715            ))
716        };
717        test_chain_map_with_different_num_input_lists(f);
718    }
719
720    #[test]
721    fn test_with_raw_function_square_on_xfe_plus_another_xfe() {
722        fn test_case<const N: usize>() {
723            let offset = ChainMap::<{ N }>::NUM_INTERNAL_REGISTERS;
724            let raw_code = InnerFunction::RawCode(RawCode::new(
725                triton_asm!(
726                    square_xfe_plus_another_xfe:
727                        dup 2 dup 2 dup 2 xx_mul
728                        dup {5 + offset}
729                        dup {5 + offset}
730                        dup {5 + offset}
731                        xx_add
732                        return
733                ),
734                DataType::Xfe,
735                DataType::Xfe,
736            ));
737            ShadowedFunction::new(ChainMap::<N>::new(raw_code)).test();
738        }
739
740        test_case::<0>();
741        test_case::<1>();
742        test_case::<2>();
743        test_case::<3>();
744        test_case::<5>();
745        test_case::<4>();
746        test_case::<6>();
747        test_case::<7>();
748    }
749
750    #[test]
751    fn test_u32_list_to_unit_list() {
752        let f = || {
753            InnerFunction::RawCode(RawCode::new(
754                triton_asm!(remove_elements: pop 1 return),
755                DataType::U32,
756                DataType::Tuple(vec![]),
757            ))
758        };
759        test_chain_map_with_different_num_input_lists(f);
760    }
761
762    #[test]
763    fn test_u32_list_to_u64_list() {
764        let f = || {
765            InnerFunction::RawCode(RawCode::new(
766                triton_asm!(duplicate_u32: dup 0 return),
767                DataType::U32,
768                DataType::U64,
769            ))
770        };
771        test_chain_map_with_different_num_input_lists(f);
772    }
773
774    #[test]
775    fn test_u32_list_to_u128_list_plus_x() {
776        // this code only works with 1 input list
777        let raw_code = InnerFunction::RawCode(u32_to_u128_add_another_u128());
778        let snippet = Map::new(raw_code);
779        let encoded_u128 = rand::random::<u128>().encode();
780        let input_list_len = rand::rng().random_range(0u16..200);
781        let initial_state = snippet.init_state(encoded_u128, [input_list_len], rand::random());
782        test_rust_equivalence_given_execution_state(
783            &ShadowedFunction::new(snippet),
784            initial_state.into(),
785        );
786    }
787
788    #[proptest(cases = 10)]
789    fn num_internal_registers_is_correct(#[strategy(arb())] guard: BFieldElement) {
790        fn test_case<const N: usize>(guard: BFieldElement) {
791            let offset = ChainMap::<{ N }>::NUM_INTERNAL_REGISTERS;
792            let raw_code = InnerFunction::RawCode(RawCode::new(
793                triton_asm! { check_env: dup {offset} push {guard} eq assert return },
794                DataType::Tuple(vec![]),
795                DataType::Tuple(vec![]),
796            ));
797            let snippet = ChainMap::<N>::new(raw_code);
798            let initial_state = snippet.init_state(vec![guard], [1; N], rand::random());
799            test_rust_equivalence_given_execution_state(
800                &ShadowedFunction::new(snippet),
801                initial_state.into(),
802            );
803        }
804
805        test_case::<0>(guard);
806        test_case::<1>(guard);
807        test_case::<2>(guard);
808        test_case::<3>(guard);
809        test_case::<4>(guard);
810        test_case::<5>(guard);
811        test_case::<6>(guard);
812        test_case::<7>(guard);
813        test_case::<8>(guard);
814        test_case::<9>(guard);
815        test_case::<10>(guard);
816        test_case::<11>(guard);
817        test_case::<12>(guard);
818    }
819
820    #[test]
821    fn mapping_over_dynamic_length_items_works() {
822        let f = || {
823            let list_type = DataType::List(Box::new(DataType::Bfe));
824            InnerFunction::RawCode(RawCode::new(
825                triton_asm!(just_forty_twos: pop 2 push 42 return),
826                DataType::Tuple(vec![list_type, DataType::Bfe]),
827                DataType::Bfe,
828            ))
829        };
830        assert!(f().domain().static_length().is_none());
831
832        test_chain_map_with_different_num_input_lists(f);
833    }
834
835    #[test]
836    fn mapping_over_list_of_lists_writing_their_lengths_works() {
837        let f = || {
838            let list_type = DataType::List(Box::new(DataType::Bfe));
839            InnerFunction::RawCode(RawCode::new(
840                triton_asm!(write_list_length: pop 1 read_mem 1 pop 1 return),
841                DataType::Tuple(vec![list_type, DataType::Bfe]),
842                DataType::Bfe,
843            ))
844        };
845        assert!(f().domain().static_length().is_none());
846
847        test_chain_map_with_different_num_input_lists(f);
848    }
849}
850
851#[cfg(test)]
852mod benches {
853    use super::tests::TestHashXFieldElement;
854    use super::*;
855    use crate::list::higher_order::inner_function::InnerFunction;
856    use crate::list::higher_order::inner_function::RawCode;
857    use crate::test_prelude::*;
858
859    #[test]
860    fn map_benchmark() {
861        let f = InnerFunction::BasicSnippet(Box::new(TestHashXFieldElement));
862        ShadowedFunction::new(Map::new(f)).bench();
863    }
864
865    #[test]
866    fn map_with_dyn_items_benchmark() {
867        let list_type = DataType::List(Box::new(DataType::Bfe));
868        let f = InnerFunction::RawCode(RawCode::new(
869            triton_asm!(dyn_length_elements: pop 2 push 42 return),
870            DataType::Tuple(vec![list_type, DataType::Bfe]),
871            DataType::Bfe,
872        ));
873        ShadowedFunction::new(Map::new(f)).bench();
874    }
875}