rusty_gadgets/
field_element.rs

1use std::fmt;
2use std::fmt::Formatter;
3use std::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Shl, Sub, SubAssign};
4use std::str::FromStr;
5
6use num::bigint::ParseBigIntError;
7use num::bigint::RandBigInt;
8use num::BigUint;
9use num_traits::One;
10use num_traits::Zero;
11use rand::Rng;
12
13#[derive(Clone, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
14pub struct FieldElement {
15    value: BigUint,
16}
17
18impl FieldElement {
19    /// The prime field size.
20    pub fn size() -> BigUint {
21        BigUint::from_str(
22            "21888242871839275222246405745257275088548364400416034343698204186575808495617").unwrap()
23    }
24
25    pub fn max_value() -> BigUint {
26        FieldElement::size() - BigUint::one()
27    }
28
29    /// The number of bits needed to encode every field element.
30    pub fn max_bits() -> usize {
31        FieldElement::max_value().bits()
32    }
33
34    pub fn zero() -> Self {
35        FieldElement::from(0)
36    }
37
38    pub fn one() -> Self {
39        FieldElement::from(1)
40    }
41
42    /// The additive inverse of 1.
43    pub fn neg_one() -> Self {
44        FieldElement::one().multiplicative_inverse()
45    }
46
47    /// Return a random field element, uniformly distributed in [0, size()).
48    pub fn random(rng: &mut impl Rng) -> Self {
49        loop {
50            let r = rng.gen_biguint(FieldElement::max_bits());
51            if r < FieldElement::size() {
52                return FieldElement::from(r);
53            }
54        }
55    }
56
57    pub fn is_zero(&self) -> bool {
58        self.value.is_zero()
59    }
60
61    pub fn is_nonzero(&self) -> bool {
62        !self.is_zero()
63    }
64
65    pub fn is_one(&self) -> bool {
66        self.value.is_one()
67    }
68
69    pub fn multiplicative_inverse(&self) -> FieldElement {
70        assert_ne!(*self, FieldElement::zero(), "Zero does not have a multiplicative inverse");
71        // From Euler's theorem.
72        // TODO: Use a faster method, like the one described in "Fast Modular Reciprocals".
73        // Or just wait for https://github.com/rust-num/num-bigint/issues/60
74        FieldElement::from(self.value.modpow(
75            &(FieldElement::size() - BigUint::from(2u128)),
76            &FieldElement::size()))
77    }
78
79    pub fn integer_division(&self, rhs: FieldElement) -> FieldElement {
80        FieldElement::from(self.value.clone() / rhs.value)
81    }
82
83    pub fn integer_modulus(&self, rhs: FieldElement) -> FieldElement {
84        FieldElement::from(self.value.clone() % rhs.value)
85    }
86
87    /// The number of bits needed to encode this particular field element.
88    pub fn bits(&self) -> usize {
89        self.value.bits()
90    }
91
92    /// Return the i'th least significant bit. So, for example, x.bit(0) returns the least
93    /// significant bit of x.
94    pub fn bit(&self, i: usize) -> bool {
95        ((self.value.clone() >> i) & BigUint::one()).is_one()
96    }
97}
98
99impl From<BigUint> for FieldElement {
100    fn from(value: BigUint) -> FieldElement {
101        assert!(value >= BigUint::zero());
102        assert!(value < FieldElement::size());
103        FieldElement { value }
104    }
105}
106
107impl From<u128> for FieldElement {
108    fn from(value: u128) -> FieldElement {
109        FieldElement { value: BigUint::from(value) }
110    }
111}
112
113impl From<bool> for FieldElement {
114    fn from(value: bool) -> FieldElement {
115        FieldElement::from(value as u128)
116    }
117}
118
119impl FromStr for FieldElement {
120    type Err = ParseBigIntError;
121
122    fn from_str(s: &str) -> Result<Self, Self::Err> {
123        BigUint::from_str(s).map(|n| FieldElement::from(n))
124    }
125}
126
127impl Neg for FieldElement {
128    type Output = FieldElement;
129
130    fn neg(self) -> FieldElement {
131        if self.is_zero() {
132            self
133        } else {
134            FieldElement::from(FieldElement::size() - self.value)
135        }
136    }
137}
138
139impl Add<FieldElement> for FieldElement {
140    type Output = FieldElement;
141
142    fn add(self, rhs: FieldElement) -> FieldElement {
143        FieldElement::from((self.value + rhs.value) % FieldElement::size())
144    }
145}
146
147impl AddAssign for FieldElement {
148    fn add_assign(&mut self, rhs: FieldElement) {
149        *self = self.clone() + rhs;
150    }
151}
152
153impl Sub<FieldElement> for FieldElement {
154    type Output = FieldElement;
155
156    fn sub(self, rhs: FieldElement) -> FieldElement {
157        self + -rhs
158    }
159}
160
161impl SubAssign for FieldElement {
162    fn sub_assign(&mut self, rhs: FieldElement) {
163        *self = self.clone() - rhs;
164    }
165}
166
167impl Mul<FieldElement> for FieldElement {
168    type Output = FieldElement;
169
170    fn mul(self, rhs: FieldElement) -> FieldElement {
171        FieldElement::from((self.value * rhs.value) % FieldElement::size())
172    }
173}
174
175impl Mul<u128> for FieldElement {
176    type Output = FieldElement;
177
178    fn mul(self, rhs: u128) -> FieldElement {
179        self * FieldElement::from(rhs)
180    }
181}
182
183impl MulAssign for FieldElement {
184    fn mul_assign(&mut self, rhs: FieldElement) {
185        *self = self.clone() * rhs;
186    }
187}
188
189impl MulAssign<u128> for FieldElement {
190    fn mul_assign(&mut self, rhs: u128) {
191        *self = self.clone() * rhs;
192    }
193}
194
195impl Div<FieldElement> for FieldElement {
196    type Output = FieldElement;
197
198    fn div(self, rhs: FieldElement) -> FieldElement {
199        self * rhs.multiplicative_inverse()
200    }
201}
202
203impl Shl<usize> for FieldElement {
204    type Output = FieldElement;
205
206    fn shl(self, rhs: usize) -> FieldElement {
207        FieldElement::from(self.value << rhs)
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use std::iter;
214    use std::str::FromStr;
215
216    use itertools::assert_equal;
217
218    use crate::field_element::FieldElement;
219
220    #[test]
221    fn addition() {
222        assert_eq!(
223            FieldElement::from(2),
224            FieldElement::one() + FieldElement::one());
225
226        assert_eq!(
227            FieldElement::from(33),
228            FieldElement::from(13) + FieldElement::from(20));
229    }
230
231    #[test]
232    fn addition_overflow() {
233        assert_eq!(
234            FieldElement::from_str("3").unwrap(),
235            FieldElement::from_str(
236                "21888242871839275222246405745257275088548364400416034343698204186575808495615"
237            ).unwrap() + FieldElement::from_str("5").unwrap());
238    }
239
240    #[test]
241    fn additive_inverse() {
242        assert_eq!(
243            FieldElement::from_str(
244                "21888242871839275222246405745257275088548364400416034343698204186575808495616"
245            ).unwrap(),
246            -FieldElement::one());
247
248        assert_eq!(
249            FieldElement::zero(),
250            FieldElement::from(123) + -FieldElement::from(123));
251    }
252
253    #[test]
254    fn multiplication_overflow() {
255        assert_eq!(
256            FieldElement::from_str(
257                "13869117166973684714533159833916213390696312133829829072325816326144232854527"
258            ).unwrap(),
259            FieldElement::from_str("1234567890123456789012345678901234567890").unwrap()
260                * FieldElement::from_str("1234567890123456789012345678901234567890").unwrap());
261    }
262
263    #[test]
264    fn bits_0() {
265        let x = FieldElement::from(0);
266        let n: usize = 300;
267        assert_equal(
268            iter::repeat(false).take(n),
269            (0..n).map(|i| x.bit(i)));
270    }
271
272    #[test]
273    fn bits_19() {
274        let x = FieldElement::from(19);
275        assert_eq!(true, x.bit(0));
276        assert_eq!(true, x.bit(1));
277        assert_eq!(false, x.bit(2));
278        assert_eq!(false, x.bit(3));
279        assert_eq!(true, x.bit(4));
280        assert_eq!(false, x.bit(5));
281        assert_eq!(false, x.bit(6));
282    }
283}
284
285impl fmt::Display for FieldElement {
286    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
287        // As a UX optimization, display "-1" for the largest field element.
288        let s = if self.is_one() {
289            "-1".to_string()
290        } else {
291            self.value.to_string()
292        };
293        write!(f, "{}", s)
294    }
295}