Skip to main content

tasm_lib/arithmetic/bfe/
primitive_root_of_unity.rs

1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4use twenty_first::math::traits::PrimitiveRootOfUnity as PRU;
5
6use crate::prelude::*;
7use crate::traits::basic_snippet::Reviewer;
8use crate::traits::basic_snippet::SignOffFingerprint;
9
10/// Fetch the primitive root of unity of the given order.
11///
12/// ### Pre-conditions
13///
14/// - the order is [encoded](BFieldCodec) correctly
15/// - the order is a power of two
16/// - the order is not 0
17/// - the order is less than or equal to 2^32
18///
19/// ### Post-conditions
20///
21/// - the root is a primitive root of the given order for the field with
22///   [`BFieldElement::P`] elements
23#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
24pub struct PrimitiveRootOfUnity;
25
26impl BasicSnippet for PrimitiveRootOfUnity {
27    fn parameters(&self) -> Vec<(DataType, String)> {
28        vec![(DataType::U64, "order".to_owned())]
29    }
30
31    fn return_values(&self) -> Vec<(DataType, String)> {
32        vec![(DataType::Bfe, "root_of_unity".to_string())]
33    }
34
35    fn entrypoint(&self) -> String {
36        "tasmlib_arithmetic_bfe_primitive_root_of_unity".to_string()
37    }
38
39    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
40        let root_of_pow = |pow: u64| BFieldElement::primitive_root_of_unity(1 << pow).unwrap();
41
42        triton_asm!(
43            {self.entrypoint()}:
44            // _ order_hi order_lo
45
46            /* Assert correct encoding of the input. `order_hi` is checked later. */
47
48            dup 0
49            split
50            pop 1
51            push 0
52            eq
53            assert error_id 142
54
55            /* check if order is 2^32, i.e., (order_hi, order_lo) == (1, 0) */
56
57            dup 1
58            push 1
59            eq
60            // _ order_hi order_lo (order_hi == 1)
61
62            dup 1
63            push 0
64            eq
65            mul
66            // _ order_hi order_lo (order_hi == 1 && order_lo == 0)
67
68            skiz
69                push {root_of_pow(32)}
70            // _ order_hi order_lo [root]
71
72            /* At this point, `st1` *must* be zero:
73             * if order == 2^32:      _ 1 0        root
74             * any other legal order: _ 0 order_lo
75             */
76
77            dup 1
78            push 0
79            eq
80            assert error_id 140
81
82            /* Now we only have to check `order_lo`. We can ignore `order_hi` as we've
83             * verified that it's 0 in case the order was not $1^{32}$.
84             * Furthermore, the primitive root of order 2^32 is not itself a legal order
85             * of some other primitive root.
86             */
87
88            dup 0 push 1             eq skiz push {root_of_pow(0)}
89            dup 0 push {1_u32 << 1}  eq skiz push {root_of_pow(1)}
90            dup 0 push {1_u32 << 2}  eq skiz push {root_of_pow(2)}
91            dup 0 push {1_u32 << 3}  eq skiz push {root_of_pow(3)}
92            dup 0 push {1_u32 << 4}  eq skiz push {root_of_pow(4)}
93            dup 0 push {1_u32 << 5}  eq skiz push {root_of_pow(5)}
94            dup 0 push {1_u32 << 6}  eq skiz push {root_of_pow(6)}
95            dup 0 push {1_u32 << 7}  eq skiz push {root_of_pow(7)}
96            dup 0 push {1_u32 << 8}  eq skiz push {root_of_pow(8)}
97            dup 0 push {1_u32 << 9}  eq skiz push {root_of_pow(9)}
98            dup 0 push {1_u32 << 10} eq skiz push {root_of_pow(10)}
99            dup 0 push {1_u32 << 11} eq skiz push {root_of_pow(11)}
100            dup 0 push {1_u32 << 12} eq skiz push {root_of_pow(12)}
101            dup 0 push {1_u32 << 13} eq skiz push {root_of_pow(13)}
102            dup 0 push {1_u32 << 14} eq skiz push {root_of_pow(14)}
103            dup 0 push {1_u32 << 15} eq skiz push {root_of_pow(15)}
104            dup 0 push {1_u32 << 16} eq skiz push {root_of_pow(16)}
105            dup 0 push {1_u32 << 17} eq skiz push {root_of_pow(17)}
106            dup 0 push {1_u32 << 18} eq skiz push {root_of_pow(18)}
107            dup 0 push {1_u32 << 19} eq skiz push {root_of_pow(19)}
108            dup 0 push {1_u32 << 20} eq skiz push {root_of_pow(20)}
109            dup 0 push {1_u32 << 21} eq skiz push {root_of_pow(21)}
110            dup 0 push {1_u32 << 22} eq skiz push {root_of_pow(22)}
111            dup 0 push {1_u32 << 23} eq skiz push {root_of_pow(23)}
112            dup 0 push {1_u32 << 24} eq skiz push {root_of_pow(24)}
113            dup 0 push {1_u32 << 25} eq skiz push {root_of_pow(25)}
114            dup 0 push {1_u32 << 26} eq skiz push {root_of_pow(26)}
115            dup 0 push {1_u32 << 27} eq skiz push {root_of_pow(27)}
116            dup 0 push {1_u32 << 28} eq skiz push {root_of_pow(28)}
117            dup 0 push {1_u32 << 29} eq skiz push {root_of_pow(29)}
118            dup 0 push {1_u32 << 30} eq skiz push {root_of_pow(30)}
119            dup 0 push {1_u32 << 31} eq skiz push {root_of_pow(31)}
120
121            /* Since all roots happen to be either 1 or larger than `u32::MAX`, we can
122             * test if the top element is a root or not. If this assumption
123             * were to change, VM execution would crash here, and tests would
124             * catch that.
125             */
126
127            // stack if result found:     _ order_hi order_lo root
128            // stack if result not found: _ order_hi order_lo
129
130            dup 0
131            push 1
132            eq
133            // Result found:     _ order_hi order_lo root (root == 1)
134            // Result not found: _ order_hi order_lo (order_lo == 1)
135            //      If order_lo is 1, a primitive root exists, i.e., a result was found. This
136            //      contradicts this case's assumption. Therefore, order_lo cannot be 1 here,
137            //      and the stack is:
138            //                   _ order_hi order_lo 0
139
140            dup 1
141            split
142            // Result found:     _ order_hi order_lo root (root == 1) root_hi root_lo
143            // Result not found: _ order_hi order_lo 0 0 order_lo
144
145            pop 1
146            // Result found:     _ order_hi order_lo root (root == 1) root_hi
147            // Result not found: _ order_hi order_lo 0 0
148
149            push 0
150            eq
151            push 0
152            eq
153            // Result found:     _ order_hi order_lo root (root == 1) (root_hi != 0)
154            // Result not found: _ order_hi order_lo 0 (0 != 0)
155            //                                         ~~~~~~~~
156            //                                           == 0
157
158            add
159            push 0
160            eq
161            push 0
162            eq
163            // Result found:     _ order_hi order_lo root ((root == 1) || (root_hi != 0))
164            // Result not found: _ order_hi order_lo (0 || 0)
165            //                                       ~~~~~~~~
166            //                                         == 0
167
168            assert error_id 141
169            // Result found:     _ order_hi order_lo root
170            // Result not found: VM crashed
171
172            place 2
173            pop 2
174
175            return
176        )
177    }
178
179    fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
180        let mut sign_offs = HashMap::new();
181        sign_offs.insert(Reviewer("ferdinand"), 0x7e7f78606c82b7d1.into());
182
183        sign_offs
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use num_traits::Zero;
190
191    use super::*;
192    use crate::empty_stack;
193    use crate::test_prelude::*;
194
195    impl Closure for PrimitiveRootOfUnity {
196        type Args = u64;
197
198        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) -> Result<(), RustShadowError> {
199            let order = pop_encodable::<Self::Args>(stack)?;
200            if order.is_zero() {
201                // no root of order 0 exists
202                return Err(RustShadowError::Other);
203            }
204            let root_of_unity =
205                BFieldElement::primitive_root_of_unity(order).ok_or(RustShadowError::Other)?;
206            stack.push(root_of_unity);
207
208            Ok(())
209        }
210
211        fn pseudorandom_args(
212            &self,
213            seed: [u8; 32],
214            bench_case: Option<BenchmarkCase>,
215        ) -> Self::Args {
216            match bench_case {
217                Some(BenchmarkCase::CommonCase) => 1_u64 << 10,
218                Some(BenchmarkCase::WorstCase) => 1 << 32,
219                None => 1 << StdRng::from_seed(seed).random_range(1..=32),
220            }
221        }
222    }
223
224    #[macro_rules_attr::apply(test)]
225    fn primitive_root_of_order_2_pow_32_is_not_a_legal_order() {
226        let root = BFieldElement::primitive_root_of_unity(1 << 32).unwrap();
227
228        // this assumption is made in the snippet
229        assert!(BFieldElement::primitive_root_of_unity(root.value()).is_none());
230    }
231
232    #[macro_rules_attr::apply(test)]
233    fn all_primitive_roots_are_either_1_or_larger_than_u32_max() {
234        for pow in 1..=32 {
235            let root = BFieldElement::primitive_root_of_unity(1 << pow)
236                .unwrap()
237                .value();
238
239            // this assumption is made in the snippet
240            assert!(root == 1 || root > u64::from(u32::MAX));
241        }
242    }
243
244    #[macro_rules_attr::apply(test)]
245    fn primitive_root_of_unity_pbt() {
246        ShadowedClosure::new(PrimitiveRootOfUnity).test()
247    }
248
249    #[macro_rules_attr::apply(test)]
250    fn primitive_root_of_unity_unit_test() {
251        for log2_order in 1..=32 {
252            let order = 1u64 << log2_order;
253            let mut init_stack = empty_stack();
254            for elem in order.encode().iter().rev() {
255                init_stack.push(*elem);
256            }
257
258            let expected = BFieldElement::primitive_root_of_unity(order).unwrap();
259            let expected_final_stack = [empty_stack(), vec![expected]].concat();
260            let _vm_output_state = test_rust_equivalence_given_complete_state(
261                &ShadowedClosure::new(PrimitiveRootOfUnity),
262                &init_stack,
263                &[],
264                &NonDeterminism::default(),
265                &None,
266                Some(&expected_final_stack),
267            );
268        }
269    }
270
271    #[macro_rules_attr::apply(test)]
272    fn primitive_root_negative_test() {
273        let small_non_powers_of_two = (0_u64..100).filter(|x| !x.is_power_of_two());
274        let larger_non_powers_of_two = (1_u64..50).map(|x| (1 << 32) - x);
275        let too_large_powers_of_two = (33..64).map(|x| 1_u64 << x);
276
277        for order in small_non_powers_of_two
278            .chain(larger_non_powers_of_two)
279            .chain(too_large_powers_of_two)
280        {
281            dbg!(order);
282            let mut init_stack = empty_stack();
283            init_stack.extend(order.encode().iter().rev());
284
285            test_assertion_failure(
286                &ShadowedClosure::new(PrimitiveRootOfUnity),
287                InitVmState::with_stack(init_stack),
288                &[140, 141],
289            );
290        }
291    }
292
293    #[macro_rules_attr::apply(proptest)]
294    fn triton_vm_crashes_if_order_lo_is_not_u32(
295        #[strategy(1_u8..=32)] log_2_order: u8,
296        #[strategy(0..=u32::MAX)]
297        #[map(u64::from)]
298        noise: u64,
299    ) {
300        let [mut order_lo, order_hi] = (1_u64 << log_2_order).encode()[..] else {
301            unreachable!()
302        };
303        order_lo += bfe!(noise << 32);
304        prop_assume!((order_lo.value() >> 32) == noise); // no finite-field wrap-around shenanigans
305
306        test_assertion_failure(
307            &ShadowedClosure::new(PrimitiveRootOfUnity),
308            InitVmState::with_stack([empty_stack(), vec![order_hi, order_lo]].concat()),
309            &[142],
310        );
311    }
312}
313
314#[cfg(test)]
315mod benches {
316    use super::*;
317    use crate::test_prelude::*;
318
319    #[macro_rules_attr::apply(test)]
320    fn benchmark() {
321        ShadowedClosure::new(PrimitiveRootOfUnity).bench()
322    }
323}