tasm_lib/list/higher_order/
inner_function.rs

1use std::collections::HashMap;
2
3use triton_vm::isa::instruction::AnInstruction;
4use triton_vm::prelude::*;
5
6use crate::prelude::*;
7
8const MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION: &str = "higher-order functions \
9currently only work with *one* input element in inner function. \
10Use a tuple data type to circumvent this.";
11
12/// A data structure for describing an inner function predicate to filter with,
13/// or a function to map with.
14#[derive(Debug, Clone, Eq, PartialEq, Hash)]
15pub struct RawCode {
16    pub function: Vec<LabelledInstruction>,
17    pub input_type: DataType,
18    pub output_type: DataType,
19}
20
21impl RawCode {
22    pub fn new(
23        function: Vec<LabelledInstruction>,
24        input_type: DataType,
25        output_type: DataType,
26    ) -> Self {
27        let is_label = |x: &_| matches!(x, LabelledInstruction::Label(_));
28        let is_instruction = |x: &_| matches!(x, LabelledInstruction::Instruction(_));
29        let labels_and_instructions = function.iter().filter(|i| is_label(i) || is_instruction(i));
30
31        // Verify that 1st line is a label
32        assert!(
33            labels_and_instructions.count() >= 2,
34            "Inner function must have at least two lines: a label and a return or recurse"
35        );
36        assert!(
37            matches!(function[0], LabelledInstruction::Label(_)),
38            "First line of inner function must be label. Got: {}",
39            function[0]
40        );
41        assert!(
42            matches!(
43                function.last().unwrap(),
44                LabelledInstruction::Instruction(AnInstruction::Return)
45                    | LabelledInstruction::Instruction(AnInstruction::Recurse)
46                    | LabelledInstruction::Instruction(AnInstruction::RecurseOrReturn)
47            ),
48            "Last line of inner function must be either return, recurse, or recurse_or_return. Got: {}",
49            function.last().unwrap()
50        );
51
52        Self {
53            function,
54            input_type,
55            output_type,
56        }
57    }
58}
59
60impl RawCode {
61    /// Return the entrypoint, label, of the inner function. Used to make a call to this function.
62    pub fn entrypoint(&self) -> String {
63        let is_label = |x: &_| matches!(x, LabelledInstruction::Label(_));
64        let is_instruction = |x: &_| matches!(x, LabelledInstruction::Instruction(_));
65        let first_label_or_instruction = self
66            .function
67            .iter()
68            .find(|&x| is_label(x) || is_instruction(x));
69        let Some(labelled_instruction) = first_label_or_instruction else {
70            panic!("Inner function must start with a label. Got neither labels nor instructions.")
71        };
72        let LabelledInstruction::Label(label) = labelled_instruction else {
73            panic!("Inner function must start with a label. Got: {labelled_instruction}");
74        };
75
76        label.to_string()
77    }
78
79    /// Returns `Some(code)` iff the raw code is a function that can be inlined
80    ///
81    /// Type hints and breakpoints are stripped.
82    pub fn inlined_body(&self) -> Option<Vec<LabelledInstruction>> {
83        let is_label = |x: &_| matches!(x, LabelledInstruction::Label(_));
84        let is_instruction = |x: &_| matches!(x, LabelledInstruction::Instruction(_));
85        let is_recursive = |x: &_| {
86            matches!(
87                x,
88                LabelledInstruction::Instruction(AnInstruction::Recurse)
89                    | LabelledInstruction::Instruction(AnInstruction::RecurseOrReturn)
90            )
91        };
92
93        if self.function.iter().any(is_recursive) {
94            // recursion needs to be wrapped in a function
95            return None;
96        }
97
98        let mut labels_and_instructions = self
99            .function
100            .iter()
101            .filter(|i| is_label(i) || is_instruction(i));
102
103        let Some(first_thing) = labels_and_instructions.next() else {
104            return Some(triton_asm!());
105        };
106        let LabelledInstruction::Label(_) = first_thing else {
107            panic!("Raw Code must start with a label.")
108        };
109
110        let Some(LabelledInstruction::Instruction(AnInstruction::Return)) =
111            labels_and_instructions.next_back()
112        else {
113            panic!("Raw Code is probably buggy: too short, or doesn't end with `return`.");
114        };
115
116        Some(labels_and_instructions.cloned().collect())
117    }
118}
119
120pub enum InnerFunction {
121    RawCode(RawCode),
122    BasicSnippet(Box<dyn BasicSnippet>),
123
124    // Used when a snippet is declared somewhere else, and it's not the responsibility of
125    // the higher order function to import it.
126    NoFunctionBody(NoFunctionBody),
127}
128
129#[derive(Debug, Clone, Eq, PartialEq, Hash)]
130pub struct NoFunctionBody {
131    pub label_name: String,
132    pub input_type: DataType,
133    pub output_type: DataType,
134}
135
136impl InnerFunction {
137    pub fn domain(&self) -> DataType {
138        match self {
139            InnerFunction::RawCode(raw) => raw.input_type.clone(),
140            InnerFunction::NoFunctionBody(f) => f.input_type.clone(),
141            InnerFunction::BasicSnippet(bs) => {
142                let [(ref input, _)] = bs.inputs()[..] else {
143                    panic!("{MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION}");
144                };
145                input.clone()
146            }
147        }
148    }
149
150    pub fn range(&self) -> DataType {
151        match self {
152            InnerFunction::RawCode(rc) => rc.output_type.clone(),
153            InnerFunction::NoFunctionBody(lnat) => lnat.output_type.clone(),
154            InnerFunction::BasicSnippet(bs) => {
155                let [(ref output, _)] = bs.outputs()[..] else {
156                    panic!("{MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION}");
157                };
158                output.clone()
159            }
160        }
161    }
162
163    /// Return the entrypoint, label, of the inner function. Used to make a call to this function.
164    pub fn entrypoint(&self) -> String {
165        match self {
166            InnerFunction::RawCode(rc) => rc.entrypoint(),
167            InnerFunction::NoFunctionBody(sn) => sn.label_name.to_owned(),
168            InnerFunction::BasicSnippet(bs) => bs.entrypoint(),
169        }
170    }
171
172    /// Computes the inner function and applies the resulting change to the given stack
173    pub fn apply(
174        &self,
175        stack: &mut Vec<BFieldElement>,
176        memory: &HashMap<BFieldElement, BFieldElement>,
177    ) {
178        match &self {
179            InnerFunction::RawCode(rc) => Self::run_vm(&rc.function, stack, memory),
180            InnerFunction::NoFunctionBody(_lnat) => {
181                panic!("Cannot apply inner function without function body")
182            }
183            InnerFunction::BasicSnippet(bs) => {
184                let mut library = Library::new();
185                let function = bs.annotated_code(&mut library);
186                let imports = library.all_imports();
187                let code = triton_asm!(
188                    {&function}
189                    {&imports}
190                );
191
192                Self::run_vm(&code, stack, memory);
193            }
194        };
195    }
196
197    /// Run the VM for on a given stack and memory to observe how it manipulates the
198    /// stack. This is a helper function for [`apply`](Self::apply), which in some cases
199    /// just grabs the inner function's code and then needs a VM to apply it.
200    fn run_vm(
201        instructions: &[LabelledInstruction],
202        stack: &mut Vec<BFieldElement>,
203        memory: &HashMap<BFieldElement, BFieldElement>,
204    ) {
205        let Some(LabelledInstruction::Label(label)) = instructions.first() else {
206            panic!();
207        };
208        let instructions = triton_asm!(
209            call {label}
210            halt
211            {&instructions}
212        );
213        let program = Program::new(&instructions);
214        let mut vmstate = VMState::new(program, PublicInput::default(), NonDeterminism::default());
215        vmstate.op_stack.stack.clone_from(stack);
216        vmstate.ram.clone_from(memory);
217        vmstate.run().unwrap();
218        *stack = vmstate.op_stack.stack;
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn breakpoint_does_not_influence_raw_code_inlining() {
228        let raw_code = RawCode {
229            function: triton_asm! { my_label: return break },
230            input_type: DataType::VoidPointer,
231            output_type: DataType::VoidPointer,
232        };
233        let inlined_code = raw_code.inlined_body().unwrap();
234        assert_eq!(triton_asm!(), inlined_code);
235    }
236
237    #[test]
238    fn type_hints_do_not_influence_raw_code_inlining() {
239        let raw_code = RawCode {
240            function: triton_asm! { my_label: hint a = stack[0] hint b = stack[1] return },
241            input_type: DataType::VoidPointer,
242            output_type: DataType::VoidPointer,
243        };
244        let inlined_code = raw_code.inlined_body().unwrap();
245        assert_eq!(triton_asm!(), inlined_code);
246    }
247
248    #[test]
249    fn allow_raw_code_with_recurse_or_return_instruction() {
250        let raw_code = triton_asm!(
251            please_help_me:
252                hint im_falling = stack[0]
253                hint in_love_with_you = stack[1]
254
255                call close_the_door_to_temptation
256
257                return
258
259                close_the_door_to_temptation:
260                    hint turn_away_from_me_darling = stack[5]
261                    break
262                    merkle_step_mem
263                    recurse_or_return
264        );
265        let raw_code = RawCode::new(raw_code, DataType::VoidPointer, DataType::VoidPointer);
266        assert!(
267            raw_code.inlined_body().is_none(),
268            "Disallow inling of code with `recurse_or_return` instruction"
269        );
270    }
271}