tasm_lib/verifier/
xfe_ntt.rs

1use triton_vm::prelude::*;
2
3use crate::prelude::*;
4
5#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
6pub struct XfeNtt;
7
8impl BasicSnippet for XfeNtt {
9    fn inputs(&self) -> Vec<(DataType, String)> {
10        vec![
11            (DataType::List(Box::new(DataType::Xfe)), "x".to_owned()),
12            (DataType::Bfe, "omega".to_owned()),
13        ]
14    }
15
16    fn outputs(&self) -> Vec<(DataType, String)> {
17        vec![(DataType::Tuple(vec![]), "result".to_owned())]
18    }
19
20    fn entrypoint(&self) -> String {
21        "tasmlib_verifier_xfe_ntt".to_owned()
22    }
23
24    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
25        let entrypoint = self.entrypoint();
26        let tasm_arithmetic_u32_leadingzeros = library.import(Box::new(
27            crate::arithmetic::u32::leading_zeros::LeadingZeros,
28        ));
29        let tasm_list_length = library.import(Box::new(crate::list::length::Length));
30        const THREE_INV: BFieldElement = BFieldElement::new(12297829379609722881);
31
32        let while_loop_with_bitreverse = format!("{entrypoint}_while_with_bitreverse");
33        let outer_loop = format!("{entrypoint}_while_outer");
34        let middle_loop = format!("{entrypoint}_while_middle");
35        let inner_loop = format!("{entrypoint}_while_inner");
36        let bitreverse_function = format!("{entrypoint}_bitreverse_function");
37        let bitreverse_loop = format!("{entrypoint}_bitreverse_while");
38        let k_lt_r_then_branch = format!("{entrypoint}_k_lt_r_then_branch");
39        // _binop_Lt__LboolR_bool_74_while_loop
40
41        triton_asm!(
42
43        {entrypoint}:
44        // _ *x omega
45
46            dup 1
47            // _ *x omega *x
48
49            call {tasm_list_length}
50            // _ *x omega size
51
52            push 32
53            dup 1
54            call {tasm_arithmetic_u32_leadingzeros}
55            push -1
56            mul
57            add
58            push -1
59            add
60            push 0
61            // _ *x omega size log_2_size k
62
63            call {while_loop_with_bitreverse}
64            // _ *x omega size log_2_size k
65
66            pop 1
67            // _ *x omega size log_2_size
68
69            push 1
70            // _ *x omega size log_2_size m
71
72            push 0
73            // _ *x omega size log_2_size m outer_count
74
75            call {outer_loop}
76            pop 5
77            pop 1
78
79            return
80
81        // Subroutines:
82
83        // Invariant: n l r i
84        {bitreverse_loop}:
85            dup 0
86            dup 3
87            eq
88            skiz
89            return
90            // _ n l r i
91
92            swap 1
93            push 2
94            mul
95            // _ n l i (r * 2)
96
97            dup 3
98            // _ n l i (r * 2) n
99
100            push 1
101            and
102            // _ n l i (r * 2) (n & 1)
103                dup 1
104                dup 1
105                // _ n l i (r * 2) (n & 1) (r * 2) (n & 1)
106                xor
107                // _ n l i (r * 2) (n & 1) ((r * 2) ^ (n & 1))
108
109                swap 2
110                // _ n l i ((r * 2) ^ (n & 1)) (n & 1) (r * 2)
111
112                and
113                // _ n l i ((r * 2) ^ (n & 1)) ((n & 1) && (r * 2))
114
115                add
116                // _ n l i (((r * 2) ^ (n & 1)) + ((n & 1) && (r * 2)))
117                // _ n l i r'
118
119            swap 1
120            // _ n l r' i
121
122            push 2
123            dup 4
124            // _ n l r' i n
125
126            div_mod
127            pop 1
128            // _ n l r' i (n / 2)
129
130            swap 4
131            pop 1
132            // _ (n / 2) l r' i
133
134            push 1
135            add
136            // _ (n / 2) l r' i'
137
138            recurse
139
140        {bitreverse_function}:
141            // _ *x omega size log_2_size k n l
142
143            push 0
144            push 0
145            call {bitreverse_loop}
146            pop 1
147            swap 2
148            pop 2
149            return
150
151            // _ *x omega size log_2_size k rk
152        {k_lt_r_then_branch}:
153
154            dup 5
155            // _ *x omega size log_2_size k rk *x
156
157            dup 0
158            // _ *x omega size log_2_size k rk *x *x
159
160            swap 2
161            // _ *x omega size log_2_size k *x *x rk
162
163            push 3
164            mul
165            push 3
166            add
167            add
168            // _ *x omega size log_2_size k *x *(x[rk] + 2)
169
170            read_mem 3
171            // _ *x omega size log_2_size k *x [x[rk]] *(x[rk] - 1)
172
173            push 1
174            add
175            // _ *x omega size log_2_size k *x [x[rk]] *x[rk]
176
177            dup 5
178            // _ *x omega size log_2_size k *x [x[rk]] *x[rk] k
179
180            push 3
181            mul
182            push 3
183            add
184            // _ *x omega size log_2_size k *x [x[rk]] *x[rk] k_offset
185
186            dup 5
187            add
188            // _ *x omega size log_2_size k *x [x[rk]] *x[rk] *(x[k] + 2)
189
190            read_mem 3
191            // _ *x omega size log_2_size k *x [x[rk]] *x[rk] [x[k]] *(x[k] - 1)
192
193            push 1
194            add
195            // _ *x omega size log_2_size k *x [x[rk]] *x[rk] [x[k]] *x[k]
196
197            swap 4
198            // _ *x omega size log_2_size k *x [x[rk]] *x[k] [x[k]] *x[rk]
199
200            write_mem 3
201            pop 1
202            // _ *x omega size log_2_size k *x [x[rk]] *x[k]
203
204            write_mem 3
205            // _ *x omega size log_2_size k *x *(x[k] +3)
206
207            pop 1
208            // _ *x omega size log_2_size k *x
209
210        return
211
212        // 1st loop, where `bitreverse` is called
213        {while_loop_with_bitreverse}:
214        // _ *x omega size log_2_size k
215
216            dup 0
217            dup 3
218            eq
219            skiz
220            return
221            // _ *x omega size log_2_size k
222
223            dup 0
224            dup 2
225            // _ *x omega size log_2_size k k log_2_size
226            call {bitreverse_function}
227            // _ *x omega size log_2_size k rk
228
229            dup 0
230            dup 2
231            // _ *x omega size log_2_size k rk rk k
232
233            lt
234            // _ *x omega size log_2_size k rk (k < rk)
235
236            skiz
237            call {k_lt_r_then_branch}
238            // _ *x omega size log_2_size k (rk|*x)
239
240            pop 1
241            // _ *x omega size log_2_size k
242
243            push 1
244            add
245            // _ *x omega size log_2_size (k+1)
246
247            recurse
248
249        // Last while-loop, *inner*, `j != m` <-- The busy-loop!
250        {inner_loop}:
251            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *x[k+j]
252
253            dup 1
254            dup 1
255            eq
256            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *x[k + j] (j == m)
257            skiz
258            return
259            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *x[k + j]
260            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx
261
262            dup 0
263            push 2
264            add
265            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx *x[k + j]_last_word
266
267            read_mem 3
268            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [x[k + j]] *x[k + j - 1]_last_word
269
270            dup 10
271            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [x[k + j]] *x[k + j - 1]_last_word (3*m)
272
273            push 3
274            add
275            add
276            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [x[k + j]] *x[k + j + m]_last_word
277
278            read_mem 3
279            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [x[k+j]] [x[k+j+m]] *x[k+j+m-1]_last_word
280
281            pop 1
282            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [x[k+j]] [x[k+j+m]]
283            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [u]      [v]
284
285            dup 8
286            xb_mul
287            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [u] (v * w)
288            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [u] [v']
289
290            dup 5
291            dup 5
292            dup 5
293            dup 5
294            dup 5
295            dup 5
296            xx_add
297            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [u] [v'] [u + v']
298
299            dup 9
300            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [u] [v'] [u + v'] *x[k + j]
301
302            write_mem 3
303            pop 1
304            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [u] [v']
305
306            push -1
307            xb_mul
308            xx_add
309            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [u - v']
310
311            dup 3
312            dup 10
313            add
314            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx [u - v'] *x[k + j + m]
315
316            write_mem 3
317            pop 1
318            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *xx
319
320            push 3 add
321            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *x[k + j + 1]
322
323            swap 2
324            dup 4
325            mul
326            swap 2
327            // _ *x omega size log_2_size (3*m) outer_count w_m k (w * w_m) *x[k+m] *x[k + j + 1]
328
329            recurse
330
331        // Last while-loop middle, k < size
332        {middle_loop}:
333            // _ *x omega size log_2_size m outer_count w_m k
334
335            dup 5
336            dup 1
337            lt
338            push 0
339            eq
340            skiz
341            return
342            // _ *x omega size log_2_size m outer_count w_m k
343
344            push 1
345            // _ *x omega size log_2_size m outer_count w_m k w
346
347            dup 8
348            // _ *x omega size log_2_size m outer_count w_m k w *x
349
350            dup 2
351            dup 6
352            add
353            // _ *x omega size log_2_size m outer_count w_m k w *x (k + m)
354
355            push 3
356            mul
357            add
358            push 1
359            add
360            // _ *x omega size log_2_size m outer_count w_m k w *x[k+m]
361
362            dup 9
363            dup 3
364            push 3
365            mul
366            add
367            push 1
368            add
369            // _ *x omega size log_2_size m outer_count w_m k w *x[k+m] *x[k+j]
370
371            // `m` -> `3 * m` for fewer clock cycles in busy-loop
372            swap 6
373            push 3
374            mul
375            swap 6
376            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *x[k+j]
377
378            call {inner_loop}
379            // _ *x omega size log_2_size (3*m) outer_count w_m k w *x[k+m] *x[k+j]
380
381            // Undo `3*` transformation
382            // `3 * m` -> `m`
383            swap 6
384            push {THREE_INV}
385            mul
386            swap 6
387
388            pop 3
389            // _ *x omega size log_2_size m outer_count w_m k
390
391            dup 3
392            // _ *x omega size log_2_size m outer_count w_m k m
393
394            push 2
395            mul
396            // _ *x omega size log_2_size m outer_count w_m k (m * 2)
397
398            add
399            // _ *x omega size log_2_size m outer_count w_m (k + (m * 2))
400
401            recurse
402
403        // Last while-loop outer
404        {outer_loop}:
405            // _ *x omega size log_2_size m outer_count
406
407            dup 0
408            dup 3
409            eq
410            skiz
411            return
412            // _ *x omega size log_2_size m outer_count
413
414            dup 4
415            // _ *x omega size log_2_size m outer_count omega
416
417            dup 4
418            // _ *x omega size log_2_size m outer_count omega size
419
420            push 2
421            // _ *x omega size log_2_size m outer_count omega size 2
422
423            dup 4
424            mul
425            // _ *x omega size log_2_size m outer_count omega size (2 * m)
426
427            swap 1
428            div_mod
429            pop 1
430            // _ *x omega size log_2_size m outer_count omega (size / (2 * m))
431
432            swap 1
433            pow
434            // _ *x omega size log_2_size m outer_count (omega ** (size / (2 * m)))
435            // _ *x omega size log_2_size m outer_count w_m
436
437            push 0
438            // _ *x omega size log_2_size m outer_count w_m k
439
440            call {middle_loop}
441            // _ *x omega size log_2_size m outer_count w_m k
442
443            swap 3
444            // _ *x omega size log_2_size k outer_count w_m m
445
446            push 2
447            mul
448            // _ *x omega size log_2_size k outer_count w_m (m * 2)
449
450            swap 3
451            // _ *x omega size log_2_size (m * 2) outer_count w_m k
452
453            pop 2
454            // _ *x omega size log_2_size (m * 2) outer_count
455
456            push 1 add
457
458            recurse
459            )
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use twenty_first::math::ntt::ntt;
466    use twenty_first::math::traits::PrimitiveRootOfUnity;
467
468    use super::*;
469    use crate::empty_stack;
470    use crate::test_helpers::rust_final_state;
471    use crate::test_helpers::tasm_final_state;
472    use crate::test_helpers::verify_stack_equivalence;
473    use crate::test_helpers::verify_stack_growth;
474    use crate::test_prelude::*;
475
476    impl Function for XfeNtt {
477        fn rust_shadow(
478            &self,
479            stack: &mut Vec<BFieldElement>,
480            memory: &mut HashMap<BFieldElement, BFieldElement>,
481        ) {
482            let _root_of_unity = stack.pop().unwrap();
483            let input_pointer = stack.pop().unwrap();
484
485            let mut vector =
486                *Vec::<XFieldElement>::decode_from_memory(memory, input_pointer).unwrap();
487            ntt(&mut vector);
488
489            encode_to_memory(memory, input_pointer, &vector);
490        }
491
492        fn pseudorandom_initial_state(
493            &self,
494            seed: [u8; 32],
495            bench_case: Option<BenchmarkCase>,
496        ) -> FunctionInitialState {
497            let mut rng = StdRng::from_seed(seed);
498            let n = match bench_case {
499                Some(BenchmarkCase::CommonCase) => 256,
500                Some(BenchmarkCase::WorstCase) => 512,
501                None => 1 << rng.random_range(1..=9),
502            };
503            let vector = (0..n).map(|_| rng.random()).collect::<Vec<XFieldElement>>();
504
505            let mut stack = empty_stack();
506            let mut memory = HashMap::new();
507
508            let vector_pointer = BFieldElement::new(100);
509            encode_to_memory(&mut memory, vector_pointer, &vector);
510            stack.push(vector_pointer);
511            stack.push(BFieldElement::primitive_root_of_unity(n as u64).unwrap());
512
513            FunctionInitialState { stack, memory }
514        }
515    }
516
517    #[test]
518    fn test() {
519        let function = ShadowedFunction::new(XfeNtt);
520        let num_states = 5;
521        let mut rng = rand::rng();
522
523        for _ in 0..num_states {
524            let seed: [u8; 32] = rng.random();
525            let FunctionInitialState { stack, memory } =
526                XfeNtt.pseudorandom_initial_state(seed, None);
527            let vector_address = stack[stack.len() - 2];
528
529            let stdin = vec![];
530
531            let init_stack = stack.to_vec();
532            let nondeterminism = NonDeterminism::default().with_ram(memory);
533
534            let rust = rust_final_state(&function, &stack, &stdin, &nondeterminism, &None);
535
536            // run tvm
537            let tasm = tasm_final_state(&function, &stack, &stdin, nondeterminism, &None);
538
539            assert_eq!(
540                rust.public_output, tasm.public_output,
541                "Rust shadowing and VM std out must agree"
542            );
543
544            let len = 16;
545            verify_stack_equivalence(
546                "Rust-shadow",
547                &rust.stack[0..len - 1],
548                "TASM execution",
549                &tasm.op_stack.stack[0..len - 1],
550            );
551            verify_stack_growth(&function, &init_stack, &tasm.op_stack.stack);
552
553            // read out the output vectors and test agreement
554            let rust_result =
555                *Vec::<XFieldElement>::decode_from_memory(&rust.ram, vector_address).unwrap();
556            let tasm_result =
557                *Vec::<XFieldElement>::decode_from_memory(&tasm.ram, vector_address).unwrap();
558            assert_eq!(
559                rust_result,
560                tasm_result,
561                "\nrust: {}\ntasm: {}",
562                rust_result.iter().join(" | "),
563                tasm_result.iter().join(" | ")
564            );
565
566            println!(
567                "tasm stack: {}",
568                tasm.op_stack.stack.iter().skip(16).join(",")
569            );
570            println!("rust stack: {}", rust.stack.iter().skip(16).join(","));
571        }
572    }
573}
574
575#[cfg(test)]
576mod benches {
577    use super::*;
578    use crate::test_prelude::*;
579
580    #[test]
581    fn benchmark() {
582        ShadowedFunction::new(XfeNtt).bench();
583    }
584}