1use super::traits::{
9 BitManipulation, ConvertFrom, Converter, Modulo, NumberTests, Samplable, ZeroizeBN, EGCD,
10};
11use super::BigInt;
12use super::HexError;
13use getrandom::getrandom;
14use gmp::mpz::Mpz;
15
16use std::borrow::Borrow;
17use std::sync::atomic;
18
19impl ZeroizeBN for Mpz {
20 fn zeroize_bn(&mut self) {
21 drop(self);
22 atomic::fence(atomic::Ordering::SeqCst);
23 atomic::compiler_fence(atomic::Ordering::SeqCst);
24 }
25}
26
27impl Converter for Mpz {
28 fn to_vec(value: &Mpz) -> Vec<u8> {
29 let bytes: Vec<u8> = value.borrow().into();
30 bytes
31 }
32
33 fn to_hex(&self) -> String {
34 self.to_str_radix(16)
35 }
36
37 fn from_hex(value: &str) -> Result<Mpz, HexError> {
38 BigInt::from_str_radix(value, 16)
39 }
40
41 fn from_bytes(bytes: &[u8]) -> Mpz {
42 BigInt::from(bytes)
43 }
44
45 fn to_bytes(&self) -> Vec<u8> {
46 self.into()
47 }
48}
49
50impl Modulo for Mpz {
51 fn mod_pow(base: &Self, exponent: &Self, modulus: &Self) -> Self {
52 base.powm(exponent, modulus)
53 }
54
55 fn mod_mul(a: &Self, b: &Self, modulus: &Self) -> Self {
56 (a.mod_floor(modulus) * b.mod_floor(modulus)).mod_floor(modulus)
57 }
58
59 fn mod_sub(a: &Self, b: &Self, modulus: &Self) -> Self {
60 let a_m = a.mod_floor(modulus);
61 let b_m = b.mod_floor(modulus);
62
63 let sub_op = a_m - b_m + modulus;
64 sub_op.mod_floor(modulus)
65 }
66
67 fn mod_add(a: &Self, b: &Self, modulus: &Self) -> Self {
68 (a.mod_floor(modulus) + b.mod_floor(modulus)).mod_floor(modulus)
69 }
70
71 fn mod_inv(a: &Self, modulus: &Self) -> Self {
72 a.invert(modulus).unwrap()
73 }
74}
75
76impl Samplable for Mpz {
77 fn sample_below(upper: &Self) -> Self {
78 assert!(*upper > Mpz::zero());
79
80 let bits = upper.bit_length();
81 loop {
82 let n = Self::sample(bits);
83 if n < *upper {
84 return n;
85 }
86 }
87 }
88
89 fn sample_range(lower: &Self, upper: &Self) -> Self {
90 assert!(upper > lower);
91 lower + Self::sample_below(&(upper - lower))
92 }
93
94 fn strict_sample_range(lower: &Self, upper: &Self) -> Self {
95 assert!(upper > lower);
96 loop {
97 let n = lower + Self::sample_below(&(upper - lower));
98 if n > *lower && n < *upper {
99 return n;
100 }
101 }
102 }
103
104 fn sample(bit_size: usize) -> Self {
105 let bytes = (bit_size - 1) / 8 + 1;
106 let mut buf: Vec<u8> = vec![0; bytes];
107 getrandom(&mut buf).unwrap();
108 Self::from(&*buf) >> (bytes * 8 - bit_size)
109 }
110
111 fn strict_sample(bit_size: usize) -> Self {
112 loop {
113 let n = Self::sample(bit_size);
114 if n.bit_length() == bit_size {
115 return n;
116 }
117 }
118 }
119}
120
121impl NumberTests for Mpz {
122 fn is_zero(me: &Self) -> bool {
123 me.is_zero()
124 }
125 fn is_even(me: &Self) -> bool {
126 me.is_multiple_of(&Mpz::from(2))
127 }
128 fn is_negative(me: &Self) -> bool {
129 *me < Mpz::from(0)
130 }
131 fn bits(me: &Self) -> usize {
132 me.bit_length()
133 }
134}
135
136impl EGCD for Mpz {
137 fn egcd(a: &Self, b: &Self) -> (Self, Self, Self) {
138 a.gcdext(b)
139 }
140}
141
142impl BitManipulation for Mpz {
143 fn set_bit(self: &mut Self, bit: usize, bit_val: bool) {
144 if bit_val {
145 self.setbit(bit);
146 } else {
147 self.clrbit(bit);
148 }
149 }
150
151 fn test_bit(self: &Self, bit: usize) -> bool {
152 self.tstbit(bit)
153 }
154}
155
156impl ConvertFrom<Mpz> for u64 {
157 fn _from(x: &Mpz) -> u64 {
158 let opt_x: Option<u64> = x.into();
159 opt_x.unwrap()
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::Converter;
166 use super::Modulo;
167 use super::Mpz;
168 use super::Samplable;
169
170 use std::cmp;
171
172 #[test]
173 #[should_panic]
174 fn sample_below_zero_test() {
175 Mpz::sample_below(&Mpz::from(-1));
176 }
177
178 #[test]
179 fn sample_below_test() {
180 let upper_bound = Mpz::from(10);
181
182 for _ in 1..100 {
183 let r = Mpz::sample_below(&upper_bound);
184 assert!(r < upper_bound);
185 }
186 }
187
188 #[test]
189 #[should_panic]
190 fn invalid_range_test() {
191 Mpz::sample_range(&Mpz::from(10), &Mpz::from(9));
192 }
193
194 #[test]
195 fn sample_range_test() {
196 let upper_bound = Mpz::from(10);
197 let lower_bound = Mpz::from(5);
198
199 for _ in 1..100 {
200 let r = Mpz::sample_range(&lower_bound, &upper_bound);
201 assert!(r < upper_bound && r >= lower_bound);
202 }
203 }
204
205 #[test]
206 fn strict_sample_range_test() {
207 let len = 249;
208
209 for _ in 1..100 {
210 let a = Mpz::sample(len);
211 let b = Mpz::sample(len);
212 let lower_bound = cmp::min(a.clone(), b.clone());
213 let upper_bound = cmp::max(a.clone(), b.clone());
214
215 let r = Mpz::strict_sample_range(&lower_bound, &upper_bound);
216 assert!(r < upper_bound && r >= lower_bound);
217 }
218 }
219
220 #[test]
221 fn strict_sample_test() {
222 let len = 249;
223
224 for _ in 1..100 {
225 let a = Mpz::strict_sample(len);
226 assert_eq!(a.bit_length(), len);
227 }
228 }
229
230 #[test]
232 fn test_mod_sub_modulo() {
233 let a = Mpz::from(10);
234 let b = Mpz::from(5);
235 let modulo = Mpz::from(3);
236 let res = Mpz::from(2);
237 assert_eq!(res, Mpz::mod_sub(&a, &b, &modulo));
238 }
239
240 #[test]
242 fn test_mod_sub_negative_modulo() {
243 let a = Mpz::from(5);
244 let b = Mpz::from(10);
245 let modulo = Mpz::from(3);
246 let res = Mpz::from(1);
247 assert_eq!(res, Mpz::mod_sub(&a, &b, &modulo));
248 }
249
250 #[test]
251 fn test_mod_mul() {
252 let a = Mpz::from(4);
253 let b = Mpz::from(5);
254 let modulo = Mpz::from(3);
255 let res = Mpz::from(2);
256 assert_eq!(res, Mpz::mod_mul(&a, &b, &modulo));
257 }
258
259 #[test]
260 fn test_mod_pow() {
261 let a = Mpz::from(2);
262 let b = Mpz::from(3);
263 let modulo = Mpz::from(3);
264 let res = Mpz::from(2);
265 assert_eq!(res, Mpz::mod_pow(&a, &b, &modulo));
266 }
267
268 #[test]
269 fn test_to_hex() {
270 let b = Mpz::from(11);
271 assert_eq!("b", b.to_hex());
272 }
273
274 #[test]
275 fn test_from_hex() {
276 let a = Mpz::from(11);
277 assert_eq!(Mpz::from_hex(&a.to_hex()).unwrap(), a);
278 }
279}