rust_bigint/
big_gmp.rs

1//! Implementation of provided traits for GMP (exposed by default)
2
3/*
4    GNU Multiple Precision Arithmetic Library (GMP) support
5    based on MIT-licensed https://github.com/KZen-networks/curv/blob/master/src/arithmetic/big_gmp.rs
6*/
7
8use super::traits::{
9    BitManipulation, ConvertFrom, Converter, Modulo, NumberTests, Samplable, ZeroizeBN, EGCD,
10};
11use super::BigInt;
12use super::HexError;
13use getrandom::getrandom;
14use gmp::mpz::Mpz;
15
16use std::borrow::Borrow;
17use std::sync::atomic;
18
19impl ZeroizeBN for Mpz {
20    fn zeroize_bn(&mut self) {
21        drop(self);
22        atomic::fence(atomic::Ordering::SeqCst);
23        atomic::compiler_fence(atomic::Ordering::SeqCst);
24    }
25}
26
27impl Converter for Mpz {
28    fn to_vec(value: &Mpz) -> Vec<u8> {
29        let bytes: Vec<u8> = value.borrow().into();
30        bytes
31    }
32
33    fn to_hex(&self) -> String {
34        self.to_str_radix(16)
35    }
36
37    fn from_hex(value: &str) -> Result<Mpz, HexError> {
38        BigInt::from_str_radix(value, 16)
39    }
40
41    fn from_bytes(bytes: &[u8]) -> Mpz {
42        BigInt::from(bytes)
43    }
44
45    fn to_bytes(&self) -> Vec<u8> {
46        self.into()
47    }
48}
49
50impl Modulo for Mpz {
51    fn mod_pow(base: &Self, exponent: &Self, modulus: &Self) -> Self {
52        base.powm(exponent, modulus)
53    }
54
55    fn mod_mul(a: &Self, b: &Self, modulus: &Self) -> Self {
56        (a.mod_floor(modulus) * b.mod_floor(modulus)).mod_floor(modulus)
57    }
58
59    fn mod_sub(a: &Self, b: &Self, modulus: &Self) -> Self {
60        let a_m = a.mod_floor(modulus);
61        let b_m = b.mod_floor(modulus);
62
63        let sub_op = a_m - b_m + modulus;
64        sub_op.mod_floor(modulus)
65    }
66
67    fn mod_add(a: &Self, b: &Self, modulus: &Self) -> Self {
68        (a.mod_floor(modulus) + b.mod_floor(modulus)).mod_floor(modulus)
69    }
70
71    fn mod_inv(a: &Self, modulus: &Self) -> Self {
72        a.invert(modulus).unwrap()
73    }
74}
75
76impl Samplable for Mpz {
77    fn sample_below(upper: &Self) -> Self {
78        assert!(*upper > Mpz::zero());
79
80        let bits = upper.bit_length();
81        loop {
82            let n = Self::sample(bits);
83            if n < *upper {
84                return n;
85            }
86        }
87    }
88
89    fn sample_range(lower: &Self, upper: &Self) -> Self {
90        assert!(upper > lower);
91        lower + Self::sample_below(&(upper - lower))
92    }
93
94    fn strict_sample_range(lower: &Self, upper: &Self) -> Self {
95        assert!(upper > lower);
96        loop {
97            let n = lower + Self::sample_below(&(upper - lower));
98            if n > *lower && n < *upper {
99                return n;
100            }
101        }
102    }
103
104    fn sample(bit_size: usize) -> Self {
105        let bytes = (bit_size - 1) / 8 + 1;
106        let mut buf: Vec<u8> = vec![0; bytes];
107        getrandom(&mut buf).unwrap();
108        Self::from(&*buf) >> (bytes * 8 - bit_size)
109    }
110
111    fn strict_sample(bit_size: usize) -> Self {
112        loop {
113            let n = Self::sample(bit_size);
114            if n.bit_length() == bit_size {
115                return n;
116            }
117        }
118    }
119}
120
121impl NumberTests for Mpz {
122    fn is_zero(me: &Self) -> bool {
123        me.is_zero()
124    }
125    fn is_even(me: &Self) -> bool {
126        me.is_multiple_of(&Mpz::from(2))
127    }
128    fn is_negative(me: &Self) -> bool {
129        *me < Mpz::from(0)
130    }
131    fn bits(me: &Self) -> usize {
132        me.bit_length()
133    }
134}
135
136impl EGCD for Mpz {
137    fn egcd(a: &Self, b: &Self) -> (Self, Self, Self) {
138        a.gcdext(b)
139    }
140}
141
142impl BitManipulation for Mpz {
143    fn set_bit(self: &mut Self, bit: usize, bit_val: bool) {
144        if bit_val {
145            self.setbit(bit);
146        } else {
147            self.clrbit(bit);
148        }
149    }
150
151    fn test_bit(self: &Self, bit: usize) -> bool {
152        self.tstbit(bit)
153    }
154}
155
156impl ConvertFrom<Mpz> for u64 {
157    fn _from(x: &Mpz) -> u64 {
158        let opt_x: Option<u64> = x.into();
159        opt_x.unwrap()
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::Converter;
166    use super::Modulo;
167    use super::Mpz;
168    use super::Samplable;
169
170    use std::cmp;
171
172    #[test]
173    #[should_panic]
174    fn sample_below_zero_test() {
175        Mpz::sample_below(&Mpz::from(-1));
176    }
177
178    #[test]
179    fn sample_below_test() {
180        let upper_bound = Mpz::from(10);
181
182        for _ in 1..100 {
183            let r = Mpz::sample_below(&upper_bound);
184            assert!(r < upper_bound);
185        }
186    }
187
188    #[test]
189    #[should_panic]
190    fn invalid_range_test() {
191        Mpz::sample_range(&Mpz::from(10), &Mpz::from(9));
192    }
193
194    #[test]
195    fn sample_range_test() {
196        let upper_bound = Mpz::from(10);
197        let lower_bound = Mpz::from(5);
198
199        for _ in 1..100 {
200            let r = Mpz::sample_range(&lower_bound, &upper_bound);
201            assert!(r < upper_bound && r >= lower_bound);
202        }
203    }
204
205    #[test]
206    fn strict_sample_range_test() {
207        let len = 249;
208
209        for _ in 1..100 {
210            let a = Mpz::sample(len);
211            let b = Mpz::sample(len);
212            let lower_bound = cmp::min(a.clone(), b.clone());
213            let upper_bound = cmp::max(a.clone(), b.clone());
214
215            let r = Mpz::strict_sample_range(&lower_bound, &upper_bound);
216            assert!(r < upper_bound && r >= lower_bound);
217        }
218    }
219
220    #[test]
221    fn strict_sample_test() {
222        let len = 249;
223
224        for _ in 1..100 {
225            let a = Mpz::strict_sample(len);
226            assert_eq!(a.bit_length(), len);
227        }
228    }
229
230    //test mod_sub: a-b mod n where a-b >0
231    #[test]
232    fn test_mod_sub_modulo() {
233        let a = Mpz::from(10);
234        let b = Mpz::from(5);
235        let modulo = Mpz::from(3);
236        let res = Mpz::from(2);
237        assert_eq!(res, Mpz::mod_sub(&a, &b, &modulo));
238    }
239
240    //test mod_sub: a-b mod n where a-b <0
241    #[test]
242    fn test_mod_sub_negative_modulo() {
243        let a = Mpz::from(5);
244        let b = Mpz::from(10);
245        let modulo = Mpz::from(3);
246        let res = Mpz::from(1);
247        assert_eq!(res, Mpz::mod_sub(&a, &b, &modulo));
248    }
249
250    #[test]
251    fn test_mod_mul() {
252        let a = Mpz::from(4);
253        let b = Mpz::from(5);
254        let modulo = Mpz::from(3);
255        let res = Mpz::from(2);
256        assert_eq!(res, Mpz::mod_mul(&a, &b, &modulo));
257    }
258
259    #[test]
260    fn test_mod_pow() {
261        let a = Mpz::from(2);
262        let b = Mpz::from(3);
263        let modulo = Mpz::from(3);
264        let res = Mpz::from(2);
265        assert_eq!(res, Mpz::mod_pow(&a, &b, &modulo));
266    }
267
268    #[test]
269    fn test_to_hex() {
270        let b = Mpz::from(11);
271        assert_eq!("b", b.to_hex());
272    }
273
274    #[test]
275    fn test_from_hex() {
276        let a = Mpz::from(11);
277        assert_eq!(Mpz::from_hex(&a.to_hex()).unwrap(), a);
278    }
279}