tasm_lib/verifier/xfe_ntt.rs
1use triton_vm::prelude::*;
2
3use crate::prelude::*;
4
5#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
6pub struct XfeNtt;
7
8impl BasicSnippet for XfeNtt {
9 fn inputs(&self) -> Vec<(DataType, String)> {
10 vec![
11 (DataType::List(Box::new(DataType::Xfe)), "x".to_owned()),
12 (DataType::Bfe, "omega".to_owned()),
13 ]
14 }
15
16 fn outputs(&self) -> Vec<(DataType, String)> {
17 vec![(DataType::Tuple(vec![]), "result".to_owned())]
18 }
19
20 fn entrypoint(&self) -> String {
21 "tasmlib_verifier_xfe_ntt".to_owned()
22 }
23
24 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
25 let entrypoint = self.entrypoint();
26 let tasm_arithmetic_u32_leadingzeros = library.import(Box::new(
27 crate::arithmetic::u32::leading_zeros::LeadingZeros,
28 ));
29 let tasm_list_length = library.import(Box::new(crate::list::length::Length));
30 const THREE_INV: BFieldElement = BFieldElement::new(12297829379609722881);
31
32 let while_loop_with_bitreverse = format!("{entrypoint}_while_with_bitreverse");
33 let outer_loop = format!("{entrypoint}_while_outer");
34 let middle_loop = format!("{entrypoint}_while_middle");
35 let inner_loop = format!("{entrypoint}_while_inner");
36 let bitreverse_function = format!("{entrypoint}_bitreverse_function");
37 let bitreverse_loop = format!("{entrypoint}_bitreverse_while");
38 let k_lt_r_then_branch = format!("{entrypoint}_k_lt_r_then_branch");
39 // _binop_Lt__LboolR_bool_74_while_loop
40
41 triton_asm!(
42
43 {entrypoint}:
44 // _ *x omega
45
46 dup 1
47 // _ *x omega *x
48
49 call {tasm_list_length}
50 // _ *x omega size
51
52 push 32
53 dup 1
54 call {tasm_arithmetic_u32_leadingzeros}
55 push -1
56 mul
57 add
58 push -1
59 add
60 push 0
61 // _ *x omega size log_2_size k
62
63 call {while_loop_with_bitreverse}
64 // _ *x omega size log_2_size k
65
66 pop 1
67 // _ *x omega size log_2_size
68
69 push 1
70 // _ *x omega size log_2_size m
71
72 push 0
73 // _ *x omega size log_2_size m outer_count
74
75 call {outer_loop}
76 pop 5
77 pop 1
78
79 return
80
81 // Subroutines:
82
83 // Invariant: n l r i
84 {bitreverse_loop}:
85 dup 0
86 dup 3
87 eq
88 skiz
89 return
90 // _ n l r i
91
92 swap 1
93 push 2
94 mul
95 // _ n l i (r * 2)
96
97 dup 3
98 // _ n l i (r * 2) n
99
100 push 1
101 and
102 // _ n l i (r * 2) (n & 1)
103 dup 1
104 dup 1
105 // _ n l i (r * 2) (n & 1) (r * 2) (n & 1)
106 xor
107 // _ n l i (r * 2) (n & 1) ((r * 2) ^ (n & 1))
108
109 swap 2
110 // _ n l i ((r * 2) ^ (n & 1)) (n & 1) (r * 2)
111
112 and
113 // _ n l i ((r * 2) ^ (n & 1)) ((n & 1) && (r * 2))
114
115 add
116 // _ n l i (((r * 2) ^ (n & 1)) + ((n & 1) && (r * 2)))
117 // _ n l i r'
118
119 swap 1
120 // _ n l r' i
121
122 push 2
123 dup 4
124 // _ n l r' i n
125
126 div_mod
127 pop 1
128 // _ n l r' i (n / 2)
129
130 swap 4
131 pop 1
132 // _ (n / 2) l r' i
133
134 push 1
135 add
136 // _ (n / 2) l r' i'
137
138 recurse
139
140 {bitreverse_function}:
141 // _ *x omega size log_2_size k n l
142
143 push 0
144 push 0
145 call {bitreverse_loop}
146 pop 1
147 swap 2
148 pop 2
149 return
150
151 // _ *x omega size log_2_size k rk
152 {k_lt_r_then_branch}:
153
154 dup 5
155 // _ *x omega size log_2_size k rk *x
156
157 dup 0
158 // _ *x omega size log_2_size k rk *x *x
159
160 swap 2
161 // _ *x omega size log_2_size k *x *x rk
162
163 push 3
164 mul
165 push 3
166 add
167 add
168 // _ *x omega size log_2_size k *x *(x[rk] + 2)
169
170 read_mem 3
171 // _ *x omega size log_2_size k *x [x[rk]] *(x[rk] - 1)
172
173 push 1
174 add
175 // _ *x omega size log_2_size k *x [x[rk]] *x[rk]
176
177 dup 5
178 // _ *x omega size log_2_size k *x [x[rk]] *x[rk] k
179
180 push 3
181 mul
182 push 3
183 add
184 // _ *x omega size log_2_size k *x [x[rk]] *x[rk] k_offset
185
186 dup 5
187 add
188 // _ *x omega size log_2_size k *x [x[rk]] *x[rk] *(x[k] + 2)
189
190 read_mem 3
191 // _ *x omega size log_2_size k *x [x[rk]] *x[rk] [x[k]] *(x[k] - 1)
192
193 push 1
194 add
195 // _ *x omega size log_2_size k *x [x[rk]] *x[rk] [x[k]] *x[k]
196
197 swap 4
198 // _ *x omega size log_2_size k *x [x[rk]] *x[k] [x[k]] *x[rk]
199
200 write_mem 3
201 pop 1
202 // _ *x omega size log_2_size k *x [x[rk]] *x[k]
203
204 write_mem 3
205 // _ *x omega size log_2_size k *x *(x[k] +3)
206
207 pop 1
208 // _ *x omega size log_2_size k *x
209
210 return
211
212 // 1st loop, where `bitreverse` is called
213 {while_loop_with_bitreverse}:
214 // _ *x omega size log_2_size k
215
216 dup 0
217 dup 3
218 eq
219 skiz
220 return
221 // _ *x omega size log_2_size k
222
223 dup 0
224 dup 2
225 // _ *x omega size log_2_size k k log_2_size
226 call {bitreverse_function}
227 // _ *x omega size log_2_size k rk
228
229 dup 0
230 dup 2
231 // _ *x omega size log_2_size k rk rk k
232
233 lt
234 // _ *x omega size log_2_size k rk (k < rk)
235
236 skiz
237 call {k_lt_r_then_branch}
238 // _ *x omega size log_2_size k (rk|*x)
239
240 pop 1
241 // _ *x omega size log_2_size k
242
243 push 1
244 add
245 // _ *x omega size log_2_size (k+1)
246
247 recurse
248
249 // Last while-loop, *inner*, `j != m` <-- The busy-loop!
250 {inner_loop}:
251 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *x[k+j]
252
253 dup 1
254 dup 1
255 eq
256 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *x[k + j] (j == m)
257 skiz
258 return
259 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *x[k + j]
260 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx
261
262 dup 0
263 push 2
264 add
265 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx *x[k + j]_last_word
266
267 read_mem 3
268 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [x[k + j]] *x[k + j - 1]_last_word
269
270 dup 10
271 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [x[k + j]] *x[k + j - 1]_last_word (3*m)
272
273 push 3
274 add
275 add
276 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [x[k + j]] *x[k + j + m]_last_word
277
278 read_mem 3
279 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [x[k+j]] [x[k+j+m]] *x[k+j+m-1]_last_word
280
281 pop 1
282 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [x[k+j]] [x[k+j+m]]
283 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [u] [v]
284
285 dup 8
286 xb_mul
287 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [u] (v * w)
288 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [u] [v']
289
290 dup 5
291 dup 5
292 dup 5
293 dup 5
294 dup 5
295 dup 5
296 xx_add
297 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [u] [v'] [u + v']
298
299 dup 9
300 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [u] [v'] [u + v'] *x[k + j]
301
302 write_mem 3
303 pop 1
304 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [u] [v']
305
306 push -1
307 xb_mul
308 xx_add
309 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [u - v']
310
311 dup 3
312 dup 10
313 add
314 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [u - v'] *x[k + j + m]
315
316 write_mem 3
317 pop 1
318 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx
319
320 push 3 add
321 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *x[k + j + 1]
322
323 swap 2
324 dup 4
325 mul
326 swap 2
327 // _ *x omega size log_2_size (3*m) outer_count w_m k (w * w_m) *x[k+m] *x[k + j + 1]
328
329 recurse
330
331 // Last while-loop middle, k < size
332 {middle_loop}:
333 // _ *x omega size log_2_size m outer_count w_m k
334
335 dup 5
336 dup 1
337 lt
338 push 0
339 eq
340 skiz
341 return
342 // _ *x omega size log_2_size m outer_count w_m k
343
344 push 1
345 // _ *x omega size log_2_size m outer_count w_m k w
346
347 dup 8
348 // _ *x omega size log_2_size m outer_count w_m k w *x
349
350 dup 2
351 dup 6
352 add
353 // _ *x omega size log_2_size m outer_count w_m k w *x (k + m)
354
355 push 3
356 mul
357 add
358 push 1
359 add
360 // _ *x omega size log_2_size m outer_count w_m k w *x[k+m]
361
362 dup 9
363 dup 3
364 push 3
365 mul
366 add
367 push 1
368 add
369 // _ *x omega size log_2_size m outer_count w_m k w *x[k+m] *x[k+j]
370
371 // `m` -> `3 * m` for fewer clock cycles in busy-loop
372 swap 6
373 push 3
374 mul
375 swap 6
376 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *x[k+j]
377
378 call {inner_loop}
379 // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *x[k+j]
380
381 // Undo `3*` transformation
382 // `3 * m` -> `m`
383 swap 6
384 push {THREE_INV}
385 mul
386 swap 6
387
388 pop 3
389 // _ *x omega size log_2_size m outer_count w_m k
390
391 dup 3
392 // _ *x omega size log_2_size m outer_count w_m k m
393
394 push 2
395 mul
396 // _ *x omega size log_2_size m outer_count w_m k (m * 2)
397
398 add
399 // _ *x omega size log_2_size m outer_count w_m (k + (m * 2))
400
401 recurse
402
403 // Last while-loop outer
404 {outer_loop}:
405 // _ *x omega size log_2_size m outer_count
406
407 dup 0
408 dup 3
409 eq
410 skiz
411 return
412 // _ *x omega size log_2_size m outer_count
413
414 dup 4
415 // _ *x omega size log_2_size m outer_count omega
416
417 dup 4
418 // _ *x omega size log_2_size m outer_count omega size
419
420 push 2
421 // _ *x omega size log_2_size m outer_count omega size 2
422
423 dup 4
424 mul
425 // _ *x omega size log_2_size m outer_count omega size (2 * m)
426
427 swap 1
428 div_mod
429 pop 1
430 // _ *x omega size log_2_size m outer_count omega (size / (2 * m))
431
432 swap 1
433 pow
434 // _ *x omega size log_2_size m outer_count (omega ** (size / (2 * m)))
435 // _ *x omega size log_2_size m outer_count w_m
436
437 push 0
438 // _ *x omega size log_2_size m outer_count w_m k
439
440 call {middle_loop}
441 // _ *x omega size log_2_size m outer_count w_m k
442
443 swap 3
444 // _ *x omega size log_2_size k outer_count w_m m
445
446 push 2
447 mul
448 // _ *x omega size log_2_size k outer_count w_m (m * 2)
449
450 swap 3
451 // _ *x omega size log_2_size (m * 2) outer_count w_m k
452
453 pop 2
454 // _ *x omega size log_2_size (m * 2) outer_count
455
456 push 1 add
457
458 recurse
459 )
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use twenty_first::math::ntt::ntt;
466 use twenty_first::math::traits::PrimitiveRootOfUnity;
467
468 use super::*;
469 use crate::empty_stack;
470 use crate::test_helpers::rust_final_state;
471 use crate::test_helpers::tasm_final_state;
472 use crate::test_helpers::verify_stack_equivalence;
473 use crate::test_helpers::verify_stack_growth;
474 use crate::test_prelude::*;
475
476 impl Function for XfeNtt {
477 fn rust_shadow(
478 &self,
479 stack: &mut Vec<BFieldElement>,
480 memory: &mut HashMap<BFieldElement, BFieldElement>,
481 ) {
482 let _root_of_unity = stack.pop().unwrap();
483 let input_pointer = stack.pop().unwrap();
484
485 let mut vector =
486 *Vec::<XFieldElement>::decode_from_memory(memory, input_pointer).unwrap();
487 ntt(&mut vector);
488
489 encode_to_memory(memory, input_pointer, &vector);
490 }
491
492 fn pseudorandom_initial_state(
493 &self,
494 seed: [u8; 32],
495 bench_case: Option<BenchmarkCase>,
496 ) -> FunctionInitialState {
497 let mut rng = StdRng::from_seed(seed);
498 let n = match bench_case {
499 Some(BenchmarkCase::CommonCase) => 256,
500 Some(BenchmarkCase::WorstCase) => 512,
501 None => 1 << rng.random_range(1..=9),
502 };
503 let vector = (0..n).map(|_| rng.random()).collect::<Vec<XFieldElement>>();
504
505 let mut stack = empty_stack();
506 let mut memory = HashMap::new();
507
508 let vector_pointer = BFieldElement::new(100);
509 encode_to_memory(&mut memory, vector_pointer, &vector);
510 stack.push(vector_pointer);
511 stack.push(BFieldElement::primitive_root_of_unity(n as u64).unwrap());
512
513 FunctionInitialState { stack, memory }
514 }
515 }
516
517 #[test]
518 fn test() {
519 let function = ShadowedFunction::new(XfeNtt);
520 let num_states = 5;
521 let mut rng = rand::rng();
522
523 for _ in 0..num_states {
524 let seed: [u8; 32] = rng.random();
525 let FunctionInitialState { stack, memory } =
526 XfeNtt.pseudorandom_initial_state(seed, None);
527 let vector_address = stack[stack.len() - 2];
528
529 let stdin = vec![];
530
531 let init_stack = stack.to_vec();
532 let nondeterminism = NonDeterminism::default().with_ram(memory);
533
534 let rust = rust_final_state(&function, &stack, &stdin, &nondeterminism, &None);
535
536 // run tvm
537 let tasm = tasm_final_state(&function, &stack, &stdin, nondeterminism, &None);
538
539 assert_eq!(
540 rust.public_output, tasm.public_output,
541 "Rust shadowing and VM std out must agree"
542 );
543
544 let len = 16;
545 verify_stack_equivalence(
546 "Rust-shadow",
547 &rust.stack[0..len - 1],
548 "TASM execution",
549 &tasm.op_stack.stack[0..len - 1],
550 );
551 verify_stack_growth(&function, &init_stack, &tasm.op_stack.stack);
552
553 // read out the output vectors and test agreement
554 let rust_result =
555 *Vec::<XFieldElement>::decode_from_memory(&rust.ram, vector_address).unwrap();
556 let tasm_result =
557 *Vec::<XFieldElement>::decode_from_memory(&tasm.ram, vector_address).unwrap();
558 assert_eq!(
559 rust_result,
560 tasm_result,
561 "\nrust: {}\ntasm: {}",
562 rust_result.iter().join(" | "),
563 tasm_result.iter().join(" | ")
564 );
565
566 println!(
567 "tasm stack: {}",
568 tasm.op_stack.stack.iter().skip(16).join(",")
569 );
570 println!("rust stack: {}", rust.stack.iter().skip(16).join(","));
571 }
572 }
573}
574
575#[cfg(test)]
576mod benches {
577 use super::*;
578 use crate::test_prelude::*;
579
580 #[test]
581 fn benchmark() {
582 ShadowedFunction::new(XfeNtt).bench();
583 }
584}