1use super::*;
17
18impl<E: Environment, I: IntegerType> Mul<Integer<E, I>> for Integer<E, I> {
19 type Output = Self;
20
21 fn mul(self, other: Self) -> Self::Output {
22 self * &other
23 }
24}
25
26impl<E: Environment, I: IntegerType> Mul<Integer<E, I>> for &Integer<E, I> {
27 type Output = Integer<E, I>;
28
29 fn mul(self, other: Integer<E, I>) -> Self::Output {
30 self * &other
31 }
32}
33
34impl<E: Environment, I: IntegerType> Mul<&Integer<E, I>> for Integer<E, I> {
35 type Output = Self;
36
37 fn mul(self, other: &Self) -> Self::Output {
38 &self * other
39 }
40}
41
42impl<E: Environment, I: IntegerType> Mul<&Integer<E, I>> for &Integer<E, I> {
43 type Output = Integer<E, I>;
44
45 fn mul(self, other: &Integer<E, I>) -> Self::Output {
46 let mut output = self.clone();
47 output *= other;
48 output
49 }
50}
51
52impl<E: Environment, I: IntegerType> MulAssign<Integer<E, I>> for Integer<E, I> {
53 fn mul_assign(&mut self, other: Integer<E, I>) {
54 *self *= &other;
55 }
56}
57
58impl<E: Environment, I: IntegerType> MulAssign<&Integer<E, I>> for Integer<E, I> {
59 fn mul_assign(&mut self, other: &Integer<E, I>) {
60 *self = self.mul_checked(other);
62 }
63}
64
65impl<E: Environment, I: IntegerType> Metrics<dyn Mul<Integer<E, I>, Output = Integer<E, I>>> for Integer<E, I> {
66 type Case = (Mode, Mode);
67
68 fn count(case: &Self::Case) -> Count {
69 <Self as Metrics<dyn DivChecked<Integer<E, I>, Output = Integer<E, I>>>>::count(case)
70 }
71}
72
73impl<E: Environment, I: IntegerType> OutputMode<dyn Mul<Integer<E, I>, Output = Integer<E, I>>> for Integer<E, I> {
74 type Case = (Mode, Mode);
75
76 fn output_mode(case: &Self::Case) -> Mode {
77 <Self as OutputMode<dyn DivChecked<Integer<E, I>, Output = Integer<E, I>>>>::output_mode(case)
78 }
79}
80
81impl<E: Environment, I: IntegerType> MulChecked<Self> for Integer<E, I> {
82 type Output = Self;
83
84 #[inline]
85 fn mul_checked(&self, other: &Integer<E, I>) -> Self::Output {
86 if self.is_constant() && other.is_constant() {
88 match self.eject_value().checked_mul(&other.eject_value()) {
90 Some(value) => Integer::new(Mode::Constant, console::Integer::new(value)),
91 None => E::halt("Integer overflow on multiplication of two constants"),
92 }
93 } else if I::is_signed() {
94 let product = Self::mul_and_check(&self.abs_wrapped(), &other.abs_wrapped());
97
98 let operands_same_sign = &self.msb().is_equal(other.msb());
100 let positive_product_overflows = operands_same_sign & product.msb();
101 E::assert_eq(positive_product_overflows, E::zero());
102
103 let negative_product_underflows = {
105 let lower_product_bits_nonzero =
106 product.bits_le[..(I::BITS as usize - 1)].iter().fold(Boolean::constant(false), |a, b| a | b);
107 let negative_product_lt_or_eq_signed_min =
108 !product.msb() | (product.msb() & !lower_product_bits_nonzero);
109 !operands_same_sign & !negative_product_lt_or_eq_signed_min
110 };
111 E::assert_eq(negative_product_underflows, E::zero());
112
113 Self::ternary(operands_same_sign, &product, &Self::zero().sub_wrapped(&product))
116 } else {
117 Self::mul_and_check(self, other)
119 }
120 }
121}
122
123impl<E: Environment, I: IntegerType> Integer<E, I> {
124 #[inline]
127 fn mul_and_check(this: &Integer<E, I>, that: &Integer<E, I>) -> Integer<E, I> {
128 if 2 * I::BITS < (E::BaseField::size_in_bits() - 1) as u64 {
130 let product: Integer<E, I> = witness!(|this, that| this.mul_wrapped(&that));
132
133 E::enforce(|| (this.to_field(), that.to_field(), product.to_field()));
136
137 product
138 }
139 else if (I::BITS + I::BITS / 2) < (E::BaseField::size_in_bits() - 1) as u64 {
141 let (product, z_1_upper_bits, z2) = Self::karatsuba_multiply(this, that);
143
144 Boolean::assert_bits_are_zero(&z_1_upper_bits);
146
147 E::assert_eq(z2, E::zero());
149
150 product
152 } else {
153 E::halt(format!("Multiplication of integers of size {} is not supported", I::BITS))
154 }
155 }
156}
157
158impl<E: Environment, I: IntegerType> Integer<E, I> {
159 #[inline]
167 pub(super) fn karatsuba_multiply(
168 this: &Integer<E, I>,
169 that: &Integer<E, I>,
170 ) -> (Integer<E, I>, Vec<Boolean<E>>, Field<E>) {
171 let x_1 = Field::from_bits_le(&this.bits_le[(I::BITS as usize / 2)..]);
186 let x_0 = Field::from_bits_le(&this.bits_le[..(I::BITS as usize / 2)]);
187 let y_1 = Field::from_bits_le(&that.bits_le[(I::BITS as usize / 2)..]);
188 let y_0 = Field::from_bits_le(&that.bits_le[..(I::BITS as usize / 2)]);
189
190 let z_0 = &x_0 * &y_0;
191 let z_2 = &x_1 * &y_1;
192 let z_1 = (&x_1 + &x_0) * (&y_1 + &y_0) - &z_2 - &z_0;
193
194 let mut b_m_bits = vec![Boolean::constant(false); I::BITS as usize / 2];
195 b_m_bits.push(Boolean::constant(true));
196
197 let b_m = Field::from_bits_le(&b_m_bits);
198 let z_0_plus_scaled_z_1 = &z_0 + (&z_1 * &b_m);
199
200 let bits_le = z_0_plus_scaled_z_1.to_lower_bits_le(I::BITS as usize + I::BITS as usize / 2 + 1);
201
202 let (bits_le, carry) = bits_le.split_at(I::BITS as usize);
204
205 (Integer::from_bits_le(bits_le), carry.to_vec(), z_2)
207 }
208}
209
210impl<E: Environment, I: IntegerType> Metrics<dyn MulChecked<Integer<E, I>, Output = Integer<E, I>>> for Integer<E, I> {
211 type Case = (Mode, Mode);
212
213 fn count(case: &Self::Case) -> Count {
214 if 2 * I::BITS < (E::BaseField::size_in_bits() - 1) as u64 {
216 match I::is_signed() {
217 true => match (case.0, case.1) {
219 (Mode::Constant, Mode::Constant) => Count::is(I::BITS, 0, 0, 0),
220 (Mode::Constant, _) | (_, Mode::Constant) => {
221 Count::is(4 * I::BITS, 0, (6 * I::BITS) + 4, (6 * I::BITS) + 9)
222 }
223 (_, _) => Count::is(3 * I::BITS, 0, (8 * I::BITS) + 6, (8 * I::BITS) + 12),
224 },
225 false => match (case.0, case.1) {
227 (Mode::Constant, Mode::Constant) => Count::is(I::BITS, 0, 0, 0),
228 (Mode::Constant, _) | (_, Mode::Constant) => Count::is(0, 0, I::BITS, I::BITS + 1),
229 (_, _) => Count::is(0, 0, I::BITS, I::BITS + 1),
230 },
231 }
232 }
233 else if (I::BITS + I::BITS / 2) < (E::BaseField::size_in_bits() - 1) as u64 {
235 match I::is_signed() {
236 true => match (case.0, case.1) {
238 (Mode::Constant, Mode::Constant) => Count::is(I::BITS, 0, 0, 0),
239 (Mode::Constant, _) | (_, Mode::Constant) => Count::less_than(833, 0, 837, 844),
240 (_, _) => Count::is(3 * I::BITS, 0, 1098, 1106),
241 },
242 false => match (case.0, case.1) {
244 (Mode::Constant, Mode::Constant) => Count::is(I::BITS, 0, 0, 0),
245 (Mode::Constant, _) | (_, Mode::Constant) => Count::less_than(193, 0, 193, 199),
246 (_, _) => Count::is(0, 0, 196, 199),
247 },
248 }
249 } else {
250 E::halt(format!("Multiplication of integers of size {} is not supported", I::BITS))
251 }
252 }
253}
254
255impl<E: Environment, I: IntegerType> OutputMode<dyn MulChecked<Integer<E, I>, Output = Integer<E, I>>>
256 for Integer<E, I>
257{
258 type Case = (Mode, Mode);
259
260 fn output_mode(case: &Self::Case) -> Mode {
261 match (case.0, case.1) {
262 (Mode::Constant, Mode::Constant) => Mode::Constant,
263 _ => Mode::Private,
264 }
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use snarkvm_circuit_environment::Circuit;
272
273 use test_utilities::*;
274
275 use core::{ops::RangeInclusive, panic::RefUnwindSafe};
276
277 const ITERATIONS: u64 = 32;
278
279 fn check_mul<I: IntegerType + RefUnwindSafe>(
280 name: &str,
281 first: console::Integer<<Circuit as Environment>::Network, I>,
282 second: console::Integer<<Circuit as Environment>::Network, I>,
283 mode_a: Mode,
284 mode_b: Mode,
285 ) {
286 let a = Integer::<Circuit, I>::new(mode_a, first);
287 let b = Integer::<Circuit, I>::new(mode_b, second);
288 match first.checked_mul(&second) {
289 Some(expected) => Circuit::scope(name, || {
290 let candidate = a.mul_checked(&b);
291 assert_eq!(expected, *candidate.eject_value());
292 assert_eq!(console::Integer::new(expected), candidate.eject_value());
293 assert_count!(MulChecked(Integer<I>, Integer<I>) => Integer<I>, &(mode_a, mode_b));
294 }),
296 None => match (mode_a, mode_b) {
297 (Mode::Constant, Mode::Constant) => check_operation_halts(&a, &b, Integer::mul_checked),
298 _ => Circuit::scope(name, || {
299 let _candidate = a.mul_checked(&b);
300 assert_count_fails!(MulChecked(Integer<I>, Integer<I>) => Integer<I>, &(mode_a, mode_b));
301 }),
302 },
303 }
304 Circuit::reset();
305 }
306
307 fn run_test<I: IntegerType + RefUnwindSafe>(mode_a: Mode, mode_b: Mode) {
308 let mut rng = TestRng::default();
309
310 for i in 0..ITERATIONS {
311 let first = Uniform::rand(&mut rng);
314 let second = Uniform::rand(&mut rng);
315
316 let name = format!("Mul: {mode_a} * {mode_b} {i}");
317 check_mul::<I>(&name, first, second, mode_a, mode_b);
318 check_mul::<I>(&name, second, first, mode_a, mode_b); let name = format!("Double: {mode_a} * {mode_b} {i}");
321 check_mul::<I>(&name, first, console::Integer::one() + console::Integer::one(), mode_a, mode_b);
322 check_mul::<I>(&name, console::Integer::one() + console::Integer::one(), first, mode_a, mode_b); let name = format!("Square: {mode_a} * {mode_b} {i}");
325 check_mul::<I>(&name, first, first, mode_a, mode_b);
326 }
327
328 check_mul::<I>("1 * MAX", console::Integer::one(), console::Integer::MAX, mode_a, mode_b);
330 check_mul::<I>("MAX * 1", console::Integer::MAX, console::Integer::one(), mode_a, mode_b);
331 check_mul::<I>("1 * MIN", console::Integer::one(), console::Integer::MIN, mode_a, mode_b);
332 check_mul::<I>("MIN * 1", console::Integer::MIN, console::Integer::one(), mode_a, mode_b);
333 check_mul::<I>("0 * MAX", console::Integer::zero(), console::Integer::MAX, mode_a, mode_b);
334 check_mul::<I>("MAX * 0", console::Integer::MAX, console::Integer::zero(), mode_a, mode_b);
335 check_mul::<I>("0 * MIN", console::Integer::zero(), console::Integer::MIN, mode_a, mode_b);
336 check_mul::<I>("MIN * 0", console::Integer::MIN, console::Integer::zero(), mode_a, mode_b);
337 check_mul::<I>("1 * 1", console::Integer::one(), console::Integer::one(), mode_a, mode_b);
338
339 check_mul::<I>(
341 "MAX * 2",
342 console::Integer::MAX,
343 console::Integer::one() + console::Integer::one(),
344 mode_a,
345 mode_b,
346 );
347 check_mul::<I>(
348 "2 * MAX",
349 console::Integer::one() + console::Integer::one(),
350 console::Integer::MAX,
351 mode_a,
352 mode_b,
353 );
354
355 if I::is_signed() {
357 check_mul::<I>("MAX * -1", console::Integer::MAX, -console::Integer::one(), mode_a, mode_b);
358 check_mul::<I>("-1 * MAX", -console::Integer::one(), console::Integer::MAX, mode_a, mode_b);
359 check_mul::<I>("MIN * -1", console::Integer::MIN, -console::Integer::one(), mode_a, mode_b);
360 check_mul::<I>("-1 * MIN", -console::Integer::one(), console::Integer::MIN, mode_a, mode_b);
361 check_mul::<I>(
362 "MIN * -2",
363 console::Integer::MIN,
364 -console::Integer::one() - console::Integer::one(),
365 mode_a,
366 mode_b,
367 );
368 check_mul::<I>(
369 "-2 * MIN",
370 -console::Integer::one() - console::Integer::one(),
371 console::Integer::MIN,
372 mode_a,
373 mode_b,
374 );
375 }
376 }
377
378 fn run_exhaustive_test<I: IntegerType + RefUnwindSafe>(mode_a: Mode, mode_b: Mode)
379 where
380 RangeInclusive<I>: Iterator<Item = I>,
381 {
382 for first in I::MIN..=I::MAX {
383 for second in I::MIN..=I::MAX {
384 let first = console::Integer::<_, I>::new(first);
385 let second = console::Integer::<_, I>::new(second);
386
387 let name = format!("Mul: ({first} * {second})");
388 check_mul::<I>(&name, first, second, mode_a, mode_b);
389 }
390 }
391 }
392
393 test_integer_binary!(run_test, i8, times);
394 test_integer_binary!(run_test, i16, times);
395 test_integer_binary!(run_test, i32, times);
396 test_integer_binary!(run_test, i64, times);
397 test_integer_binary!(run_test, i128, times);
398
399 test_integer_binary!(run_test, u8, times);
400 test_integer_binary!(run_test, u16, times);
401 test_integer_binary!(run_test, u32, times);
402 test_integer_binary!(run_test, u64, times);
403 test_integer_binary!(run_test, u128, times);
404
405 test_integer_binary!(#[ignore], run_exhaustive_test, u8, times, exhaustive);
406 test_integer_binary!(#[ignore], run_exhaustive_test, i8, times, exhaustive);
407}