tasm_lib/list/higher_order/
all.rs1use itertools::Itertools;
2use triton_vm::prelude::*;
3
4use super::inner_function::InnerFunction;
5use crate::list::get::Get;
6use crate::list::length::Length;
7use crate::prelude::*;
8
9pub struct All {
12 pub f: InnerFunction,
13}
14
15impl All {
16 pub fn new(f: InnerFunction) -> Self {
17 Self { f }
18 }
19}
20
21impl BasicSnippet for All {
22 fn inputs(&self) -> Vec<(DataType, String)> {
23 let element_type = self.f.domain();
24 let list_type = DataType::List(Box::new(element_type));
25 vec![(list_type, "*input_list".to_string())]
26 }
27
28 fn outputs(&self) -> Vec<(DataType, String)> {
29 vec![(DataType::Bool, "all_true".to_string())]
30 }
31
32 fn entrypoint(&self) -> String {
33 format!("tasmlib_list_higher_order_u32_all_{}", self.f.entrypoint())
34 }
35
36 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
37 let input_type = self.f.domain();
38 let output_type = self.f.range();
39 assert_eq!(output_type, DataType::Bool);
40
41 let get_length = library.import(Box::new(Length));
42 let list_get = library.import(Box::new(Get::new(input_type)));
43
44 let inner_function_name = match &self.f {
45 InnerFunction::RawCode(rc) => rc.entrypoint(),
46 InnerFunction::NoFunctionBody(_) => todo!(),
47 InnerFunction::BasicSnippet(bs) => {
48 let labelled_instructions = bs.annotated_code(library);
49 library.explicit_import(&bs.entrypoint(), &labelled_instructions)
50 }
51 };
52
53 let maybe_inner_function_body_raw = match &self.f {
56 InnerFunction::RawCode(rc) => rc.function.iter().map(|x| x.to_string()).join("\n"),
57 InnerFunction::NoFunctionBody(_) => todo!(),
58 InnerFunction::BasicSnippet(_) => Default::default(),
59 };
60 let entrypoint = self.entrypoint();
61 let main_loop = format!("{entrypoint}_loop");
62
63 let result_type_hint = format!("hint all_{}: Boolean = stack[0]", self.f.entrypoint());
64
65 triton_asm!(
66 {entrypoint}:
69 hint input_list = stack[0]
70 push 1 {result_type_hint}
72 swap 1 dup 0 call {get_length}
75 hint list_item: Index = stack[0]
76 call {main_loop}
79 pop 2 return
83
84 {main_loop}:
86 dup 0 push 0 eq
88 skiz return
91 push -1 add
95
96 dup 1 dup 1
100 call {list_get}
102 call {inner_function_name}
106 dup 3 mul swap 3 pop 1 recurse
115
116 {maybe_inner_function_body_raw}
117 )
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use num::One;
124 use num::Zero;
125
126 use super::*;
127 use crate::arithmetic;
128 use crate::empty_stack;
129 use crate::list::LIST_METADATA_SIZE;
130 use crate::list::higher_order::inner_function::RawCode;
131 use crate::rust_shadowing_helper_functions;
132 use crate::rust_shadowing_helper_functions::list::list_get;
133 use crate::rust_shadowing_helper_functions::list::untyped_insert_random_list;
134 use crate::test_helpers::test_rust_equivalence_given_complete_state;
135 use crate::test_prelude::*;
136
137 impl All {
138 fn generate_input_state(
139 &self,
140 list_pointer: BFieldElement,
141 list_length: usize,
142 random: bool,
143 ) -> InitVmState {
144 let mut stack = empty_stack();
145 stack.push(list_pointer);
146
147 let mut memory = HashMap::default();
148 let input_type = self.f.domain();
149 let list_bookkeeping_offset = LIST_METADATA_SIZE;
150 let element_index_in_list =
151 list_bookkeeping_offset + list_length * input_type.stack_size();
152 let element_index = list_pointer + BFieldElement::new(element_index_in_list as u64);
153 memory.insert(BFieldElement::zero(), element_index);
154
155 if random {
156 untyped_insert_random_list(
157 list_pointer,
158 list_length,
159 &mut memory,
160 input_type.stack_size(),
161 );
162 } else {
163 rust_shadowing_helper_functions::list::list_insert(
164 list_pointer,
165 (0..list_length as u64)
166 .map(BFieldElement::new)
167 .collect_vec(),
168 &mut memory,
169 );
170 }
171
172 InitVmState::with_stack_and_memory(stack, memory)
173 }
174 }
175
176 impl Function for All {
177 fn rust_shadow(
178 &self,
179 stack: &mut Vec<BFieldElement>,
180 memory: &mut HashMap<BFieldElement, BFieldElement>,
181 ) {
182 let input_type = self.f.domain();
183 let list_pointer = stack.pop().unwrap();
184
185 let list_length =
187 rust_shadowing_helper_functions::list::list_get_length(list_pointer, memory);
188 let mut satisfied = true;
189 for i in 0..list_length {
190 let input_item = list_get(list_pointer, i, memory, input_type.stack_size());
191 for bfe in input_item.into_iter().rev() {
192 stack.push(bfe);
193 }
194
195 self.f.apply(stack, memory);
196
197 let single_result = stack.pop().unwrap().value() != 0;
198 satisfied = satisfied && single_result;
199 }
200
201 stack.push(BFieldElement::new(satisfied as u64));
202 }
203
204 fn pseudorandom_initial_state(
205 &self,
206 seed: [u8; 32],
207 bench_case: Option<BenchmarkCase>,
208 ) -> FunctionInitialState {
209 let (stack, memory) = match bench_case {
210 Some(BenchmarkCase::CommonCase) => {
211 let list_pointer = BFieldElement::new(5);
212 let list_length = 10;
213 let execution_state =
214 self.generate_input_state(list_pointer, list_length, false);
215 (execution_state.stack, execution_state.nondeterminism.ram)
216 }
217 Some(BenchmarkCase::WorstCase) => {
218 let list_pointer = BFieldElement::new(5);
219 let list_length = 100;
220 let execution_state =
221 self.generate_input_state(list_pointer, list_length, false);
222 (execution_state.stack, execution_state.nondeterminism.ram)
223 }
224 None => {
225 let mut rng = StdRng::from_seed(seed);
226 let list_pointer = BFieldElement::new(rng.next_u64() % (1 << 20));
227 let list_length = 1 << (rng.next_u32() as usize % 4);
228 let execution_state =
229 self.generate_input_state(list_pointer, list_length, true);
230 (execution_state.stack, execution_state.nondeterminism.ram)
231 }
232 };
233
234 FunctionInitialState { stack, memory }
235 }
236 }
237
238 #[test]
239 fn rust_shadow() {
240 let inner_function = InnerFunction::BasicSnippet(Box::new(TestHashXFieldElementLsb));
241 ShadowedFunction::new(All::new(inner_function)).test();
242 }
243
244 #[test]
245 fn all_lt_test() {
246 const TWO_POW_31: u64 = 1u64 << 31;
247 let rawcode = RawCode::new(
248 triton_asm!(
249 less_than_2_pow_31:
250 push 2147483648 swap 1
252 lt
253 return
254 ),
255 DataType::Bfe,
256 DataType::Bool,
257 );
258 let snippet = All::new(InnerFunction::RawCode(rawcode));
259 let mut memory = HashMap::new();
260
261 rust_shadowing_helper_functions::list::list_insert(
263 BFieldElement::new(42),
264 (0..30).map(BFieldElement::new).collect_vec(),
265 &mut memory,
266 );
267 let input_stack = [empty_stack(), vec![BFieldElement::new(42)]].concat();
268 let expected_end_stack_true = [empty_stack(), vec![BFieldElement::one()]].concat();
269 let shadowed_snippet = ShadowedFunction::new(snippet);
270 let mut nondeterminism = NonDeterminism::default().with_ram(memory);
271 test_rust_equivalence_given_complete_state(
272 &shadowed_snippet,
273 &input_stack,
274 &[],
275 &nondeterminism,
276 &None,
277 Some(&expected_end_stack_true),
278 );
279
280 rust_shadowing_helper_functions::list::list_insert(
282 BFieldElement::new(42),
283 (0..30)
284 .map(|x| BFieldElement::new(x + TWO_POW_31 - 20))
285 .collect_vec(),
286 &mut nondeterminism.ram,
287 );
288 let expected_end_stack_false = [empty_stack(), vec![BFieldElement::zero()]].concat();
289 test_rust_equivalence_given_complete_state(
290 &shadowed_snippet,
291 &input_stack,
292 &[],
293 &nondeterminism,
294 &None,
295 Some(&expected_end_stack_false),
296 );
297 }
298
299 #[test]
300 fn test_with_raw_function_lsb_on_bfe() {
301 let rawcode = RawCode::new(
302 triton_asm!(
303 lsb_bfe:
304 split push 2 swap 1 div_mod swap 2 pop 2 return
311 ),
312 DataType::Bfe,
313 DataType::Bool,
314 );
315 let snippet = All::new(InnerFunction::RawCode(rawcode));
316 ShadowedFunction::new(snippet).test();
317 }
318
319 #[test]
320 fn test_with_raw_function_eq_42() {
321 let raw_code = RawCode::new(
322 triton_asm!(
323 eq_42:
324 push 42
325 eq
326 return
327 ),
328 DataType::U32,
329 DataType::Bool,
330 );
331 let snippet = All::new(InnerFunction::RawCode(raw_code));
332 ShadowedFunction::new(snippet).test();
333 }
334
335 #[test]
336 fn test_with_raw_function_lsb_on_xfe() {
337 let rawcode = RawCode::new(
338 triton_asm!(
339 lsb_xfe:
340 split push 2 swap 1 div_mod swap 4 pop 4 return
347 ),
348 DataType::Xfe,
349 DataType::Bool,
350 );
351 let snippet = All::new(InnerFunction::RawCode(rawcode));
352 ShadowedFunction::new(snippet).test();
353 }
354
355 #[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
357 pub(super) struct TestHashXFieldElementLsb;
358
359 impl BasicSnippet for TestHashXFieldElementLsb {
360 fn inputs(&self) -> Vec<(DataType, String)> {
361 vec![(DataType::Xfe, "element".to_string())]
362 }
363
364 fn outputs(&self) -> Vec<(DataType, String)> {
365 vec![(DataType::Bool, "bool".to_string())]
366 }
367
368 fn entrypoint(&self) -> String {
369 "test_hash_xfield_element_lsb".to_string()
370 }
371
372 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
373 let entrypoint = self.entrypoint();
374 let unused_import = library.import(Box::new(arithmetic::u32::safe_add::SafeAdd));
375 triton_asm!(
376 {entrypoint}:
379 push 0
383 push 0
384 call {unused_import}
385 pop 1
386
387 push 0
388 push 0
389 push 0
390 push 0
391 push 0
392 push 0
393 push 1 pick 9
395 pick 9
396 pick 9 sponge_init
399 sponge_absorb
400 sponge_squeeze
401 split
404 push 2
405 place 1
406 div_mod place 11
409 pop 5
410 pop 5
411 pop 1 return
414 )
415 }
416 }
417}
418
419#[cfg(test)]
420mod benches {
421 use super::tests::TestHashXFieldElementLsb;
422 use super::*;
423 use crate::test_prelude::*;
424
425 #[test]
426 fn benchmark() {
427 let inner_function = InnerFunction::BasicSnippet(Box::new(TestHashXFieldElementLsb));
428 ShadowedFunction::new(All::new(inner_function)).bench();
429 }
430}