1use super::*;
17
18impl<E: Environment, I: IntegerType, M: Magnitude> Pow<Integer<E, M>> for Integer<E, I> {
19 type Output = Integer<E, I>;
20
21 #[inline]
23 fn pow(self, other: Integer<E, M>) -> Self::Output {
24 self.pow_checked(&other)
25 }
26}
27
28impl<E: Environment, I: IntegerType, M: Magnitude> Pow<&Integer<E, M>> for Integer<E, I> {
29 type Output = Integer<E, I>;
30
31 #[inline]
33 fn pow(self, other: &Integer<E, M>) -> Self::Output {
34 self.pow_checked(other)
35 }
36}
37
38impl<E: Environment, I: IntegerType, M: Magnitude> PowChecked<Integer<E, M>> for Integer<E, I> {
39 type Output = Self;
40
41 #[inline]
43 fn pow_checked(&self, other: &Integer<E, M>) -> Self::Output {
44 if self.is_constant() && other.is_constant() {
46 match self.eject_value().checked_pow(&other.eject_value().to_u32().unwrap()) {
49 Some(value) => Integer::new(Mode::Constant, console::Integer::new(value)),
50 None => E::halt("Integer overflow on exponentiation of two constants"),
51 }
52 } else {
53 let mut result = Self::one();
54
55 for bit in other.bits_le.iter().rev() {
60 result = result.mul_checked(&result);
61
62 let result_times_self = if I::is_signed() {
63 let (product, overflow) = Self::mul_with_flags(&(&result).abs_wrapped(), &self.abs_wrapped());
66
67 let operands_same_sign = &result.msb().is_equal(self.msb());
69 let positive_product_overflows = operands_same_sign & product.msb();
70
71 let negative_product_underflows = {
73 let lower_product_bits_nonzero = product.bits_le[..(I::BITS as usize - 1)]
74 .iter()
75 .fold(Boolean::constant(false), |a, b| a | b);
76 let negative_product_lt_or_eq_signed_min =
77 !product.msb() | (product.msb() & !lower_product_bits_nonzero);
78 !operands_same_sign & !negative_product_lt_or_eq_signed_min
79 };
80
81 let overflow = overflow | positive_product_overflows | negative_product_underflows;
82 E::assert_eq(overflow & bit, E::zero());
83
84 Self::ternary(operands_same_sign, &product, &(!&product).add_wrapped(&Self::one()))
86 } else {
87 let (product, overflow) = Self::mul_with_flags(&result, self);
88
89 E::assert_eq(overflow & bit, E::zero());
91
92 product
94 };
95
96 result = Self::ternary(bit, &result_times_self, &result);
97 }
98 result
99 }
100 }
101}
102
103impl<E: Environment, I: IntegerType> Integer<E, I> {
104 #[inline]
107 fn mul_with_flags(this: &Integer<E, I>, that: &Integer<E, I>) -> (Integer<E, I>, Boolean<E>) {
108 if 2 * I::BITS < (E::BaseField::size_in_bits() - 1) as u64 {
110 let product: Integer<E, I> = witness!(|this, that| this.mul_wrapped(&that));
112
113 let computed_product = this.to_field() * that.to_field();
116 let witnessed_product = product.to_field();
117 let flag = computed_product.is_not_equal(&witnessed_product);
118
119 (product, flag)
121 }
122 else if (I::BITS + I::BITS / 2) < (E::BaseField::size_in_bits() - 1) as u64 {
124 let (product, z_1_upper_bits, z2) = Self::karatsuba_multiply(this, that);
126 let z_1_upper_field = Field::from_bits_le(&z_1_upper_bits);
128 let z_1_upper_field_plus_z_2 = &z_1_upper_field + &z2;
130 let flag = z_1_upper_field_plus_z_2.is_not_equal(&Field::zero());
131
132 (product, flag)
134 } else {
135 E::halt(format!("Multiplication of integers of size {} is not supported", I::BITS))
136 }
137 }
138}
139
140impl<E: Environment, I: IntegerType, M: Magnitude> Metrics<dyn PowChecked<Integer<E, M>, Output = Integer<E, I>>>
141 for Integer<E, I>
142{
143 type Case = (Mode, Mode, bool, bool);
144
145 fn count(case: &Self::Case) -> Count {
146 match (case.0, case.1) {
147 (Mode::Constant, Mode::Constant) => Count::is(I::BITS, 0, 0, 0),
148 (Mode::Constant, _) | (_, Mode::Constant) => {
149 let mul_count = count!(Integer<E, I>, MulWrapped<Integer<E, I>, Output=Integer<E, I>>, case);
150 (2 * M::BITS * mul_count) + Count::is(2 * I::BITS, 0, I::BITS, I::BITS)
151 }
152 (_, _) => {
153 let mul_count = count!(Integer<E, I>, MulWrapped<Integer<E, I>, Output=Integer<E, I>>, case);
154 (2 * M::BITS * mul_count) + Count::is(2 * I::BITS, 0, I::BITS, I::BITS)
155 }
156 }
157 }
158}
159
160impl<E: Environment, I: IntegerType, M: Magnitude> OutputMode<dyn PowChecked<Integer<E, M>, Output = Integer<E, I>>>
161 for Integer<E, I>
162{
163 type Case = (Mode, CircuitType<Integer<E, M>>);
164
165 fn output_mode(case: &Self::Case) -> Mode {
166 match (case.0, (case.1.mode(), &case.1)) {
167 (Mode::Constant, (Mode::Constant, _)) => Mode::Constant,
168 (Mode::Constant, (mode, _)) => match mode {
169 Mode::Constant => Mode::Constant,
170 _ => Mode::Private,
171 },
172 (_, (Mode::Constant, case)) => match case {
173 CircuitType::Constant(constant) => match constant.eject_value().is_zero() {
175 true => Mode::Constant,
176 false => Mode::Private,
177 },
178 _ => E::halt("The constant is required for the output mode of `pow_wrapped` with a constant."),
179 },
180 (_, _) => Mode::Private,
181 }
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use crate::test_utilities::*;
189 use snarkvm_circuit_environment::Circuit;
190
191 use std::{ops::RangeInclusive, panic::RefUnwindSafe};
192
193 const ITERATIONS: u64 = 4;
195
196 fn check_pow<I: IntegerType + RefUnwindSafe, M: Magnitude + RefUnwindSafe>(
197 name: &str,
198 first: console::Integer<<Circuit as Environment>::Network, I>,
199 second: console::Integer<<Circuit as Environment>::Network, M>,
200 mode_a: Mode,
201 mode_b: Mode,
202 ) {
203 let a = Integer::<Circuit, I>::new(mode_a, first);
204 let b = Integer::<Circuit, M>::new(mode_b, second);
205 match first.checked_pow(&second.to_u32().unwrap()) {
206 Some(expected) => Circuit::scope(name, || {
207 let candidate = a.pow_checked(&b);
208 assert_eq!(expected, *candidate.eject_value());
209 assert_eq!(console::Integer::new(expected), candidate.eject_value());
210 assert!(Circuit::is_satisfied_in_scope(), "(is_satisfied_in_scope)");
213 }),
214 None => {
215 match (mode_a, mode_b) {
216 (Mode::Constant, Mode::Constant) => check_operation_halts(&a, &b, Integer::pow_checked),
217 _ => Circuit::scope(name, || {
218 let _candidate = a.pow_checked(&b);
219 assert!(!Circuit::is_satisfied_in_scope(), "(!is_satisfied_in_scope)");
221 }),
222 }
223 }
224 }
225 Circuit::reset();
226 }
227
228 fn run_test<I: IntegerType + RefUnwindSafe, M: Magnitude + RefUnwindSafe>(mode_a: Mode, mode_b: Mode) {
229 let mut rng = TestRng::default();
230
231 for i in 0..ITERATIONS {
232 let first = Uniform::rand(&mut rng);
233 let second = Uniform::rand(&mut rng);
234
235 let name = format!("Pow: {mode_a} ** {mode_b} {i}");
236 check_pow::<I, M>(&name, first, second, mode_a, mode_b);
237
238 let name = format!("Pow Zero: {mode_a} ** {mode_b} {i}");
239 check_pow::<I, M>(&name, first, console::Integer::zero(), mode_a, mode_b);
240
241 let name = format!("Pow One: {mode_a} ** {mode_b} {i}");
242 check_pow::<I, M>(&name, first, console::Integer::one(), mode_a, mode_b);
243
244 let name = format!("Square: {mode_a} ** {mode_b} {i}");
246 check_pow::<I, M>(&name, first, console::Integer::one() + console::Integer::one(), mode_a, mode_b);
247
248 let name = format!("Cube: {mode_a} ** {mode_b} {i}");
250 check_pow::<I, M>(
251 &name,
252 first,
253 console::Integer::one() + console::Integer::one() + console::Integer::one(),
254 mode_a,
255 mode_b,
256 );
257 }
258
259 check_pow::<I, M>("MAX ** MAX", console::Integer::MAX, console::Integer::MAX, mode_a, mode_b);
261 check_pow::<I, M>("MIN ** 0", console::Integer::MIN, console::Integer::zero(), mode_a, mode_b);
262 check_pow::<I, M>("MAX ** 0", console::Integer::MAX, console::Integer::zero(), mode_a, mode_b);
263 check_pow::<I, M>("MIN ** 1", console::Integer::MIN, console::Integer::one(), mode_a, mode_b);
264 check_pow::<I, M>("MAX ** 1", console::Integer::MAX, console::Integer::one(), mode_a, mode_b);
265 }
266
267 fn run_exhaustive_test<I: IntegerType + RefUnwindSafe, M: Magnitude + RefUnwindSafe>(mode_a: Mode, mode_b: Mode)
268 where
269 RangeInclusive<I>: Iterator<Item = I>,
270 RangeInclusive<M>: Iterator<Item = M>,
271 {
272 for first in I::MIN..=I::MAX {
273 for second in M::MIN..=M::MAX {
274 let first = console::Integer::<_, I>::new(first);
275 let second = console::Integer::<_, M>::new(second);
276
277 let name = format!("Pow: ({first} ** {second})");
278 check_pow::<I, M>(&name, first, second, mode_a, mode_b);
279 }
280 }
281 }
282
283 test_integer_binary!(run_test, i8, u8, pow);
284 test_integer_binary!(run_test, i8, u16, pow);
285 test_integer_binary!(run_test, i8, u32, pow);
286
287 test_integer_binary!(run_test, i16, u8, pow);
288 test_integer_binary!(run_test, i16, u16, pow);
289 test_integer_binary!(run_test, i16, u32, pow);
290
291 test_integer_binary!(run_test, i32, u8, pow);
292 test_integer_binary!(run_test, i32, u16, pow);
293 test_integer_binary!(run_test, i32, u32, pow);
294
295 test_integer_binary!(run_test, i64, u8, pow);
296 test_integer_binary!(run_test, i64, u16, pow);
297 test_integer_binary!(run_test, i64, u32, pow);
298
299 test_integer_binary!(run_test, i128, u8, pow);
300 test_integer_binary!(run_test, i128, u16, pow);
301 test_integer_binary!(run_test, i128, u32, pow);
302
303 test_integer_binary!(run_test, u8, u8, pow);
304 test_integer_binary!(run_test, u8, u16, pow);
305 test_integer_binary!(run_test, u8, u32, pow);
306
307 test_integer_binary!(run_test, u16, u8, pow);
308 test_integer_binary!(run_test, u16, u16, pow);
309 test_integer_binary!(run_test, u16, u32, pow);
310
311 test_integer_binary!(run_test, u32, u8, pow);
312 test_integer_binary!(run_test, u32, u16, pow);
313 test_integer_binary!(run_test, u32, u32, pow);
314
315 test_integer_binary!(run_test, u64, u8, pow);
316 test_integer_binary!(run_test, u64, u16, pow);
317 test_integer_binary!(run_test, u64, u32, pow);
318
319 test_integer_binary!(run_test, u128, u8, pow);
320 test_integer_binary!(run_test, u128, u16, pow);
321 test_integer_binary!(run_test, u128, u32, pow);
322
323 test_integer_binary!(#[ignore], run_exhaustive_test, u8, u8, pow, exhaustive);
324 test_integer_binary!(#[ignore], run_exhaustive_test, i8, u8, pow, exhaustive);
325}