tasm_lib/arithmetic/bfe/
primitive_root_of_unity.rs1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4use twenty_first::math::traits::PrimitiveRootOfUnity as PRU;
5
6use crate::prelude::*;
7use crate::traits::basic_snippet::Reviewer;
8use crate::traits::basic_snippet::SignOffFingerprint;
9
10#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
24pub struct PrimitiveRootOfUnity;
25
26impl BasicSnippet for PrimitiveRootOfUnity {
27 fn parameters(&self) -> Vec<(DataType, String)> {
28 vec![(DataType::U64, "order".to_owned())]
29 }
30
31 fn return_values(&self) -> Vec<(DataType, String)> {
32 vec![(DataType::Bfe, "root_of_unity".to_string())]
33 }
34
35 fn entrypoint(&self) -> String {
36 "tasmlib_arithmetic_bfe_primitive_root_of_unity".to_string()
37 }
38
39 fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
40 let root_of_pow = |pow: u64| BFieldElement::primitive_root_of_unity(1 << pow).unwrap();
41
42 triton_asm!(
43 {self.entrypoint()}:
44 dup 0
49 split
50 pop 1
51 push 0
52 eq
53 assert error_id 142
54
55 dup 1
58 push 1
59 eq
60 dup 1
63 push 0
64 eq
65 mul
66 skiz
69 push {root_of_pow(32)}
70 dup 1
78 push 0
79 eq
80 assert error_id 140
81
82 dup 0 push 1 eq skiz push {root_of_pow(0)}
89 dup 0 push {1_u32 << 1} eq skiz push {root_of_pow(1)}
90 dup 0 push {1_u32 << 2} eq skiz push {root_of_pow(2)}
91 dup 0 push {1_u32 << 3} eq skiz push {root_of_pow(3)}
92 dup 0 push {1_u32 << 4} eq skiz push {root_of_pow(4)}
93 dup 0 push {1_u32 << 5} eq skiz push {root_of_pow(5)}
94 dup 0 push {1_u32 << 6} eq skiz push {root_of_pow(6)}
95 dup 0 push {1_u32 << 7} eq skiz push {root_of_pow(7)}
96 dup 0 push {1_u32 << 8} eq skiz push {root_of_pow(8)}
97 dup 0 push {1_u32 << 9} eq skiz push {root_of_pow(9)}
98 dup 0 push {1_u32 << 10} eq skiz push {root_of_pow(10)}
99 dup 0 push {1_u32 << 11} eq skiz push {root_of_pow(11)}
100 dup 0 push {1_u32 << 12} eq skiz push {root_of_pow(12)}
101 dup 0 push {1_u32 << 13} eq skiz push {root_of_pow(13)}
102 dup 0 push {1_u32 << 14} eq skiz push {root_of_pow(14)}
103 dup 0 push {1_u32 << 15} eq skiz push {root_of_pow(15)}
104 dup 0 push {1_u32 << 16} eq skiz push {root_of_pow(16)}
105 dup 0 push {1_u32 << 17} eq skiz push {root_of_pow(17)}
106 dup 0 push {1_u32 << 18} eq skiz push {root_of_pow(18)}
107 dup 0 push {1_u32 << 19} eq skiz push {root_of_pow(19)}
108 dup 0 push {1_u32 << 20} eq skiz push {root_of_pow(20)}
109 dup 0 push {1_u32 << 21} eq skiz push {root_of_pow(21)}
110 dup 0 push {1_u32 << 22} eq skiz push {root_of_pow(22)}
111 dup 0 push {1_u32 << 23} eq skiz push {root_of_pow(23)}
112 dup 0 push {1_u32 << 24} eq skiz push {root_of_pow(24)}
113 dup 0 push {1_u32 << 25} eq skiz push {root_of_pow(25)}
114 dup 0 push {1_u32 << 26} eq skiz push {root_of_pow(26)}
115 dup 0 push {1_u32 << 27} eq skiz push {root_of_pow(27)}
116 dup 0 push {1_u32 << 28} eq skiz push {root_of_pow(28)}
117 dup 0 push {1_u32 << 29} eq skiz push {root_of_pow(29)}
118 dup 0 push {1_u32 << 30} eq skiz push {root_of_pow(30)}
119 dup 0 push {1_u32 << 31} eq skiz push {root_of_pow(31)}
120
121 dup 0
131 push 1
132 eq
133 dup 1
141 split
142 pop 1
146 push 0
150 eq
151 push 0
152 eq
153 add
159 push 0
160 eq
161 push 0
162 eq
163 assert error_id 141
169 place 2
173 pop 2
174
175 return
176 )
177 }
178
179 fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
180 let mut sign_offs = HashMap::new();
181 sign_offs.insert(Reviewer("ferdinand"), 0x7e7f78606c82b7d1.into());
182
183 sign_offs
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use num_traits::Zero;
190
191 use super::*;
192 use crate::empty_stack;
193 use crate::test_prelude::*;
194
195 impl Closure for PrimitiveRootOfUnity {
196 type Args = u64;
197
198 fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) -> Result<(), RustShadowError> {
199 let order = pop_encodable::<Self::Args>(stack)?;
200 if order.is_zero() {
201 return Err(RustShadowError::Other);
203 }
204 let root_of_unity =
205 BFieldElement::primitive_root_of_unity(order).ok_or(RustShadowError::Other)?;
206 stack.push(root_of_unity);
207
208 Ok(())
209 }
210
211 fn pseudorandom_args(
212 &self,
213 seed: [u8; 32],
214 bench_case: Option<BenchmarkCase>,
215 ) -> Self::Args {
216 match bench_case {
217 Some(BenchmarkCase::CommonCase) => 1_u64 << 10,
218 Some(BenchmarkCase::WorstCase) => 1 << 32,
219 None => 1 << StdRng::from_seed(seed).random_range(1..=32),
220 }
221 }
222 }
223
224 #[macro_rules_attr::apply(test)]
225 fn primitive_root_of_order_2_pow_32_is_not_a_legal_order() {
226 let root = BFieldElement::primitive_root_of_unity(1 << 32).unwrap();
227
228 assert!(BFieldElement::primitive_root_of_unity(root.value()).is_none());
230 }
231
232 #[macro_rules_attr::apply(test)]
233 fn all_primitive_roots_are_either_1_or_larger_than_u32_max() {
234 for pow in 1..=32 {
235 let root = BFieldElement::primitive_root_of_unity(1 << pow)
236 .unwrap()
237 .value();
238
239 assert!(root == 1 || root > u64::from(u32::MAX));
241 }
242 }
243
244 #[macro_rules_attr::apply(test)]
245 fn primitive_root_of_unity_pbt() {
246 ShadowedClosure::new(PrimitiveRootOfUnity).test()
247 }
248
249 #[macro_rules_attr::apply(test)]
250 fn primitive_root_of_unity_unit_test() {
251 for log2_order in 1..=32 {
252 let order = 1u64 << log2_order;
253 let mut init_stack = empty_stack();
254 for elem in order.encode().iter().rev() {
255 init_stack.push(*elem);
256 }
257
258 let expected = BFieldElement::primitive_root_of_unity(order).unwrap();
259 let expected_final_stack = [empty_stack(), vec![expected]].concat();
260 let _vm_output_state = test_rust_equivalence_given_complete_state(
261 &ShadowedClosure::new(PrimitiveRootOfUnity),
262 &init_stack,
263 &[],
264 &NonDeterminism::default(),
265 &None,
266 Some(&expected_final_stack),
267 );
268 }
269 }
270
271 #[macro_rules_attr::apply(test)]
272 fn primitive_root_negative_test() {
273 let small_non_powers_of_two = (0_u64..100).filter(|x| !x.is_power_of_two());
274 let larger_non_powers_of_two = (1_u64..50).map(|x| (1 << 32) - x);
275 let too_large_powers_of_two = (33..64).map(|x| 1_u64 << x);
276
277 for order in small_non_powers_of_two
278 .chain(larger_non_powers_of_two)
279 .chain(too_large_powers_of_two)
280 {
281 dbg!(order);
282 let mut init_stack = empty_stack();
283 init_stack.extend(order.encode().iter().rev());
284
285 test_assertion_failure(
286 &ShadowedClosure::new(PrimitiveRootOfUnity),
287 InitVmState::with_stack(init_stack),
288 &[140, 141],
289 );
290 }
291 }
292
293 #[macro_rules_attr::apply(proptest)]
294 fn triton_vm_crashes_if_order_lo_is_not_u32(
295 #[strategy(1_u8..=32)] log_2_order: u8,
296 #[strategy(0..=u32::MAX)]
297 #[map(u64::from)]
298 noise: u64,
299 ) {
300 let [mut order_lo, order_hi] = (1_u64 << log_2_order).encode()[..] else {
301 unreachable!()
302 };
303 order_lo += bfe!(noise << 32);
304 prop_assume!((order_lo.value() >> 32) == noise); test_assertion_failure(
307 &ShadowedClosure::new(PrimitiveRootOfUnity),
308 InitVmState::with_stack([empty_stack(), vec![order_hi, order_lo]].concat()),
309 &[142],
310 );
311 }
312}
313
314#[cfg(test)]
315mod benches {
316 use super::*;
317 use crate::test_prelude::*;
318
319 #[macro_rules_attr::apply(test)]
320 fn benchmark() {
321 ShadowedClosure::new(PrimitiveRootOfUnity).bench()
322 }
323}