snarkvm_circuit_types_integers/
mul_checked.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::*;
17
18impl<E: Environment, I: IntegerType> Mul<Integer<E, I>> for Integer<E, I> {
19    type Output = Self;
20
21    fn mul(self, other: Self) -> Self::Output {
22        self * &other
23    }
24}
25
26impl<E: Environment, I: IntegerType> Mul<Integer<E, I>> for &Integer<E, I> {
27    type Output = Integer<E, I>;
28
29    fn mul(self, other: Integer<E, I>) -> Self::Output {
30        self * &other
31    }
32}
33
34impl<E: Environment, I: IntegerType> Mul<&Integer<E, I>> for Integer<E, I> {
35    type Output = Self;
36
37    fn mul(self, other: &Self) -> Self::Output {
38        &self * other
39    }
40}
41
42impl<E: Environment, I: IntegerType> Mul<&Integer<E, I>> for &Integer<E, I> {
43    type Output = Integer<E, I>;
44
45    fn mul(self, other: &Integer<E, I>) -> Self::Output {
46        let mut output = self.clone();
47        output *= other;
48        output
49    }
50}
51
52impl<E: Environment, I: IntegerType> MulAssign<Integer<E, I>> for Integer<E, I> {
53    fn mul_assign(&mut self, other: Integer<E, I>) {
54        *self *= &other;
55    }
56}
57
58impl<E: Environment, I: IntegerType> MulAssign<&Integer<E, I>> for Integer<E, I> {
59    fn mul_assign(&mut self, other: &Integer<E, I>) {
60        // Stores the product of `self` and `other` in `self`.
61        *self = self.mul_checked(other);
62    }
63}
64
65impl<E: Environment, I: IntegerType> Metrics<dyn Mul<Integer<E, I>, Output = Integer<E, I>>> for Integer<E, I> {
66    type Case = (Mode, Mode);
67
68    fn count(case: &Self::Case) -> Count {
69        <Self as Metrics<dyn DivChecked<Integer<E, I>, Output = Integer<E, I>>>>::count(case)
70    }
71}
72
73impl<E: Environment, I: IntegerType> OutputMode<dyn Mul<Integer<E, I>, Output = Integer<E, I>>> for Integer<E, I> {
74    type Case = (Mode, Mode);
75
76    fn output_mode(case: &Self::Case) -> Mode {
77        <Self as OutputMode<dyn DivChecked<Integer<E, I>, Output = Integer<E, I>>>>::output_mode(case)
78    }
79}
80
81impl<E: Environment, I: IntegerType> MulChecked<Self> for Integer<E, I> {
82    type Output = Self;
83
84    #[inline]
85    fn mul_checked(&self, other: &Integer<E, I>) -> Self::Output {
86        // Determine the variable mode.
87        if self.is_constant() && other.is_constant() {
88            // Compute the product and return the new constant.
89            match self.eject_value().checked_mul(&other.eject_value()) {
90                Some(value) => Integer::new(Mode::Constant, console::Integer::new(value)),
91                None => E::halt("Integer overflow on multiplication of two constants"),
92            }
93        } else if I::is_signed() {
94            // Compute the product of `abs(self)` and `abs(other)`, while checking for an overflow.
95            // Note: it is safe to use `abs_wrapped` as we want `Integer::MIN` to be interpreted as an unsigned number.
96            let product = Self::mul_and_check(&self.abs_wrapped(), &other.abs_wrapped());
97
98            // If the product should be positive, then it cannot exceed the signed maximum.
99            let operands_same_sign = &self.msb().is_equal(other.msb());
100            let positive_product_overflows = operands_same_sign & product.msb();
101            E::assert_eq(positive_product_overflows, E::zero());
102
103            // If the product should be negative, then it cannot exceed the absolute value of the signed minimum.
104            let negative_product_underflows = {
105                let lower_product_bits_nonzero =
106                    product.bits_le[..(I::BITS as usize - 1)].iter().fold(Boolean::constant(false), |a, b| a | b);
107                let negative_product_lt_or_eq_signed_min =
108                    !product.msb() | (product.msb() & !lower_product_bits_nonzero);
109                !operands_same_sign & !negative_product_lt_or_eq_signed_min
110            };
111            E::assert_eq(negative_product_underflows, E::zero());
112
113            // Note that the relevant overflow cases are checked independently above.
114            // Return the product of `self` and `other` with the appropriate sign.
115            Self::ternary(operands_same_sign, &product, &Self::zero().sub_wrapped(&product))
116        } else {
117            // Compute the product of `self` and `other`, while checking for an overflow.
118            Self::mul_and_check(self, other)
119        }
120    }
121}
122
123impl<E: Environment, I: IntegerType> Integer<E, I> {
124    /// Multiply the integer bits of `this` and `that`, while checking for an overflow.
125    /// This function assumes that `this` and `that` are non-negative.
126    #[inline]
127    fn mul_and_check(this: &Integer<E, I>, that: &Integer<E, I>) -> Integer<E, I> {
128        // Case 1 - 2 integers fit in 1 field element (u8, u16, u32, u64, i8, i16, i32, i64).
129        if 2 * I::BITS < (E::BaseField::size_in_bits() - 1) as u64 {
130            // Instead of multiplying the bits of `self` and `other`, witness the integer product.
131            let product: Integer<E, I> = witness!(|this, that| this.mul_wrapped(&that));
132
133            // Check that the computed product is equal to witnessed product, in the base field.
134            // Note: The multiplication is safe as the field twice as large as the maximum integer type supported.
135            E::enforce(|| (this.to_field(), that.to_field(), product.to_field()));
136
137            product
138        }
139        // Case 2 - 1.5 integers fit in 1 field element (u128, i128).
140        else if (I::BITS + I::BITS / 2) < (E::BaseField::size_in_bits() - 1) as u64 {
141            // Use Karatsuba multiplication to compute the product of `self` and `other`.
142            let (product, z_1_upper_bits, z2) = Self::karatsuba_multiply(this, that);
143
144            // Check that the upper bits of z1 are zero.
145            Boolean::assert_bits_are_zero(&z_1_upper_bits);
146
147            // Check that `z2` is zero.
148            E::assert_eq(z2, E::zero());
149
150            // Return the product of `self` and `other`.
151            product
152        } else {
153            E::halt(format!("Multiplication of integers of size {} is not supported", I::BITS))
154        }
155    }
156}
157
158impl<E: Environment, I: IntegerType> Integer<E, I> {
159    /// Multiply the integer bits of `this` and `that`, using Karatsuba multiplication.
160    ///
161    /// See this page for reference: https://en.wikipedia.org/wiki/Karatsuba_algorithm.
162    ///
163    /// We follow the naming convention given in the `Basic Step` section of the cited page.
164    /// The output is the product of `this` and `that`, the upper bits of `z1`, and `z2` as a field element.
165    /// This function assumes that 1.5 * I::BITS fits in 1 field element.
166    #[inline]
167    pub(super) fn karatsuba_multiply(
168        this: &Integer<E, I>,
169        that: &Integer<E, I>,
170    ) -> (Integer<E, I>, Vec<Boolean<E>>, Field<E>) {
171        // Perform multiplication by decomposing it into operations on its upper and lower bits.
172        // Here is a picture of the bits involved, placed according to the power-of-two weights, in little endian order:
173        //   x0: <--I::BITS/2-->
174        //   x1:                <--I::BITS/2-->
175        //   y0: <--I::BITS/2-->
176        //   y1:                <--I::BITS/2-->
177        //   z0: <-----------I::BITS---------->
178        //   z1:                <-----------I::BITS+1--------->
179        //   z2:                               <-----------I::BITS---------->
180        //                                     |   overlap    |
181        // The carry bits include:
182        //   - the overlapping bits of z1 and z2
183        //   - the upper bits of z2
184
185        let x_1 = Field::from_bits_le(&this.bits_le[(I::BITS as usize / 2)..]);
186        let x_0 = Field::from_bits_le(&this.bits_le[..(I::BITS as usize / 2)]);
187        let y_1 = Field::from_bits_le(&that.bits_le[(I::BITS as usize / 2)..]);
188        let y_0 = Field::from_bits_le(&that.bits_le[..(I::BITS as usize / 2)]);
189
190        let z_0 = &x_0 * &y_0;
191        let z_2 = &x_1 * &y_1;
192        let z_1 = (&x_1 + &x_0) * (&y_1 + &y_0) - &z_2 - &z_0;
193
194        let mut b_m_bits = vec![Boolean::constant(false); I::BITS as usize / 2];
195        b_m_bits.push(Boolean::constant(true));
196
197        let b_m = Field::from_bits_le(&b_m_bits);
198        let z_0_plus_scaled_z_1 = &z_0 + (&z_1 * &b_m);
199
200        let bits_le = z_0_plus_scaled_z_1.to_lower_bits_le(I::BITS as usize + I::BITS as usize / 2 + 1);
201
202        // Split the integer bits into product bits and the upper bits of `z_1`.
203        let (bits_le, carry) = bits_le.split_at(I::BITS as usize);
204
205        // Return the product of `self` and `other`, along with the carry bits.
206        (Integer::from_bits_le(bits_le), carry.to_vec(), z_2)
207    }
208}
209
210impl<E: Environment, I: IntegerType> Metrics<dyn MulChecked<Integer<E, I>, Output = Integer<E, I>>> for Integer<E, I> {
211    type Case = (Mode, Mode);
212
213    fn count(case: &Self::Case) -> Count {
214        // Case 1 - 2 integers fit in 1 field element (u8, u16, u32, u64, i8, i16, i32, i64).
215        if 2 * I::BITS < (E::BaseField::size_in_bits() - 1) as u64 {
216            match I::is_signed() {
217                // Signed case
218                true => match (case.0, case.1) {
219                    (Mode::Constant, Mode::Constant) => Count::is(I::BITS, 0, 0, 0),
220                    (Mode::Constant, _) | (_, Mode::Constant) => {
221                        Count::is(4 * I::BITS, 0, (6 * I::BITS) + 4, (6 * I::BITS) + 9)
222                    }
223                    (_, _) => Count::is(3 * I::BITS, 0, (8 * I::BITS) + 6, (8 * I::BITS) + 12),
224                },
225                // Unsigned case
226                false => match (case.0, case.1) {
227                    (Mode::Constant, Mode::Constant) => Count::is(I::BITS, 0, 0, 0),
228                    (Mode::Constant, _) | (_, Mode::Constant) => Count::is(0, 0, I::BITS, I::BITS + 1),
229                    (_, _) => Count::is(0, 0, I::BITS, I::BITS + 1),
230                },
231            }
232        }
233        // Case 2 - 1.5 integers fit in 1 field element (u128, i128).
234        else if (I::BITS + I::BITS / 2) < (E::BaseField::size_in_bits() - 1) as u64 {
235            match I::is_signed() {
236                // Signed case
237                true => match (case.0, case.1) {
238                    (Mode::Constant, Mode::Constant) => Count::is(I::BITS, 0, 0, 0),
239                    (Mode::Constant, _) | (_, Mode::Constant) => Count::less_than(833, 0, 837, 844),
240                    (_, _) => Count::is(3 * I::BITS, 0, 1098, 1106),
241                },
242                // Unsigned case
243                false => match (case.0, case.1) {
244                    (Mode::Constant, Mode::Constant) => Count::is(I::BITS, 0, 0, 0),
245                    (Mode::Constant, _) | (_, Mode::Constant) => Count::less_than(193, 0, 193, 199),
246                    (_, _) => Count::is(0, 0, 196, 199),
247                },
248            }
249        } else {
250            E::halt(format!("Multiplication of integers of size {} is not supported", I::BITS))
251        }
252    }
253}
254
255impl<E: Environment, I: IntegerType> OutputMode<dyn MulChecked<Integer<E, I>, Output = Integer<E, I>>>
256    for Integer<E, I>
257{
258    type Case = (Mode, Mode);
259
260    fn output_mode(case: &Self::Case) -> Mode {
261        match (case.0, case.1) {
262            (Mode::Constant, Mode::Constant) => Mode::Constant,
263            _ => Mode::Private,
264        }
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use snarkvm_circuit_environment::Circuit;
272
273    use test_utilities::*;
274
275    use core::{ops::RangeInclusive, panic::RefUnwindSafe};
276
277    const ITERATIONS: u64 = 32;
278
279    fn check_mul<I: IntegerType + RefUnwindSafe>(
280        name: &str,
281        first: console::Integer<<Circuit as Environment>::Network, I>,
282        second: console::Integer<<Circuit as Environment>::Network, I>,
283        mode_a: Mode,
284        mode_b: Mode,
285    ) {
286        let a = Integer::<Circuit, I>::new(mode_a, first);
287        let b = Integer::<Circuit, I>::new(mode_b, second);
288        match first.checked_mul(&second) {
289            Some(expected) => Circuit::scope(name, || {
290                let candidate = a.mul_checked(&b);
291                assert_eq!(expected, *candidate.eject_value());
292                assert_eq!(console::Integer::new(expected), candidate.eject_value());
293                assert_count!(MulChecked(Integer<I>, Integer<I>) => Integer<I>, &(mode_a, mode_b));
294                // assert_output_mode!(MulChecked(Integer<I>, Integer<I>) => Integer<I>, &(mode_a, mode_b), candidate);
295            }),
296            None => match (mode_a, mode_b) {
297                (Mode::Constant, Mode::Constant) => check_operation_halts(&a, &b, Integer::mul_checked),
298                _ => Circuit::scope(name, || {
299                    let _candidate = a.mul_checked(&b);
300                    assert_count_fails!(MulChecked(Integer<I>, Integer<I>) => Integer<I>, &(mode_a, mode_b));
301                }),
302            },
303        }
304        Circuit::reset();
305    }
306
307    fn run_test<I: IntegerType + RefUnwindSafe>(mode_a: Mode, mode_b: Mode) {
308        let mut rng = TestRng::default();
309
310        for i in 0..ITERATIONS {
311            // TODO (@pranav) Uniform random sampling almost always produces arguments that result in an overflow.
312            //  Is there a better method for sampling arguments?
313            let first = Uniform::rand(&mut rng);
314            let second = Uniform::rand(&mut rng);
315
316            let name = format!("Mul: {mode_a} * {mode_b} {i}");
317            check_mul::<I>(&name, first, second, mode_a, mode_b);
318            check_mul::<I>(&name, second, first, mode_a, mode_b); // Commute the operation.
319
320            let name = format!("Double: {mode_a} * {mode_b} {i}");
321            check_mul::<I>(&name, first, console::Integer::one() + console::Integer::one(), mode_a, mode_b);
322            check_mul::<I>(&name, console::Integer::one() + console::Integer::one(), first, mode_a, mode_b); // Commute the operation.
323
324            let name = format!("Square: {mode_a} * {mode_b} {i}");
325            check_mul::<I>(&name, first, first, mode_a, mode_b);
326        }
327
328        // Check specific cases common to signed and unsigned integers.
329        check_mul::<I>("1 * MAX", console::Integer::one(), console::Integer::MAX, mode_a, mode_b);
330        check_mul::<I>("MAX * 1", console::Integer::MAX, console::Integer::one(), mode_a, mode_b);
331        check_mul::<I>("1 * MIN", console::Integer::one(), console::Integer::MIN, mode_a, mode_b);
332        check_mul::<I>("MIN * 1", console::Integer::MIN, console::Integer::one(), mode_a, mode_b);
333        check_mul::<I>("0 * MAX", console::Integer::zero(), console::Integer::MAX, mode_a, mode_b);
334        check_mul::<I>("MAX * 0", console::Integer::MAX, console::Integer::zero(), mode_a, mode_b);
335        check_mul::<I>("0 * MIN", console::Integer::zero(), console::Integer::MIN, mode_a, mode_b);
336        check_mul::<I>("MIN * 0", console::Integer::MIN, console::Integer::zero(), mode_a, mode_b);
337        check_mul::<I>("1 * 1", console::Integer::one(), console::Integer::one(), mode_a, mode_b);
338
339        // Check common overflow cases.
340        check_mul::<I>(
341            "MAX * 2",
342            console::Integer::MAX,
343            console::Integer::one() + console::Integer::one(),
344            mode_a,
345            mode_b,
346        );
347        check_mul::<I>(
348            "2 * MAX",
349            console::Integer::one() + console::Integer::one(),
350            console::Integer::MAX,
351            mode_a,
352            mode_b,
353        );
354
355        // Check additional corner cases for signed integers.
356        if I::is_signed() {
357            check_mul::<I>("MAX * -1", console::Integer::MAX, -console::Integer::one(), mode_a, mode_b);
358            check_mul::<I>("-1 * MAX", -console::Integer::one(), console::Integer::MAX, mode_a, mode_b);
359            check_mul::<I>("MIN * -1", console::Integer::MIN, -console::Integer::one(), mode_a, mode_b);
360            check_mul::<I>("-1 * MIN", -console::Integer::one(), console::Integer::MIN, mode_a, mode_b);
361            check_mul::<I>(
362                "MIN * -2",
363                console::Integer::MIN,
364                -console::Integer::one() - console::Integer::one(),
365                mode_a,
366                mode_b,
367            );
368            check_mul::<I>(
369                "-2 * MIN",
370                -console::Integer::one() - console::Integer::one(),
371                console::Integer::MIN,
372                mode_a,
373                mode_b,
374            );
375        }
376    }
377
378    fn run_exhaustive_test<I: IntegerType + RefUnwindSafe>(mode_a: Mode, mode_b: Mode)
379    where
380        RangeInclusive<I>: Iterator<Item = I>,
381    {
382        for first in I::MIN..=I::MAX {
383            for second in I::MIN..=I::MAX {
384                let first = console::Integer::<_, I>::new(first);
385                let second = console::Integer::<_, I>::new(second);
386
387                let name = format!("Mul: ({first} * {second})");
388                check_mul::<I>(&name, first, second, mode_a, mode_b);
389            }
390        }
391    }
392
393    test_integer_binary!(run_test, i8, times);
394    test_integer_binary!(run_test, i16, times);
395    test_integer_binary!(run_test, i32, times);
396    test_integer_binary!(run_test, i64, times);
397    test_integer_binary!(run_test, i128, times);
398
399    test_integer_binary!(run_test, u8, times);
400    test_integer_binary!(run_test, u16, times);
401    test_integer_binary!(run_test, u32, times);
402    test_integer_binary!(run_test, u64, times);
403    test_integer_binary!(run_test, u128, times);
404
405    test_integer_binary!(#[ignore], run_exhaustive_test, u8, times, exhaustive);
406    test_integer_binary!(#[ignore], run_exhaustive_test, i8, times, exhaustive);
407}