snarkvm_circuit_types_integers/
pow_checked.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::*;
17
18impl<E: Environment, I: IntegerType, M: Magnitude> Pow<Integer<E, M>> for Integer<E, I> {
19    type Output = Integer<E, I>;
20
21    /// Returns the `power` of `self` to the power of `other`.
22    #[inline]
23    fn pow(self, other: Integer<E, M>) -> Self::Output {
24        self.pow_checked(&other)
25    }
26}
27
28impl<E: Environment, I: IntegerType, M: Magnitude> Pow<&Integer<E, M>> for Integer<E, I> {
29    type Output = Integer<E, I>;
30
31    /// Returns the `power` of `self` to the power of `other`.
32    #[inline]
33    fn pow(self, other: &Integer<E, M>) -> Self::Output {
34        self.pow_checked(other)
35    }
36}
37
38impl<E: Environment, I: IntegerType, M: Magnitude> PowChecked<Integer<E, M>> for Integer<E, I> {
39    type Output = Self;
40
41    /// Returns the `power` of `self` to the power of `other`.
42    #[inline]
43    fn pow_checked(&self, other: &Integer<E, M>) -> Self::Output {
44        // Determine the variable mode.
45        if self.is_constant() && other.is_constant() {
46            // Compute the result and return the new constant.
47            // This cast is safe since `Magnitude`s can only be `u8`, `u16`, or `u32`.
48            match self.eject_value().checked_pow(&other.eject_value().to_u32().unwrap()) {
49                Some(value) => Integer::new(Mode::Constant, console::Integer::new(value)),
50                None => E::halt("Integer overflow on exponentiation of two constants"),
51            }
52        } else {
53            let mut result = Self::one();
54
55            // TODO (@pranav) In each step, we check that we have not overflowed,
56            //  yet we know that in the first step, we do not need to check and
57            //  in general we do not need to check for overflow until we have found
58            //  the second bit that has been set. Optimize.
59            for bit in other.bits_le.iter().rev() {
60                result = result.mul_checked(&result);
61
62                let result_times_self = if I::is_signed() {
63                    // Multiply the absolute value of `self` and `other` in the base field.
64                    // Note: it is safe to use `abs_wrapped` since we want `Integer::MIN` to be interpreted as an unsigned number.
65                    let (product, overflow) = Self::mul_with_flags(&(&result).abs_wrapped(), &self.abs_wrapped());
66
67                    // If the product should be positive, then it cannot exceed the signed maximum.
68                    let operands_same_sign = &result.msb().is_equal(self.msb());
69                    let positive_product_overflows = operands_same_sign & product.msb();
70
71                    // If the product should be negative, then it cannot exceed the absolute value of the signed minimum.
72                    let negative_product_underflows = {
73                        let lower_product_bits_nonzero = product.bits_le[..(I::BITS as usize - 1)]
74                            .iter()
75                            .fold(Boolean::constant(false), |a, b| a | b);
76                        let negative_product_lt_or_eq_signed_min =
77                            !product.msb() | (product.msb() & !lower_product_bits_nonzero);
78                        !operands_same_sign & !negative_product_lt_or_eq_signed_min
79                    };
80
81                    let overflow = overflow | positive_product_overflows | negative_product_underflows;
82                    E::assert_eq(overflow & bit, E::zero());
83
84                    // Return the product of `self` and `other` with the appropriate sign.
85                    Self::ternary(operands_same_sign, &product, &(!&product).add_wrapped(&Self::one()))
86                } else {
87                    let (product, overflow) = Self::mul_with_flags(&result, self);
88
89                    // For unsigned multiplication, check that the overflow flag is not set.
90                    E::assert_eq(overflow & bit, E::zero());
91
92                    // Return the product of `self` and `other`.
93                    product
94                };
95
96                result = Self::ternary(bit, &result_times_self, &result);
97            }
98            result
99        }
100    }
101}
102
103impl<E: Environment, I: IntegerType> Integer<E, I> {
104    /// Multiply the integer bits of `this` and `that`, returning a flag indicating whether the product overflowed.
105    /// This method assumes that the `this` and `that` are both positive.
106    #[inline]
107    fn mul_with_flags(this: &Integer<E, I>, that: &Integer<E, I>) -> (Integer<E, I>, Boolean<E>) {
108        // Case 1 - 2 integers fit in 1 field element (u8, u16, u32, u64, i8, i16, i32, i64).
109        if 2 * I::BITS < (E::BaseField::size_in_bits() - 1) as u64 {
110            // Instead of multiplying the bits of `self` and `other`, witness the integer product.
111            let product: Integer<E, I> = witness!(|this, that| this.mul_wrapped(&that));
112
113            // Check that the computed product is not equal to witnessed product, in the base field.
114            // Note: The multiplication is safe as the field twice as large as the maximum integer type supported.
115            let computed_product = this.to_field() * that.to_field();
116            let witnessed_product = product.to_field();
117            let flag = computed_product.is_not_equal(&witnessed_product);
118
119            // Return the product of `self` and `other` and the overflow flag.
120            (product, flag)
121        }
122        // Case 2 - 1.5 integers fit in 1 field element (u128, i128).
123        else if (I::BITS + I::BITS / 2) < (E::BaseField::size_in_bits() - 1) as u64 {
124            // Use Karatsuba multiplication to compute the product of `self` and `other` and the carry bits.
125            let (product, z_1_upper_bits, z2) = Self::karatsuba_multiply(this, that);
126            // Reconstruct the upper bits of z_1 in the field.
127            let z_1_upper_field = Field::from_bits_le(&z_1_upper_bits);
128            // Compute whether the sum of z_1_field and z_2 is zero.
129            let z_1_upper_field_plus_z_2 = &z_1_upper_field + &z2;
130            let flag = z_1_upper_field_plus_z_2.is_not_equal(&Field::zero());
131
132            // Return the product of `self` and `other` and the overflow flag.
133            (product, flag)
134        } else {
135            E::halt(format!("Multiplication of integers of size {} is not supported", I::BITS))
136        }
137    }
138}
139
140impl<E: Environment, I: IntegerType, M: Magnitude> Metrics<dyn PowChecked<Integer<E, M>, Output = Integer<E, I>>>
141    for Integer<E, I>
142{
143    type Case = (Mode, Mode, bool, bool);
144
145    fn count(case: &Self::Case) -> Count {
146        match (case.0, case.1) {
147            (Mode::Constant, Mode::Constant) => Count::is(I::BITS, 0, 0, 0),
148            (Mode::Constant, _) | (_, Mode::Constant) => {
149                let mul_count = count!(Integer<E, I>, MulWrapped<Integer<E, I>, Output=Integer<E, I>>, case);
150                (2 * M::BITS * mul_count) + Count::is(2 * I::BITS, 0, I::BITS, I::BITS)
151            }
152            (_, _) => {
153                let mul_count = count!(Integer<E, I>, MulWrapped<Integer<E, I>, Output=Integer<E, I>>, case);
154                (2 * M::BITS * mul_count) + Count::is(2 * I::BITS, 0, I::BITS, I::BITS)
155            }
156        }
157    }
158}
159
160impl<E: Environment, I: IntegerType, M: Magnitude> OutputMode<dyn PowChecked<Integer<E, M>, Output = Integer<E, I>>>
161    for Integer<E, I>
162{
163    type Case = (Mode, CircuitType<Integer<E, M>>);
164
165    fn output_mode(case: &Self::Case) -> Mode {
166        match (case.0, (case.1.mode(), &case.1)) {
167            (Mode::Constant, (Mode::Constant, _)) => Mode::Constant,
168            (Mode::Constant, (mode, _)) => match mode {
169                Mode::Constant => Mode::Constant,
170                _ => Mode::Private,
171            },
172            (_, (Mode::Constant, case)) => match case {
173                // Determine if the constant is all zeros.
174                CircuitType::Constant(constant) => match constant.eject_value().is_zero() {
175                    true => Mode::Constant,
176                    false => Mode::Private,
177                },
178                _ => E::halt("The constant is required for the output mode of `pow_wrapped` with a constant."),
179            },
180            (_, _) => Mode::Private,
181        }
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use crate::test_utilities::*;
189    use snarkvm_circuit_environment::Circuit;
190
191    use std::{ops::RangeInclusive, panic::RefUnwindSafe};
192
193    // Lowered to 4; we run (~5 * ITERATIONS) cases for most tests.
194    const ITERATIONS: u64 = 4;
195
196    fn check_pow<I: IntegerType + RefUnwindSafe, M: Magnitude + RefUnwindSafe>(
197        name: &str,
198        first: console::Integer<<Circuit as Environment>::Network, I>,
199        second: console::Integer<<Circuit as Environment>::Network, M>,
200        mode_a: Mode,
201        mode_b: Mode,
202    ) {
203        let a = Integer::<Circuit, I>::new(mode_a, first);
204        let b = Integer::<Circuit, M>::new(mode_b, second);
205        match first.checked_pow(&second.to_u32().unwrap()) {
206            Some(expected) => Circuit::scope(name, || {
207                let candidate = a.pow_checked(&b);
208                assert_eq!(expected, *candidate.eject_value());
209                assert_eq!(console::Integer::new(expected), candidate.eject_value());
210                // assert_count!(PowChecked(Integer<I>, Integer<M>) => Integer<I>, &(mode_a, mode_b));
211                // assert_output_mode!(PowChecked(Integer<I>, Integer<M>) => Integer<I>, &(mode_a, CircuitType::from(&b)), candidate);
212                assert!(Circuit::is_satisfied_in_scope(), "(is_satisfied_in_scope)");
213            }),
214            None => {
215                match (mode_a, mode_b) {
216                    (Mode::Constant, Mode::Constant) => check_operation_halts(&a, &b, Integer::pow_checked),
217                    _ => Circuit::scope(name, || {
218                        let _candidate = a.pow_checked(&b);
219                        // assert_count_fails!(PowChecked(Integer<I>, Integer<M>) => Integer<I>, &(mode_a, mode_b));A
220                        assert!(!Circuit::is_satisfied_in_scope(), "(!is_satisfied_in_scope)");
221                    }),
222                }
223            }
224        }
225        Circuit::reset();
226    }
227
228    fn run_test<I: IntegerType + RefUnwindSafe, M: Magnitude + RefUnwindSafe>(mode_a: Mode, mode_b: Mode) {
229        let mut rng = TestRng::default();
230
231        for i in 0..ITERATIONS {
232            let first = Uniform::rand(&mut rng);
233            let second = Uniform::rand(&mut rng);
234
235            let name = format!("Pow: {mode_a} ** {mode_b} {i}");
236            check_pow::<I, M>(&name, first, second, mode_a, mode_b);
237
238            let name = format!("Pow Zero: {mode_a} ** {mode_b} {i}");
239            check_pow::<I, M>(&name, first, console::Integer::zero(), mode_a, mode_b);
240
241            let name = format!("Pow One: {mode_a} ** {mode_b} {i}");
242            check_pow::<I, M>(&name, first, console::Integer::one(), mode_a, mode_b);
243
244            // Check that the square is computed correctly.
245            let name = format!("Square: {mode_a} ** {mode_b} {i}");
246            check_pow::<I, M>(&name, first, console::Integer::one() + console::Integer::one(), mode_a, mode_b);
247
248            // Check that the cube is computed correctly.
249            let name = format!("Cube: {mode_a} ** {mode_b} {i}");
250            check_pow::<I, M>(
251                &name,
252                first,
253                console::Integer::one() + console::Integer::one() + console::Integer::one(),
254                mode_a,
255                mode_b,
256            );
257        }
258
259        // Test corner cases for exponentiation.
260        check_pow::<I, M>("MAX ** MAX", console::Integer::MAX, console::Integer::MAX, mode_a, mode_b);
261        check_pow::<I, M>("MIN ** 0", console::Integer::MIN, console::Integer::zero(), mode_a, mode_b);
262        check_pow::<I, M>("MAX ** 0", console::Integer::MAX, console::Integer::zero(), mode_a, mode_b);
263        check_pow::<I, M>("MIN ** 1", console::Integer::MIN, console::Integer::one(), mode_a, mode_b);
264        check_pow::<I, M>("MAX ** 1", console::Integer::MAX, console::Integer::one(), mode_a, mode_b);
265    }
266
267    fn run_exhaustive_test<I: IntegerType + RefUnwindSafe, M: Magnitude + RefUnwindSafe>(mode_a: Mode, mode_b: Mode)
268    where
269        RangeInclusive<I>: Iterator<Item = I>,
270        RangeInclusive<M>: Iterator<Item = M>,
271    {
272        for first in I::MIN..=I::MAX {
273            for second in M::MIN..=M::MAX {
274                let first = console::Integer::<_, I>::new(first);
275                let second = console::Integer::<_, M>::new(second);
276
277                let name = format!("Pow: ({first} ** {second})");
278                check_pow::<I, M>(&name, first, second, mode_a, mode_b);
279            }
280        }
281    }
282
283    test_integer_binary!(run_test, i8, u8, pow);
284    test_integer_binary!(run_test, i8, u16, pow);
285    test_integer_binary!(run_test, i8, u32, pow);
286
287    test_integer_binary!(run_test, i16, u8, pow);
288    test_integer_binary!(run_test, i16, u16, pow);
289    test_integer_binary!(run_test, i16, u32, pow);
290
291    test_integer_binary!(run_test, i32, u8, pow);
292    test_integer_binary!(run_test, i32, u16, pow);
293    test_integer_binary!(run_test, i32, u32, pow);
294
295    test_integer_binary!(run_test, i64, u8, pow);
296    test_integer_binary!(run_test, i64, u16, pow);
297    test_integer_binary!(run_test, i64, u32, pow);
298
299    test_integer_binary!(run_test, i128, u8, pow);
300    test_integer_binary!(run_test, i128, u16, pow);
301    test_integer_binary!(run_test, i128, u32, pow);
302
303    test_integer_binary!(run_test, u8, u8, pow);
304    test_integer_binary!(run_test, u8, u16, pow);
305    test_integer_binary!(run_test, u8, u32, pow);
306
307    test_integer_binary!(run_test, u16, u8, pow);
308    test_integer_binary!(run_test, u16, u16, pow);
309    test_integer_binary!(run_test, u16, u32, pow);
310
311    test_integer_binary!(run_test, u32, u8, pow);
312    test_integer_binary!(run_test, u32, u16, pow);
313    test_integer_binary!(run_test, u32, u32, pow);
314
315    test_integer_binary!(run_test, u64, u8, pow);
316    test_integer_binary!(run_test, u64, u16, pow);
317    test_integer_binary!(run_test, u64, u32, pow);
318
319    test_integer_binary!(run_test, u128, u8, pow);
320    test_integer_binary!(run_test, u128, u16, pow);
321    test_integer_binary!(run_test, u128, u32, pow);
322
323    test_integer_binary!(#[ignore], run_exhaustive_test, u8, u8, pow, exhaustive);
324    test_integer_binary!(#[ignore], run_exhaustive_test, i8, u8, pow, exhaustive);
325}