tasm_lib/list/higher_order/
all.rs1use itertools::Itertools;
2use triton_vm::prelude::*;
3
4use super::inner_function::InnerFunction;
5use crate::list::get::Get;
6use crate::list::length::Length;
7use crate::prelude::*;
8
9pub struct All {
12 pub f: InnerFunction,
13}
14
15impl All {
16 pub fn new(f: InnerFunction) -> Self {
17 Self { f }
18 }
19}
20
21impl BasicSnippet for All {
22 fn parameters(&self) -> Vec<(DataType, String)> {
23 let element_type = self.f.domain();
24 let list_type = DataType::List(Box::new(element_type));
25 vec![(list_type, "*input_list".to_string())]
26 }
27
28 fn return_values(&self) -> Vec<(DataType, String)> {
29 vec![(DataType::Bool, "all_true".to_string())]
30 }
31
32 fn entrypoint(&self) -> String {
33 format!("tasmlib_list_higher_order_u32_all_{}", self.f.entrypoint())
34 }
35
36 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
37 let input_type = self.f.domain();
38 let output_type = self.f.range();
39 assert_eq!(output_type, DataType::Bool);
40
41 let get_length = library.import(Box::new(Length));
42 let list_get = library.import(Box::new(Get::new(input_type)));
43
44 let inner_function_name = match &self.f {
45 InnerFunction::RawCode(rc) => rc.entrypoint(),
46 InnerFunction::NoFunctionBody(_) => todo!(),
47 InnerFunction::BasicSnippet(bs) => {
48 let labelled_instructions = bs.annotated_code(library);
49 library.explicit_import(&bs.entrypoint(), &labelled_instructions)
50 }
51 };
52
53 let maybe_inner_function_body_raw = match &self.f {
56 InnerFunction::RawCode(rc) => rc.function.iter().map(|x| x.to_string()).join("\n"),
57 InnerFunction::NoFunctionBody(_) => todo!(),
58 InnerFunction::BasicSnippet(_) => Default::default(),
59 };
60 let entrypoint = self.entrypoint();
61 let main_loop = format!("{entrypoint}_loop");
62
63 let result_type_hint = format!("hint all_{}: Boolean = stack[0]", self.f.entrypoint());
64
65 triton_asm!(
66 {entrypoint}:
69 hint input_list = stack[0]
70 push 1 {result_type_hint}
72 swap 1 dup 0 call {get_length}
75 hint list_item: Index = stack[0]
76 call {main_loop}
79 pop 2 return
83
84 {main_loop}:
86 dup 0 push 0 eq
88 skiz return
91 push -1 add
95
96 dup 1 dup 1
100 call {list_get}
102 call {inner_function_name}
106 dup 3 mul swap 3 pop 1 recurse
115
116 {maybe_inner_function_body_raw}
117 )
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use num::One;
124 use num::Zero;
125
126 use super::*;
127 use crate::arithmetic;
128 use crate::empty_stack;
129 use crate::list::LIST_METADATA_SIZE;
130 use crate::list::higher_order::inner_function::RawCode;
131 use crate::rust_shadowing_helper_functions;
132 use crate::rust_shadowing_helper_functions::list::list_get;
133 use crate::rust_shadowing_helper_functions::list::untyped_insert_random_list;
134 use crate::test_helpers::test_rust_equivalence_given_complete_state;
135 use crate::test_prelude::*;
136
137 impl All {
138 fn generate_input_state(
139 &self,
140 list_pointer: BFieldElement,
141 list_length: usize,
142 random: bool,
143 ) -> InitVmState {
144 let mut stack = empty_stack();
145 stack.push(list_pointer);
146
147 let mut memory = HashMap::default();
148 let input_type = self.f.domain();
149 let list_bookkeeping_offset = LIST_METADATA_SIZE;
150 let element_index_in_list =
151 list_bookkeeping_offset + list_length * input_type.stack_size();
152 let element_index = list_pointer + BFieldElement::new(element_index_in_list as u64);
153 memory.insert(BFieldElement::zero(), element_index);
154
155 if random {
156 untyped_insert_random_list(
157 list_pointer,
158 list_length,
159 &mut memory,
160 input_type.stack_size(),
161 );
162 } else {
163 rust_shadowing_helper_functions::list::list_insert(
164 list_pointer,
165 (0..list_length as u64)
166 .map(BFieldElement::new)
167 .collect_vec(),
168 &mut memory,
169 );
170 }
171
172 InitVmState::with_stack_and_memory(stack, memory)
173 }
174 }
175
176 impl Function for All {
177 fn rust_shadow(
178 &self,
179 stack: &mut Vec<BFieldElement>,
180 memory: &mut HashMap<BFieldElement, BFieldElement>,
181 ) -> Result<(), RustShadowError> {
182 let input_type = self.f.domain();
183 let list_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
184
185 let list_length =
187 rust_shadowing_helper_functions::list::list_get_length(list_pointer, memory)?;
188 let mut satisfied = true;
189 for i in 0..list_length {
190 let input_item = list_get(list_pointer, i, memory, input_type.stack_size())?;
191 for bfe in input_item.into_iter().rev() {
192 stack.push(bfe);
193 }
194
195 self.f.apply(stack, memory);
196
197 let single_result =
198 stack.pop().ok_or(RustShadowError::StackUnderflow)?.value() != 0;
199 satisfied = satisfied && single_result;
200 }
201
202 stack.push(BFieldElement::new(satisfied as u64));
203 Ok(())
204 }
205
206 fn pseudorandom_initial_state(
207 &self,
208 seed: [u8; 32],
209 bench_case: Option<BenchmarkCase>,
210 ) -> FunctionInitialState {
211 let (stack, memory) = match bench_case {
212 Some(BenchmarkCase::CommonCase) => {
213 let list_pointer = BFieldElement::new(5);
214 let list_length = 10;
215 let execution_state =
216 self.generate_input_state(list_pointer, list_length, false);
217 (execution_state.stack, execution_state.nondeterminism.ram)
218 }
219 Some(BenchmarkCase::WorstCase) => {
220 let list_pointer = BFieldElement::new(5);
221 let list_length = 100;
222 let execution_state =
223 self.generate_input_state(list_pointer, list_length, false);
224 (execution_state.stack, execution_state.nondeterminism.ram)
225 }
226 None => {
227 let mut rng = StdRng::from_seed(seed);
228 let list_pointer = BFieldElement::new(rng.next_u64() % (1 << 20));
229 let list_length = 1 << (rng.next_u32() as usize % 4);
230 let execution_state =
231 self.generate_input_state(list_pointer, list_length, true);
232 (execution_state.stack, execution_state.nondeterminism.ram)
233 }
234 };
235
236 FunctionInitialState { stack, memory }
237 }
238 }
239
240 #[macro_rules_attr::apply(test)]
241 fn rust_shadow() {
242 let inner_function = InnerFunction::BasicSnippet(Box::new(TestHashXFieldElementLsb));
243 ShadowedFunction::new(All::new(inner_function)).test();
244 }
245
246 #[macro_rules_attr::apply(test)]
247 fn all_lt_test() {
248 const TWO_POW_31: u64 = 1u64 << 31;
249 let rawcode = RawCode::new(
250 triton_asm!(
251 less_than_2_pow_31:
252 push 2147483648 swap 1
254 lt
255 return
256 ),
257 DataType::Bfe,
258 DataType::Bool,
259 );
260 let snippet = All::new(InnerFunction::RawCode(rawcode));
261 let mut memory = HashMap::new();
262
263 rust_shadowing_helper_functions::list::list_insert(
265 BFieldElement::new(42),
266 (0..30).map(BFieldElement::new).collect_vec(),
267 &mut memory,
268 );
269 let input_stack = [empty_stack(), vec![BFieldElement::new(42)]].concat();
270 let expected_end_stack_true = [empty_stack(), vec![BFieldElement::one()]].concat();
271 let shadowed_snippet = ShadowedFunction::new(snippet);
272 let mut nondeterminism = NonDeterminism::default().with_ram(memory);
273 test_rust_equivalence_given_complete_state(
274 &shadowed_snippet,
275 &input_stack,
276 &[],
277 &nondeterminism,
278 &None,
279 Some(&expected_end_stack_true),
280 );
281
282 rust_shadowing_helper_functions::list::list_insert(
284 BFieldElement::new(42),
285 (0..30)
286 .map(|x| BFieldElement::new(x + TWO_POW_31 - 20))
287 .collect_vec(),
288 &mut nondeterminism.ram,
289 );
290 let expected_end_stack_false = [empty_stack(), vec![BFieldElement::zero()]].concat();
291 test_rust_equivalence_given_complete_state(
292 &shadowed_snippet,
293 &input_stack,
294 &[],
295 &nondeterminism,
296 &None,
297 Some(&expected_end_stack_false),
298 );
299 }
300
301 #[macro_rules_attr::apply(test)]
302 fn test_with_raw_function_lsb_on_bfe() {
303 let rawcode = RawCode::new(
304 triton_asm!(
305 lsb_bfe:
306 split push 2 swap 1 div_mod swap 2 pop 2 return
313 ),
314 DataType::Bfe,
315 DataType::Bool,
316 );
317 let snippet = All::new(InnerFunction::RawCode(rawcode));
318 ShadowedFunction::new(snippet).test();
319 }
320
321 #[macro_rules_attr::apply(test)]
322 fn test_with_raw_function_eq_42() {
323 let raw_code = RawCode::new(
324 triton_asm!(
325 eq_42:
326 push 42
327 eq
328 return
329 ),
330 DataType::U32,
331 DataType::Bool,
332 );
333 let snippet = All::new(InnerFunction::RawCode(raw_code));
334 ShadowedFunction::new(snippet).test();
335 }
336
337 #[macro_rules_attr::apply(test)]
338 fn test_with_raw_function_lsb_on_xfe() {
339 let rawcode = RawCode::new(
340 triton_asm!(
341 lsb_xfe:
342 split push 2 swap 1 div_mod swap 4 pop 4 return
349 ),
350 DataType::Xfe,
351 DataType::Bool,
352 );
353 let snippet = All::new(InnerFunction::RawCode(rawcode));
354 ShadowedFunction::new(snippet).test();
355 }
356
357 #[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
359 pub(super) struct TestHashXFieldElementLsb;
360
361 impl BasicSnippet for TestHashXFieldElementLsb {
362 fn parameters(&self) -> Vec<(DataType, String)> {
363 vec![(DataType::Xfe, "element".to_string())]
364 }
365
366 fn return_values(&self) -> Vec<(DataType, String)> {
367 vec![(DataType::Bool, "bool".to_string())]
368 }
369
370 fn entrypoint(&self) -> String {
371 "test_hash_xfield_element_lsb".to_string()
372 }
373
374 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
375 let entrypoint = self.entrypoint();
376 let unused_import = library.import(Box::new(arithmetic::u32::safe_add::SafeAdd));
377 triton_asm!(
378 {entrypoint}:
381 push 0
385 push 0
386 call {unused_import}
387 pop 1
388
389 push 0
390 push 0
391 push 0
392 push 0
393 push 0
394 push 0
395 push 1 pick 9
397 pick 9
398 pick 9 sponge_init
401 sponge_absorb
402 sponge_squeeze
403 split
406 push 2
407 place 1
408 div_mod place 11
411 pop 5
412 pop 5
413 pop 1 return
416 )
417 }
418 }
419}
420
421#[cfg(test)]
422mod benches {
423 use super::tests::TestHashXFieldElementLsb;
424 use super::*;
425 use crate::test_prelude::*;
426
427 #[macro_rules_attr::apply(test)]
428 fn benchmark() {
429 let inner_function = InnerFunction::BasicSnippet(Box::new(TestHashXFieldElementLsb));
430 ShadowedFunction::new(All::new(inner_function)).bench();
431 }
432}