Skip to main content

tasm_lib/list/higher_order/
map.rs

1use itertools::Itertools;
2use strum::EnumCount;
3use tasm_lib::list::higher_order::inner_function::InnerFunction;
4use tasm_lib::structure::tasm_object::DEFAULT_MAX_DYN_FIELD_SIZE;
5use triton_vm::isa::op_stack::OpStackElement;
6use triton_vm::prelude::*;
7
8use crate::list::new::New;
9use crate::list::push::Push;
10use crate::prelude::*;
11
12const INNER_FN_INCORRECT_NUM_INPUTS: &str = "Inner function in `map` only works with *one* input. \
13                                             Use a tuple as a workaround.";
14const INNER_FN_INCORRECT_INPUT_DYN_LEN: &str = "An input type of dynamic length to `map`s inner \
15                                                function must be a tuple of form `(bfe, _)`.";
16
17/// Applies a given function to every element of a list, and collects the new
18/// elements into a new list.
19///
20/// Mapping over multiple input lists into one output list, effectively chaining
21/// inputs before applying the map, is possible with [`ChainMap`]. See there for
22/// extended documentation.
23pub type Map = ChainMap<1>;
24
25/// Applies a given function `f` to every element of all given lists, collecting
26/// the new elements into a new list.
27///
28/// The given function `f` must produce elements of a type for which the encoded
29/// length is [statically known][len]. The input type may have either
30/// statically or dynamically known length:
31/// - In the static case, the entire element is placed on the stack before
32///   passing control to `f`.
33/// - In the dynamic case, a memory pointer to the encoded element and the
34///   item's length is placed on the stack before passing control to `f`. The
35///   input list **must** be encoded according to [`BFieldCodec`]. Otherwise,
36///   behavior of `ChainMap` is undefined!
37///
38/// The stack layout is independent of the list currently being processed. This
39/// allows the [`InnerFunction`] `f` to use runtime parameters from the stack.
40/// Note that the chain map requires a certain number of stack registers for
41/// internal purposes. This number can be accessed through
42/// [`ChainMap::NUM_INTERNAL_REGISTERS`]. As mentioned above, the stack layout
43/// upon starting execution of `f` depends on the input type's
44/// [static length][len]. In the static case, the stack layout is:
45///
46/// ```txt
47/// // _ <accessible> [_; ChainMap::<N>::NUM_INTERNAL_REGISTERS] [input_element; len]
48/// ```
49///
50/// In the case of input elements with a dynamic length, the stack layout is:
51///
52/// ```txt
53/// // _ <accessible> [_; ChainMap::<N>::NUM_INTERNAL_REGISTERS] *elem_i elem_i_len
54/// ```
55///
56/// [len]: BFieldCodec::static_length
57pub struct ChainMap<const NUM_INPUT_LISTS: usize> {
58    f: InnerFunction,
59}
60
61impl<const NUM_INPUT_LISTS: usize> ChainMap<NUM_INPUT_LISTS> {
62    /// The number of registers required internally. See [`ChainMap`] for additional
63    /// details.
64    pub const NUM_INTERNAL_REGISTERS: usize = {
65        assert!(NUM_INPUT_LISTS <= Self::MAX_NUM_INPUT_LISTS);
66
67        3 + NUM_INPUT_LISTS
68    };
69
70    /// Need access to all lists, plus a little wiggle room.
71    const MAX_NUM_INPUT_LISTS: usize = OpStackElement::COUNT - 1;
72
73    /// # Panics
74    ///
75    /// - if the input type has [static length] _and_ takes up
76    ///   [`OpStackElement::COUNT`] or more words
77    /// - if the input type has dynamic length and is _anything but_ a tuple
78    ///   `(_, `[`BFieldElement`][bfe]`)`
79    /// - if the output type takes up [`OpStackElement::COUNT`]` - 1` or more words
80    /// - if the output type does not have a [static length][len]
81    ///
82    /// [len]: BFieldCodec::static_length
83    /// [bfe]: DataType::Bfe
84    pub fn new(f: InnerFunction) -> Self {
85        let domain = f.domain();
86        if let Some(input_len) = domain.static_length() {
87            // need instruction `place {input_type.stack_size()}`
88            assert!(input_len < OpStackElement::COUNT);
89        } else {
90            let DataType::Tuple(tuple) = domain else {
91                panic!("{INNER_FN_INCORRECT_INPUT_DYN_LEN}");
92            };
93            let [_, DataType::Bfe] = tuple[..] else {
94                panic!("{INNER_FN_INCORRECT_INPUT_DYN_LEN}");
95            };
96        }
97
98        // need instruction `pick {output_type.stack_size() + 1}`
99        let output_len = f
100            .range()
101            .static_length()
102            .expect("output type's encoding length must be static");
103        assert!(output_len + 1 < OpStackElement::COUNT);
104
105        Self { f }
106    }
107}
108
109impl<const NUM_INPUT_LISTS: usize> BasicSnippet for ChainMap<NUM_INPUT_LISTS> {
110    fn parameters(&self) -> Vec<(DataType, String)> {
111        let list_type = DataType::List(Box::new(self.f.domain()));
112
113        (0..NUM_INPUT_LISTS)
114            .map(|i| (list_type.clone(), format!("*input_list_{i}")))
115            .collect_vec()
116    }
117
118    fn return_values(&self) -> Vec<(DataType, String)> {
119        let list_type = DataType::List(Box::new(self.f.range()));
120        vec![(list_type, "*output_list".to_string())]
121    }
122
123    fn entrypoint(&self) -> String {
124        let maybe_chain_surely_map = if NUM_INPUT_LISTS == 1 {
125            "map".to_string()
126        } else {
127            format!("chain_map_{NUM_INPUT_LISTS}")
128        };
129
130        let f_label = self.f.entrypoint();
131        format!("tasmlib_list_higher_order_u32_{maybe_chain_surely_map}_{f_label}")
132    }
133
134    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
135        if self.f.domain().static_length().is_some() {
136            self.code_for_static_len_input_type(library)
137        } else {
138            self.code_for_dyn_len_input_type(library)
139        }
140    }
141}
142
143struct DecomposedInnerFunction<'body> {
144    exec_or_call: Vec<LabelledInstruction>,
145    fn_body: Option<&'body [LabelledInstruction]>,
146}
147
148impl<const NUM_INPUT_LISTS: usize> ChainMap<NUM_INPUT_LISTS> {
149    fn code_for_static_len_input_type(&self, library: &mut Library) -> Vec<LabelledInstruction> {
150        let input_type = self.f.domain();
151        let output_type = self.f.range();
152        assert!(input_type.static_length().is_some());
153
154        let new_list = library.import(Box::new(New));
155        let inner_fn = self.decompose_inner_fn(library);
156
157        let entrypoint = self.entrypoint();
158        let main_loop_fn = format!("{entrypoint}_loop");
159
160        let mul_elem_size = |n| match n {
161            1 => triton_asm!(),
162            n => triton_asm!(push {n} mul),
163        };
164        let adjust_output_list_pointer = match output_type.stack_size() {
165            0 | 1 => triton_asm!(),
166            n => triton_asm!(addi {-(n as i32 - 1)}),
167        };
168
169        let main_loop_body = triton_asm! {
170            // INVARIANT: _ *end_condition_in_list *output_elem *input_elem
171
172            /* maybe return */
173            // not using `recurse_or_return` to have more room for parameters
174            // that might live on the stack, to and used by the inner function
175            dup 2 dup 1 eq
176            skiz return
177
178            /* read */
179            {&input_type.read_value_from_memory_leave_pointer()}
180            place {input_type.stack_size()}
181                        // _ *end_condition_in_list *output_elem *prev_input_elem [input_elem]
182
183            /* map */
184            {&inner_fn.exec_or_call}
185                        // _ *end_condition_in_list *output_elem *prev_input_elem [output_elem]
186
187            /* write */
188            pick {output_type.stack_size() + 1}
189            {&output_type.write_value_to_memory_leave_pointer()}
190            addi {-2 * output_type.stack_size() as i32}
191            place 1     // _ *end_condition_in_list *prev_output_elem *prev_input_elem
192
193            recurse
194        };
195
196        let map_one_list = triton_asm! {
197            // BEFORE: _ [fill; M]   [*in_list; N-M]   *out_list
198            // AFTER:  _ [fill; M+1] [*in_list; N-M-1] *out_list
199
200            /* read list lengths */
201            read_mem 1
202            addi 1      // _ [_; M] [_; N-M-1] *in_list out_list_len *out_list
203
204            pick 2
205            read_mem 1
206            addi 1      // _ [_; M] [_; N-M-1] out_list_len *out_list in_list_len *in_list
207
208            /* prepare in_list pointer for main loop */
209            dup 1
210            {&mul_elem_size(input_type.stack_size())}
211            dup 1
212            add         // _ [_; M] [_; N-M-1] out_list_len *out_list in_list_len *in_list *in_list_first_elem_last_word
213
214            /* update out_list's len */
215            pick 2
216            pick 4      // _ [_; M] [_; N-M-1] *out_list *in_list *in_list_first_elem_last_word in_list_len out_list_len
217            add         // _ [_; M] [_; N-M-1] *out_list *in_list *in_list_first_elem_last_word new_out_list_len
218
219            dup 0
220            pick 4
221            write_mem 1
222            addi -1     // _ [_; M] [_; N-M-1] *in_list *in_list_first_elem_last_word new_out_list_len *out_list
223
224            /* store *out_list for next iterations */
225            dup 0
226            place 4     // _ [_; M] [_; N-M-1] *out_list *in_list *in_list_first_elem_last_word new_out_list_len *out_list
227
228            /* prepare out_list pointer for main loop */
229            pick 1
230            {&mul_elem_size(output_type.stack_size())}
231            add         // _ [_; M] [_; N-M-1] *out_list *in_list *in_list_first_elem_last_word *out_list_last_elem_last_word
232
233            {&adjust_output_list_pointer}
234            place 1     // _ [_; M] [_; N-M-1] *out_list *in_list *out_list_last_elem_first_word *in_list_first_elem_last_word
235
236            call {main_loop_fn}
237                        hint used_list: Pointer = stack[2]
238                        // _ [_; M] [_; N-M-1] *out_list fill garbage fill
239
240            /* clean up */
241            pop 2
242            place {NUM_INPUT_LISTS}
243                        // _ [_; M+1] [_; N-M-1] *out_list
244        };
245        let map_all_lists = vec![map_one_list; NUM_INPUT_LISTS].concat();
246
247        triton_asm! {
248            // BEFORE: _ [*in_list; N]
249            // AFTER:  _ *out_list
250            {entrypoint}:
251                call {new_list}
252                    hint chain_map_output_list: Pointer = stack[0]
253                {&map_all_lists}
254                place {NUM_INPUT_LISTS}
255                {&Self::pop_input_lists()}
256                return
257            {main_loop_fn}:
258                {&main_loop_body}
259            {&inner_fn.fn_body.unwrap_or_default()}
260        }
261    }
262
263    fn code_for_dyn_len_input_type(&self, library: &mut Library) -> Vec<LabelledInstruction> {
264        let input_type = self.f.domain();
265        let output_type = self.f.range();
266        assert!(input_type.static_length().is_none());
267
268        let new_list = library.import(Box::new(New));
269        let push = library.import(Box::new(Push::new(output_type.clone())));
270        let inner_fn = self.decompose_inner_fn(library);
271
272        let entrypoint = self.entrypoint();
273        let main_loop_fn = format!("{entrypoint}_loop");
274
275        let main_loop_body = triton_asm! {
276            //                ⬐ for Self::NUM_INTERNAL_REGISTERS
277            // BEFORE:    _ fill 0           in_list_len *out_list *in_list[0]_si
278            // INVARIANT: _ fill i           in_list_len *out_list *in_list[i]_si
279            // AFTER:     _ fill in_list_len in_list_len *out_list garbage
280
281            /* maybe return */
282            dup 3
283            dup 3
284            eq
285            skiz return
286
287            /* read field size */
288            read_mem 1  hint item_len = stack[0]
289            addi 2      // _ fill i in_list_len *out_list elem_len *in_list[i]
290
291            /* check field size is reasonable */
292            push {DEFAULT_MAX_DYN_FIELD_SIZE}
293                        hint default_max_dyn_field_size = stack[0]
294            dup 2       // _ fill i in_list_len *out_list l[i]_len *in_list[i] max l[i]_len
295            lt
296            assert      // _ fill i in_list_len *out_list l[i]_len *in_list[i]
297
298            /* advance item iterator */
299            dup 1
300            dup 1
301            add
302            place 2
303
304            /* prepare for inner function */
305            place 1     // _ fill i in_list_len *out_list *in_list[i+1]_si *in_list[i] l[i]_len
306
307            /* map */
308            {&inner_fn.exec_or_call}
309                        // _ fill i in_list_len *out_list *in_list[i]_si [out_elem]
310
311            /* write */
312            dup {output_type.stack_size() + 1}
313            place {output_type.stack_size()}
314            call {push}
315                        // _ fill i in_list_len *out_list *in_list[i]_si
316
317            /* advance i */
318            pick 3
319            addi 1
320            place 3
321                        // _ fill (i+i) in_list_len *out_list *in_list[i+1]_si
322
323            recurse
324        };
325
326        let map_one_list = triton_asm! {
327            // BEFORE: _ [fill; M]   [*in_list; N-M]   *out_list
328            // AFTER:  _ [fill; M+1] [*in_list; N-M-1] *out_list
329
330            /* read in_list length */
331            pick 1
332            read_mem 1  hint in_list_len = stack[1]
333            addi 2      // _ [_; M] [_; N-M-1] *out_list in_list_len *in_list[0]_si
334
335
336            /* setup for main loop */
337            pick 2
338            place 1
339            push 0      hint filler = stack[0]
340            place 3
341            push 0      hint index = stack[0]
342            place 3
343                        // _ [_; M] [_; N-M-1] fill 0 in_list_len *out_list *in_list[0]_si
344
345            call {main_loop_fn}
346
347            /* clean up */
348            pick 1
349            place 4
350            pop 3
351            place {NUM_INPUT_LISTS}
352                        // _ [_; M+1] [_; N-M-1] *out_list
353        };
354        let map_all_lists = vec![map_one_list; NUM_INPUT_LISTS].concat();
355
356        triton_asm! {
357            // BEFORE: _ [*in_list; N]
358            // AFTER:  _ *out_list
359            {entrypoint}:
360                call {new_list}
361                    hint chain_map_output_list: Pointer = stack[0]
362                {&map_all_lists}
363                place {NUM_INPUT_LISTS}
364                {&Self::pop_input_lists()}
365                return
366            {main_loop_fn}:
367                {&main_loop_body}
368            {&inner_fn.fn_body.unwrap_or_default()}
369        }
370    }
371
372    fn decompose_inner_fn(&self, library: &mut Library) -> DecomposedInnerFunction<'_> {
373        let exec_or_call = match &self.f {
374            InnerFunction::RawCode(code) => {
375                // Inlining saves two clock cycles per iteration. If the function cannot be
376                // inlined, it needs to be appended to the function body.
377                code.inlined_body()
378                    .unwrap_or(triton_asm!(call {code.entrypoint()}))
379            }
380            InnerFunction::BasicSnippet(snippet) => {
381                assert_eq!(
382                    1,
383                    snippet.parameters().len(),
384                    "{INNER_FN_INCORRECT_NUM_INPUTS}"
385                );
386                let labelled_instructions = snippet.annotated_code(library);
387                let label = library.explicit_import(&snippet.entrypoint(), &labelled_instructions);
388                triton_asm!(call { label })
389            }
390            InnerFunction::NoFunctionBody(lnat) => {
391                triton_asm!(call { lnat.label_name })
392            }
393        };
394
395        let fn_body = if let InnerFunction::RawCode(c) = &self.f {
396            c.inlined_body().is_none().then_some(c.function.as_slice())
397        } else {
398            None
399        };
400
401        DecomposedInnerFunction {
402            exec_or_call,
403            fn_body,
404        }
405    }
406
407    fn pop_input_lists() -> Vec<LabelledInstruction> {
408        match NUM_INPUT_LISTS {
409            0 => triton_asm!(),
410            i @ 1..=5 => triton_asm!(pop { i }),
411            i @ 6..=10 => triton_asm!(pop 5 pop { i - 5 }),
412            i @ 11..=15 => triton_asm!(pop 5 pop 5 pop { i - 10 }),
413            _ => unreachable!("see compile time checks for `NUM_INPUT_LISTS`"),
414        }
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use itertools::Itertools;
421
422    use super::*;
423    use crate::arithmetic;
424    use crate::list::higher_order::inner_function::InnerFunction;
425    use crate::list::higher_order::inner_function::RawCode;
426    use crate::neptune::mutator_set::get_swbf_indices::u32_to_u128_add_another_u128;
427    use crate::rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator;
428    use crate::rust_shadowing_helper_functions::list::list_get;
429    use crate::rust_shadowing_helper_functions::list::list_get_length;
430    use crate::rust_shadowing_helper_functions::list::list_pointer_to_elem_pointer;
431    use crate::rust_shadowing_helper_functions::list::list_set;
432    use crate::rust_shadowing_helper_functions::list::list_set_length;
433    use crate::test_helpers::test_rust_equivalence_given_execution_state;
434    use crate::test_prelude::*;
435
436    impl<const NUM_INPUT_LISTS: usize> ChainMap<NUM_INPUT_LISTS> {
437        fn init_state(
438            &self,
439            environment_args: impl IntoIterator<Item = BFieldElement>,
440            list_lengths: [u16; NUM_INPUT_LISTS],
441            seed: <StdRng as SeedableRng>::Seed,
442        ) -> FunctionInitialState {
443            let input_type = self.f.domain();
444            let mut stack = self.init_stack_for_isolated_run();
445            let mut memory = HashMap::default();
446            let mut rng = StdRng::from_seed(seed);
447
448            stack.extend(environment_args);
449
450            for list_length in list_lengths {
451                let list_length = usize::from(list_length);
452                let list = input_type.random_list(&mut rng, list_length);
453                let list_pointer = dynamic_allocator(&mut memory);
454                let indexed_list = list
455                    .into_iter()
456                    .enumerate()
457                    .map(|(i, v)| (list_pointer + bfe!(i), v));
458
459                memory.extend(indexed_list);
460                stack.push(list_pointer);
461            }
462
463            FunctionInitialState { stack, memory }
464        }
465    }
466
467    impl<const NUM_INPUT_LISTS: usize> Function for ChainMap<NUM_INPUT_LISTS> {
468        fn rust_shadow(
469            &self,
470            stack: &mut Vec<BFieldElement>,
471            memory: &mut HashMap<BFieldElement, BFieldElement>,
472        ) -> Result<(), RustShadowError> {
473            let input_type = self.f.domain();
474            let output_type = self.f.range();
475
476            New.rust_shadow(stack, memory)?;
477            let output_list_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
478
479            let input_list_pointers: Vec<_> = (0..NUM_INPUT_LISTS)
480                .map(|_| stack.pop().ok_or(RustShadowError::StackUnderflow))
481                .try_collect()?;
482
483            // the inner function _must not_ rely on these elements
484            let buffer = (0..Self::NUM_INTERNAL_REGISTERS).map(|_| rand::random::<BFieldElement>());
485            stack.extend(buffer);
486
487            let mut total_output_len = 0;
488            for input_list_pointer in input_list_pointers {
489                let input_list_len = list_get_length(input_list_pointer, memory)?;
490                let output_list_len = list_get_length(output_list_pointer, memory)?;
491                let new_output_list_len = output_list_len + input_list_len;
492                list_set_length(output_list_pointer, new_output_list_len, memory);
493
494                for i in (0..input_list_len).rev() {
495                    if input_type.static_length().is_some() {
496                        let elem =
497                            list_get(input_list_pointer, i, memory, input_type.stack_size())?;
498                        stack.extend(elem.into_iter().rev());
499                    } else {
500                        let (len, ptr) = list_pointer_to_elem_pointer(
501                            input_list_pointer,
502                            i,
503                            memory,
504                            &input_type,
505                        )?;
506                        stack.push(ptr);
507                        stack.push(bfe!(len));
508                    };
509                    self.f.apply(stack, memory);
510                    let elem: Vec<_> = (0..output_type.stack_size())
511                        .map(|_| stack.pop().ok_or(RustShadowError::StackUnderflow))
512                        .try_collect()?;
513                    list_set(output_list_pointer, total_output_len + i, elem, memory)?;
514                }
515
516                total_output_len += input_list_len;
517            }
518
519            for _ in 0..Self::NUM_INTERNAL_REGISTERS {
520                stack.pop();
521            }
522
523            stack.push(output_list_pointer);
524            Ok(())
525        }
526
527        fn pseudorandom_initial_state(
528            &self,
529            seed: [u8; 32],
530            bench: Option<BenchmarkCase>,
531        ) -> FunctionInitialState {
532            let mut rng = StdRng::from_seed(seed);
533            let environment_args = rng.random::<[BFieldElement; OpStackElement::COUNT]>();
534
535            let list_lengths = match bench {
536                None => rng.random::<[u8; NUM_INPUT_LISTS]>(),
537                Some(BenchmarkCase::CommonCase) => [10; NUM_INPUT_LISTS],
538                Some(BenchmarkCase::WorstCase) => [100; NUM_INPUT_LISTS],
539            };
540            let list_lengths = list_lengths.map(Into::into);
541
542            self.init_state(environment_args, list_lengths, rng.random())
543        }
544    }
545
546    #[derive(Debug, Clone)]
547    pub(crate) struct TestHashXFieldElement;
548
549    impl BasicSnippet for TestHashXFieldElement {
550        fn parameters(&self) -> Vec<(DataType, String)> {
551            vec![(DataType::Xfe, "element".to_string())]
552        }
553
554        fn return_values(&self) -> Vec<(DataType, String)> {
555            vec![(DataType::Digest, "digest".to_string())]
556        }
557
558        fn entrypoint(&self) -> String {
559            "test_hash_xfield_element".to_string()
560        }
561
562        fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
563            let entrypoint = self.entrypoint();
564            let unused_import = library.import(Box::new(arithmetic::u32::safe_add::SafeAdd));
565            triton_asm!(
566                // BEFORE: _ x2 x1 x0
567                // AFTER:  _ d4 d3 d2 d1 d0
568                {entrypoint}:
569                    push 0 push 0
570                    push 0 push 0
571                    push 0 push 0   // _ x2 x1 x0 0 0 0 0 0 0
572                    push 1          // _ x2 x1 x0 0 0 0 0 0 0 1
573                    pick 9          // _ x1 x0 0 0 0 0 0 0 1 x2
574                    pick 9          // _ x0 0 0 0 0 0 0 1 x2 x1
575                    pick 9          // _ 0 0 0 0 0 0 1 x2 x1 x0
576
577                    // Useless additions, to ensure that imports are accepted inside the
578                    // map-generated code
579                    push 0
580                    push 0
581                    call {unused_import}
582                    pop 1
583
584                    sponge_init
585                    sponge_absorb
586                    sponge_squeeze // _ d9 d8 d7 d6 d5 d4 d3 d2 d1 d0
587                    pick 9 pick 9  // _ d7 d6 d5 d4 d3 d2 d1 d0 d9 d8
588                    pick 9 pick 9  // _ d5 d4 d3 d2 d1 d0 d9 d8 d7 d6
589                    pick 9 pop 5   // _ d4 d3 d2 d1 d0
590                    return
591            )
592        }
593    }
594
595    /// `InnerFunction` does not implement `Clone`, and it is hard (impossible?) to
596    /// teach it. Hence, take a function `f` to generate the `InnerFunction`.
597    ///
598    /// Theoretically, this _could_ mean that `f` produces a different
599    /// `InnerFunction` each time it is called, as every `Fn()` is `FnMut()`.
600    /// Since this is a test helper, how about it doesn't. 😊
601    fn test_chain_map_with_different_num_input_lists(f: impl Fn() -> InnerFunction) {
602        ShadowedFunction::new(ChainMap::<0>::new(f())).test();
603        ShadowedFunction::new(ChainMap::<1>::new(f())).test();
604        ShadowedFunction::new(ChainMap::<2>::new(f())).test();
605        ShadowedFunction::new(ChainMap::<3>::new(f())).test();
606        ShadowedFunction::new(ChainMap::<4>::new(f())).test();
607        ShadowedFunction::new(ChainMap::<5>::new(f())).test();
608
609        ShadowedFunction::new(ChainMap::<7>::new(f())).test();
610        ShadowedFunction::new(ChainMap::<11>::new(f())).test();
611        ShadowedFunction::new(ChainMap::<15>::new(f())).test();
612    }
613
614    #[macro_rules_attr::apply(test)]
615    fn test_with_raw_function_identity_on_bfe() {
616        let f = || {
617            InnerFunction::RawCode(RawCode::new(
618                triton_asm!(identity_bfe: return),
619                DataType::Bfe,
620                DataType::Bfe,
621            ))
622        };
623        test_chain_map_with_different_num_input_lists(f);
624    }
625
626    #[macro_rules_attr::apply(test)]
627    fn test_with_raw_function_bfe_lift() {
628        let f = || {
629            InnerFunction::RawCode(RawCode::new(
630                triton_asm!(bfe_lift: push 0 push 0 pick 2 return),
631                DataType::Bfe,
632                DataType::Xfe,
633            ))
634        };
635        test_chain_map_with_different_num_input_lists(f);
636    }
637
638    #[macro_rules_attr::apply(test)]
639    fn test_with_raw_function_xfe_get_coeff_0() {
640        let f = || {
641            InnerFunction::RawCode(RawCode::new(
642                triton_asm!(get_0: place 2 pop 2 return),
643                DataType::Xfe,
644                DataType::Bfe,
645            ))
646        };
647        test_chain_map_with_different_num_input_lists(f);
648    }
649
650    #[macro_rules_attr::apply(test)]
651    fn test_with_raw_function_square_on_bfe() {
652        let f = || {
653            InnerFunction::RawCode(RawCode::new(
654                triton_asm!(square_bfe: dup 0 mul return),
655                DataType::Bfe,
656                DataType::Bfe,
657            ))
658        };
659        test_chain_map_with_different_num_input_lists(f);
660    }
661
662    #[macro_rules_attr::apply(test)]
663    fn test_with_raw_function_square_plus_n_on_bfe() {
664        // Inner function calculates `|x| -> x*x + n`, where `x` is the list element,
665        // and `n` is the same value for all elements.
666        fn test_case<const N: usize>() {
667            let raw_code = InnerFunction::RawCode(RawCode::new(
668                triton_asm!(square_plus_n_bfe: dup 0 mul dup {5 + N} add return),
669                DataType::Bfe,
670                DataType::Bfe,
671            ));
672            ShadowedFunction::new(ChainMap::<N>::new(raw_code)).test();
673        }
674
675        test_case::<0>();
676        test_case::<1>();
677        test_case::<2>();
678        test_case::<3>();
679        test_case::<4>();
680        test_case::<5>();
681        test_case::<7>();
682        test_case::<9>();
683        test_case::<10>();
684    }
685
686    #[macro_rules_attr::apply(test)]
687    fn test_with_raw_function_square_on_xfe() {
688        let f = || {
689            InnerFunction::RawCode(RawCode::new(
690                triton_asm!(square_xfe: dup 2 dup 2 dup 2 xx_mul return),
691                DataType::Xfe,
692                DataType::Xfe,
693            ))
694        };
695        test_chain_map_with_different_num_input_lists(f);
696    }
697
698    #[macro_rules_attr::apply(test)]
699    fn test_with_raw_function_xfe_to_digest() {
700        let f = || {
701            InnerFunction::RawCode(RawCode::new(
702                triton_asm!(xfe_to_digest: push 0 push 0 return),
703                DataType::Xfe,
704                DataType::Digest,
705            ))
706        };
707        test_chain_map_with_different_num_input_lists(f);
708    }
709
710    #[macro_rules_attr::apply(test)]
711    fn test_with_raw_function_digest_to_xfe() {
712        let f = || {
713            InnerFunction::RawCode(RawCode::new(
714                triton_asm!(xfe_to_digest: pop 2 return),
715                DataType::Digest,
716                DataType::Xfe,
717            ))
718        };
719        test_chain_map_with_different_num_input_lists(f);
720    }
721
722    #[macro_rules_attr::apply(test)]
723    fn test_with_raw_function_square_on_xfe_plus_another_xfe() {
724        fn test_case<const N: usize>() {
725            let offset = ChainMap::<{ N }>::NUM_INTERNAL_REGISTERS;
726            let raw_code = InnerFunction::RawCode(RawCode::new(
727                triton_asm!(
728                    square_xfe_plus_another_xfe:
729                        dup 2 dup 2 dup 2 xx_mul
730                        dup {5 + offset}
731                        dup {5 + offset}
732                        dup {5 + offset}
733                        xx_add
734                        return
735                ),
736                DataType::Xfe,
737                DataType::Xfe,
738            ));
739            ShadowedFunction::new(ChainMap::<N>::new(raw_code)).test();
740        }
741
742        test_case::<0>();
743        test_case::<1>();
744        test_case::<2>();
745        test_case::<3>();
746        test_case::<5>();
747        test_case::<4>();
748        test_case::<6>();
749        test_case::<7>();
750    }
751
752    #[macro_rules_attr::apply(test)]
753    fn test_u32_list_to_unit_list() {
754        let f = || {
755            InnerFunction::RawCode(RawCode::new(
756                triton_asm!(remove_elements: pop 1 return),
757                DataType::U32,
758                DataType::Tuple(vec![]),
759            ))
760        };
761        test_chain_map_with_different_num_input_lists(f);
762    }
763
764    #[macro_rules_attr::apply(test)]
765    fn test_u32_list_to_u64_list() {
766        let f = || {
767            InnerFunction::RawCode(RawCode::new(
768                triton_asm!(duplicate_u32: dup 0 return),
769                DataType::U32,
770                DataType::U64,
771            ))
772        };
773        test_chain_map_with_different_num_input_lists(f);
774    }
775
776    #[macro_rules_attr::apply(test)]
777    fn test_u32_list_to_u128_list_plus_x() {
778        // this code only works with 1 input list
779        let raw_code = InnerFunction::RawCode(u32_to_u128_add_another_u128());
780        let snippet = Map::new(raw_code);
781        let encoded_u128 = rand::random::<u128>().encode();
782        let input_list_len = rand::rng().random_range(0u16..200);
783        let initial_state = snippet.init_state(encoded_u128, [input_list_len], rand::random());
784        test_rust_equivalence_given_execution_state(
785            &ShadowedFunction::new(snippet),
786            initial_state.into(),
787        );
788    }
789
790    #[macro_rules_attr::apply(proptest(cases = 10))]
791    fn num_internal_registers_is_correct(#[strategy(arb())] guard: BFieldElement) {
792        fn test_case<const N: usize>(guard: BFieldElement) {
793            let offset = ChainMap::<{ N }>::NUM_INTERNAL_REGISTERS;
794            let raw_code = InnerFunction::RawCode(RawCode::new(
795                triton_asm! { check_env: dup {offset} push {guard} eq assert return },
796                DataType::Tuple(vec![]),
797                DataType::Tuple(vec![]),
798            ));
799            let snippet = ChainMap::<N>::new(raw_code);
800            let initial_state = snippet.init_state(vec![guard], [1; N], rand::random());
801            test_rust_equivalence_given_execution_state(
802                &ShadowedFunction::new(snippet),
803                initial_state.into(),
804            );
805        }
806
807        test_case::<0>(guard);
808        test_case::<1>(guard);
809        test_case::<2>(guard);
810        test_case::<3>(guard);
811        test_case::<4>(guard);
812        test_case::<5>(guard);
813        test_case::<6>(guard);
814        test_case::<7>(guard);
815        test_case::<8>(guard);
816        test_case::<9>(guard);
817        test_case::<10>(guard);
818        test_case::<11>(guard);
819        test_case::<12>(guard);
820    }
821
822    #[macro_rules_attr::apply(test)]
823    fn mapping_over_dynamic_length_items_works() {
824        let f = || {
825            let list_type = DataType::List(Box::new(DataType::Bfe));
826            InnerFunction::RawCode(RawCode::new(
827                triton_asm!(just_forty_twos: pop 2 push 42 return),
828                DataType::Tuple(vec![list_type, DataType::Bfe]),
829                DataType::Bfe,
830            ))
831        };
832        assert!(f().domain().static_length().is_none());
833
834        test_chain_map_with_different_num_input_lists(f);
835    }
836
837    #[macro_rules_attr::apply(test)]
838    fn mapping_over_list_of_lists_writing_their_lengths_works() {
839        let f = || {
840            let list_type = DataType::List(Box::new(DataType::Bfe));
841            InnerFunction::RawCode(RawCode::new(
842                triton_asm!(write_list_length: pop 1 read_mem 1 pop 1 return),
843                DataType::Tuple(vec![list_type, DataType::Bfe]),
844                DataType::Bfe,
845            ))
846        };
847        assert!(f().domain().static_length().is_none());
848
849        test_chain_map_with_different_num_input_lists(f);
850    }
851}
852
853#[cfg(test)]
854mod benches {
855    use super::tests::TestHashXFieldElement;
856    use super::*;
857    use crate::list::higher_order::inner_function::InnerFunction;
858    use crate::list::higher_order::inner_function::RawCode;
859    use crate::test_prelude::*;
860
861    #[macro_rules_attr::apply(test)]
862    fn map_benchmark() {
863        let f = InnerFunction::BasicSnippet(Box::new(TestHashXFieldElement));
864        ShadowedFunction::new(Map::new(f)).bench();
865    }
866
867    #[macro_rules_attr::apply(test)]
868    fn map_with_dyn_items_benchmark() {
869        let list_type = DataType::List(Box::new(DataType::Bfe));
870        let f = InnerFunction::RawCode(RawCode::new(
871            triton_asm!(dyn_length_elements: pop 2 push 42 return),
872            DataType::Tuple(vec![list_type, DataType::Bfe]),
873            DataType::Bfe,
874        ));
875        ShadowedFunction::new(Map::new(f)).bench();
876    }
877}