twenty_first/math/
traits.rs1use std::fmt::Debug;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::ops::Add;
5use std::ops::AddAssign;
6use std::ops::Div;
7use std::ops::Mul;
8use std::ops::MulAssign;
9use std::ops::Neg;
10use std::ops::Sub;
11use std::ops::SubAssign;
12
13use num_traits::ConstOne;
14use num_traits::ConstZero;
15use num_traits::Zero;
16use serde::Serialize;
17use serde::de::DeserializeOwned;
18
19pub trait CyclicGroupGenerator
20where
21 Self: Sized,
22{
23 fn get_cyclic_group_elements(&self, max: Option<usize>) -> Vec<Self>;
24}
25
26pub trait Inverse
28where
29 Self: Sized + Zero,
30{
31 fn inverse(&self) -> Self;
38
39 fn inverse_or_zero(&self) -> Self {
40 if self.is_zero() {
41 Self::zero()
42 } else {
43 self.inverse()
44 }
45 }
46}
47
48pub trait PrimitiveRootOfUnity
49where
50 Self: Sized,
51{
52 fn primitive_root_of_unity(n: u64) -> Option<Self>;
53}
54
55pub trait ModPowU64 {
56 #[must_use]
57 fn mod_pow_u64(&self, pow: u64) -> Self;
58}
59
60pub trait ModPowU32 {
61 #[must_use]
62 fn mod_pow_u32(&self, exp: u32) -> Self;
63}
64
65pub trait FiniteField:
66 Copy
67 + Debug
68 + Display
69 + Eq
70 + Serialize
71 + DeserializeOwned
72 + Hash
73 + ConstZero
74 + ConstOne
75 + Add<Output = Self>
76 + Mul<Output = Self>
77 + Sub<Output = Self>
78 + Div<Output = Self>
79 + Neg<Output = Self>
80 + AddAssign
81 + MulAssign
82 + SubAssign
83 + CyclicGroupGenerator
84 + PrimitiveRootOfUnity
85 + Inverse
86 + ModPowU32
87 + From<u64>
88 + Send
89 + Sync
90{
91 fn batch_inversion(input: Vec<Self>) -> Vec<Self> {
94 let input_length = input.len();
95 if input_length == 0 {
96 return Vec::<Self>::new();
97 }
98
99 let zero = Self::zero();
100 let one = Self::one();
101 let mut scratch: Vec<Self> = vec![zero; input_length];
102 let mut acc = one;
103 scratch[0] = input[0];
104
105 for i in 0..input_length {
106 assert!(!input[i].is_zero(), "Cannot do batch inversion on zero");
107 scratch[i] = acc;
108 acc *= input[i];
109 }
110
111 acc = acc.inverse();
112
113 let mut res = input;
114 for i in (0..input_length).rev() {
115 let tmp = acc * res[i];
116 res[i] = acc * scratch[i];
117 acc = tmp;
118 }
119
120 res
121 }
122
123 #[inline(always)]
124 fn square(self) -> Self {
125 self * self
126 }
127}