tasm_lib/list/higher_order/
inner_function.rs1use std::collections::HashMap;
2
3use triton_vm::isa::instruction::AnInstruction;
4use triton_vm::prelude::*;
5
6use crate::prelude::*;
7
8const MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION: &str = "higher-order functions \
9currently only work with *one* input element in inner function. \
10Use a tuple data type to circumvent this.";
11
12#[derive(Debug, Clone, Eq, PartialEq, Hash)]
15pub struct RawCode {
16 pub function: Vec<LabelledInstruction>,
17 pub input_type: DataType,
18 pub output_type: DataType,
19}
20
21impl RawCode {
22 pub fn new(
23 function: Vec<LabelledInstruction>,
24 input_type: DataType,
25 output_type: DataType,
26 ) -> Self {
27 let is_label = |x: &_| matches!(x, LabelledInstruction::Label(_));
28 let is_instruction = |x: &_| matches!(x, LabelledInstruction::Instruction(_));
29 let labels_and_instructions = function.iter().filter(|i| is_label(i) || is_instruction(i));
30
31 assert!(
33 labels_and_instructions.count() >= 2,
34 "Inner function must have at least two lines: a label and a return or recurse"
35 );
36 assert!(
37 matches!(function[0], LabelledInstruction::Label(_)),
38 "First line of inner function must be label. Got: {}",
39 function[0]
40 );
41 assert!(
42 matches!(
43 function.last().unwrap(),
44 LabelledInstruction::Instruction(AnInstruction::Return)
45 | LabelledInstruction::Instruction(AnInstruction::Recurse)
46 | LabelledInstruction::Instruction(AnInstruction::RecurseOrReturn)
47 ),
48 "Last line of inner function must be either return, recurse, or recurse_or_return. Got: {}",
49 function.last().unwrap()
50 );
51
52 Self {
53 function,
54 input_type,
55 output_type,
56 }
57 }
58}
59
60impl RawCode {
61 pub fn entrypoint(&self) -> String {
63 let is_label = |x: &_| matches!(x, LabelledInstruction::Label(_));
64 let is_instruction = |x: &_| matches!(x, LabelledInstruction::Instruction(_));
65 let first_label_or_instruction = self
66 .function
67 .iter()
68 .find(|&x| is_label(x) || is_instruction(x));
69 let Some(labelled_instruction) = first_label_or_instruction else {
70 panic!("Inner function must start with a label. Got neither labels nor instructions.")
71 };
72 let LabelledInstruction::Label(label) = labelled_instruction else {
73 panic!("Inner function must start with a label. Got: {labelled_instruction}");
74 };
75
76 label.to_string()
77 }
78
79 pub fn inlined_body(&self) -> Option<Vec<LabelledInstruction>> {
83 let is_label = |x: &_| matches!(x, LabelledInstruction::Label(_));
84 let is_instruction = |x: &_| matches!(x, LabelledInstruction::Instruction(_));
85 let is_recursive = |x: &_| {
86 matches!(
87 x,
88 LabelledInstruction::Instruction(AnInstruction::Recurse)
89 | LabelledInstruction::Instruction(AnInstruction::RecurseOrReturn)
90 )
91 };
92
93 if self.function.iter().any(is_recursive) {
94 return None;
96 }
97
98 let mut labels_and_instructions = self
99 .function
100 .iter()
101 .filter(|i| is_label(i) || is_instruction(i));
102
103 let Some(first_thing) = labels_and_instructions.next() else {
104 return Some(triton_asm!());
105 };
106 let LabelledInstruction::Label(_) = first_thing else {
107 panic!("Raw Code must start with a label.")
108 };
109
110 let Some(LabelledInstruction::Instruction(AnInstruction::Return)) =
111 labels_and_instructions.next_back()
112 else {
113 panic!("Raw Code is probably buggy: too short, or doesn't end with `return`.");
114 };
115
116 Some(labels_and_instructions.cloned().collect())
117 }
118}
119
120pub enum InnerFunction {
121 RawCode(RawCode),
122 BasicSnippet(Box<dyn BasicSnippet>),
123
124 NoFunctionBody(NoFunctionBody),
127}
128
129#[derive(Debug, Clone, Eq, PartialEq, Hash)]
130pub struct NoFunctionBody {
131 pub label_name: String,
132 pub input_type: DataType,
133 pub output_type: DataType,
134}
135
136impl InnerFunction {
137 pub fn domain(&self) -> DataType {
138 match self {
139 InnerFunction::RawCode(raw) => raw.input_type.clone(),
140 InnerFunction::NoFunctionBody(f) => f.input_type.clone(),
141 InnerFunction::BasicSnippet(bs) => {
142 let [(ref input, _)] = bs.inputs()[..] else {
143 panic!("{MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION}");
144 };
145 input.clone()
146 }
147 }
148 }
149
150 pub fn range(&self) -> DataType {
151 match self {
152 InnerFunction::RawCode(rc) => rc.output_type.clone(),
153 InnerFunction::NoFunctionBody(lnat) => lnat.output_type.clone(),
154 InnerFunction::BasicSnippet(bs) => {
155 let [(ref output, _)] = bs.outputs()[..] else {
156 panic!("{MORE_THAN_ONE_INPUT_OR_OUTPUT_TYPE_IN_INNER_FUNCTION}");
157 };
158 output.clone()
159 }
160 }
161 }
162
163 pub fn entrypoint(&self) -> String {
165 match self {
166 InnerFunction::RawCode(rc) => rc.entrypoint(),
167 InnerFunction::NoFunctionBody(sn) => sn.label_name.to_owned(),
168 InnerFunction::BasicSnippet(bs) => bs.entrypoint(),
169 }
170 }
171
172 pub fn apply(
174 &self,
175 stack: &mut Vec<BFieldElement>,
176 memory: &HashMap<BFieldElement, BFieldElement>,
177 ) {
178 match &self {
179 InnerFunction::RawCode(rc) => Self::run_vm(&rc.function, stack, memory),
180 InnerFunction::NoFunctionBody(_lnat) => {
181 panic!("Cannot apply inner function without function body")
182 }
183 InnerFunction::BasicSnippet(bs) => {
184 let mut library = Library::new();
185 let function = bs.annotated_code(&mut library);
186 let imports = library.all_imports();
187 let code = triton_asm!(
188 {&function}
189 {&imports}
190 );
191
192 Self::run_vm(&code, stack, memory);
193 }
194 };
195 }
196
197 fn run_vm(
201 instructions: &[LabelledInstruction],
202 stack: &mut Vec<BFieldElement>,
203 memory: &HashMap<BFieldElement, BFieldElement>,
204 ) {
205 let Some(LabelledInstruction::Label(label)) = instructions.first() else {
206 panic!();
207 };
208 let instructions = triton_asm!(
209 call {label}
210 halt
211 {&instructions}
212 );
213 let program = Program::new(&instructions);
214 let mut vmstate = VMState::new(program, PublicInput::default(), NonDeterminism::default());
215 vmstate.op_stack.stack.clone_from(stack);
216 vmstate.ram.clone_from(memory);
217 vmstate.run().unwrap();
218 *stack = vmstate.op_stack.stack;
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn breakpoint_does_not_influence_raw_code_inlining() {
228 let raw_code = RawCode {
229 function: triton_asm! { my_label: return break },
230 input_type: DataType::VoidPointer,
231 output_type: DataType::VoidPointer,
232 };
233 let inlined_code = raw_code.inlined_body().unwrap();
234 assert_eq!(triton_asm!(), inlined_code);
235 }
236
237 #[test]
238 fn type_hints_do_not_influence_raw_code_inlining() {
239 let raw_code = RawCode {
240 function: triton_asm! { my_label: hint a = stack[0] hint b = stack[1] return },
241 input_type: DataType::VoidPointer,
242 output_type: DataType::VoidPointer,
243 };
244 let inlined_code = raw_code.inlined_body().unwrap();
245 assert_eq!(triton_asm!(), inlined_code);
246 }
247
248 #[test]
249 fn allow_raw_code_with_recurse_or_return_instruction() {
250 let raw_code = triton_asm!(
251 please_help_me:
252 hint im_falling = stack[0]
253 hint in_love_with_you = stack[1]
254
255 call close_the_door_to_temptation
256
257 return
258
259 close_the_door_to_temptation:
260 hint turn_away_from_me_darling = stack[5]
261 break
262 merkle_step_mem
263 recurse_or_return
264 );
265 let raw_code = RawCode::new(raw_code, DataType::VoidPointer, DataType::VoidPointer);
266 assert!(
267 raw_code.inlined_body().is_none(),
268 "Disallow inling of code with `recurse_or_return` instruction"
269 );
270 }
271}