snarkvm_circuit_types_integers/
shl_checked.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::*;
17
18impl<E: Environment, I: IntegerType, M: Magnitude> Shl<Integer<E, M>> for Integer<E, I> {
19    type Output = Self;
20
21    fn shl(self, rhs: Integer<E, M>) -> Self::Output {
22        self << &rhs
23    }
24}
25
26impl<E: Environment, I: IntegerType, M: Magnitude> Shl<Integer<E, M>> for &Integer<E, I> {
27    type Output = Integer<E, I>;
28
29    fn shl(self, rhs: Integer<E, M>) -> Self::Output {
30        self << &rhs
31    }
32}
33
34impl<E: Environment, I: IntegerType, M: Magnitude> Shl<&Integer<E, M>> for Integer<E, I> {
35    type Output = Self;
36
37    fn shl(self, rhs: &Integer<E, M>) -> Self::Output {
38        &self << rhs
39    }
40}
41
42impl<E: Environment, I: IntegerType, M: Magnitude> Shl<&Integer<E, M>> for &Integer<E, I> {
43    type Output = Integer<E, I>;
44
45    fn shl(self, rhs: &Integer<E, M>) -> Self::Output {
46        let mut output = self.clone();
47        output <<= rhs;
48        output
49    }
50}
51
52impl<E: Environment, I: IntegerType, M: Magnitude> ShlAssign<Integer<E, M>> for Integer<E, I> {
53    fn shl_assign(&mut self, rhs: Integer<E, M>) {
54        *self <<= &rhs
55    }
56}
57
58impl<E: Environment, I: IntegerType, M: Magnitude> ShlAssign<&Integer<E, M>> for Integer<E, I> {
59    fn shl_assign(&mut self, rhs: &Integer<E, M>) {
60        // Stores the result of `self` << `other` in `self`.
61        *self = self.shl_checked(rhs);
62    }
63}
64
65impl<E: Environment, I: IntegerType, M: Magnitude> ShlChecked<Integer<E, M>> for Integer<E, I> {
66    type Output = Self;
67
68    #[inline]
69    fn shl_checked(&self, rhs: &Integer<E, M>) -> Self::Output {
70        // Retrieve the index for the first upper bit from the RHS that we mask.
71        let first_upper_bit_index = I::BITS.trailing_zeros() as usize;
72        // Initialize a constant `two`.
73        let two = Self::one() + Self::one();
74        match I::is_signed() {
75            true => {
76                if 3 * I::BITS < E::BaseField::size_in_data_bits() as u64 {
77                    // Enforce that the upper bits of `rhs` are all zero.
78                    Boolean::assert_bits_are_zero(&rhs.bits_le[first_upper_bit_index..]);
79
80                    // Sign-extend `self` to 2 * I::BITS.
81                    let mut bits_le = self.to_bits_le();
82                    bits_le.resize(2 * I::BITS as usize, self.msb().clone());
83
84                    // Calculate the result directly in the field.
85                    // Since 2^{rhs} < Integer::MAX and 3 * I::BITS is less than E::BaseField::size in data bits,
86                    // we know that the operation will not overflow the field modulus.
87                    let mut result = Field::from_bits_le(&bits_le);
88                    for (i, bit) in rhs.bits_le[..first_upper_bit_index].iter().enumerate() {
89                        // In each iteration, multiple the result by 2^(1<<i), if the bit is set.
90                        // Note that instantiating the field from a u128 is safe since it is larger than all eligible integer types.
91                        let constant = Field::constant(console::Field::from_u128(2u128.pow(1 << i)));
92                        let product = &result * &constant;
93                        result = Field::ternary(bit, &product, &result);
94                    }
95                    // Extract the bits of the result, including the carry bits.
96                    let bits_le = result.to_lower_bits_le(3 * I::BITS as usize);
97                    // Split the bits into the lower and upper bits.
98                    let (lower_bits_le, upper_bits_le) = bits_le.split_at(I::BITS as usize);
99                    // Initialize the integer from the lower bits.
100                    let result = Self { bits_le: lower_bits_le.to_vec(), phantom: Default::default() };
101                    // Ensure that the sign of the first I::BITS upper bits match the sign of the result.
102                    for bit in &upper_bits_le[..(I::BITS as usize)] {
103                        E::assert_eq(bit, result.msb());
104                    }
105                    // Return the result.
106                    result
107                } else {
108                    // Compute 2 ^ `rhs` as unsigned integer of the size I::BITS.
109                    // This is necessary to avoid a spurious overflow when `rhs` is I::BITS - 1.
110                    // For example, 2i8 ^ 7i8 overflows, however -1i8 << 7i8 ==> -1i8 * 2i8 ^ 7i8 ==> -128i8, which is a valid i8 value.
111                    let unsigned_two = two.cast_as_dual();
112                    // Note that `pow_checked` is used to enforce that `rhs` < I::BITS.
113                    let unsigned_factor = unsigned_two.pow_checked(rhs);
114                    // For all values of `rhs` such that `rhs` < I::BITS,
115                    //  - if `rhs` == I::BITS - 1, `signed_factor` == I::MIN,
116                    //  - otherwise, `signed_factor` is the same as `unsigned_factor`.
117                    let signed_factor = Self { bits_le: unsigned_factor.bits_le, phantom: Default::default() };
118
119                    // If `signed_factor` is I::MIN, then negate `self` in order to balance the sign of I::MIN.
120                    let signed_factor_is_min = &signed_factor.is_equal(&Self::constant(console::Integer::MIN));
121                    let lhs = Self::ternary(signed_factor_is_min, &Self::zero().sub_wrapped(self), self);
122
123                    // Compute `lhs` * `factor`, which is equivalent to `lhs` * 2 ^ `rhs`.
124                    lhs.mul_checked(&signed_factor)
125                }
126            }
127            false => {
128                if 2 * I::BITS < E::BaseField::size_in_data_bits() as u64 {
129                    // Enforce that the upper bits of `rhs` are all zero.
130                    Boolean::assert_bits_are_zero(&rhs.bits_le[first_upper_bit_index..]);
131
132                    // Calculate the result directly in the field.
133                    // Since 2^{rhs} < Integer::MAX and 2 * I::BITS is less than E::BaseField::size in data bits,
134                    // we know that the operation will not overflow Integer::MAX or the field modulus.
135                    let mut result = self.to_field();
136                    for (i, bit) in rhs.bits_le[..first_upper_bit_index].iter().enumerate() {
137                        // In each iteration, multiply the result by 2^(1<<i), if the bit is set.
138                        // Note that instantiating the field from a u128 is safe since it is larger than all eligible integer types.
139                        let constant = Field::constant(console::Field::from_u128(2u128.pow(1 << i)));
140                        let product = &result * &constant;
141                        result = Field::ternary(bit, &product, &result);
142                    }
143                    // Extract the bits of the result, including the carry bits.
144                    let bits_le = result.to_lower_bits_le(2 * I::BITS as usize);
145                    // Split the bits into the lower and upper bits.
146                    let (lower_bits_le, upper_bits_le) = bits_le.split_at(I::BITS as usize);
147                    // Ensure that the carry bits are all zero.
148                    Boolean::assert_bits_are_zero(upper_bits_le);
149                    // Initialize the integer from the lower bits
150                    Self { bits_le: lower_bits_le.to_vec(), phantom: Default::default() }
151                } else {
152                    // Compute `lhs` * 2 ^ `rhs`.
153                    self.mul_checked(&two.pow_checked(rhs))
154                }
155            }
156        }
157    }
158}
159
160impl<E: Environment, I: IntegerType> Metrics<dyn Shl<Integer<E, I>, Output = Integer<E, I>>> for Integer<E, I> {
161    type Case = (Mode, Mode);
162
163    fn count(case: &Self::Case) -> Count {
164        <Self as Metrics<dyn DivChecked<Integer<E, I>, Output = Integer<E, I>>>>::count(case)
165    }
166}
167
168impl<E: Environment, I: IntegerType> OutputMode<dyn Shl<Integer<E, I>, Output = Integer<E, I>>> for Integer<E, I> {
169    type Case = (Mode, Mode);
170
171    fn output_mode(case: &Self::Case) -> Mode {
172        <Self as OutputMode<dyn DivChecked<Integer<E, I>, Output = Integer<E, I>>>>::output_mode(case)
173    }
174}
175
176impl<E: Environment, I: IntegerType, M: Magnitude> Metrics<dyn ShlChecked<Integer<E, M>, Output = Integer<E, I>>>
177    for Integer<E, I>
178{
179    type Case = (Mode, Mode, bool, bool);
180
181    fn count(case: &Self::Case) -> Count {
182        // A quick hack that matches `(u8 -> 0, u16 -> 1, u32 -> 2, u64 -> 3, u128 -> 4)`.
183        let index = |num_bits: u64| match [8, 16, 32, 64, 128].iter().position(|&bits| bits == num_bits) {
184            Some(index) => index as u64,
185            None => E::halt(format!("Integer of {num_bits} bits is not supported")),
186        };
187
188        match (case.0, case.1) {
189            (Mode::Constant, Mode::Constant) => Count::is(I::BITS, 0, 0, 0),
190            (_, Mode::Constant) => Count::is(0, 0, 0, 0),
191            (Mode::Constant, _) | (_, _) => {
192                let wrapped_count = count!(Integer<E, I>, ShlWrapped<Integer<E, M>, Output=Integer<E, I>>, case);
193                wrapped_count + Count::is(0, 0, M::BITS - 4 - index(I::BITS), M::BITS - 3 - index(I::BITS))
194            }
195        }
196    }
197}
198
199impl<E: Environment, I: IntegerType, M: Magnitude> OutputMode<dyn ShlChecked<Integer<E, M>, Output = Integer<E, I>>>
200    for Integer<E, I>
201{
202    type Case = (Mode, Mode);
203
204    fn output_mode(case: &Self::Case) -> Mode {
205        match (case.0, case.1) {
206            (Mode::Constant, Mode::Constant) => Mode::Constant,
207            (mode_a, Mode::Constant) => mode_a,
208            (_, _) => Mode::Private,
209        }
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use snarkvm_circuit_environment::Circuit;
217
218    use test_utilities::*;
219
220    use core::{ops::RangeInclusive, panic::RefUnwindSafe};
221
222    const ITERATIONS: u64 = 32;
223
224    fn check_shl<I: IntegerType + RefUnwindSafe, M: Magnitude + RefUnwindSafe + TryFrom<u64>>(
225        name: &str,
226        first: console::Integer<<Circuit as Environment>::Network, I>,
227        second: console::Integer<<Circuit as Environment>::Network, M>,
228        mode_a: Mode,
229        mode_b: Mode,
230    ) {
231        let a = Integer::<Circuit, I>::new(mode_a, first);
232        let b = Integer::<Circuit, M>::new(mode_b, second);
233
234        match first.checked_shl(&second.to_u32().unwrap()) {
235            Some(expected) => Circuit::scope(name, || {
236                let candidate = a.shl_checked(&b);
237                assert_eq!(expected, *candidate.eject_value());
238                assert_eq!(console::Integer::new(expected), candidate.eject_value());
239                // assert_count!(ShlChecked(Integer<I>, Integer<M>) => Integer<I>, &(mode_a, mode_b));
240                // assert_output_mode!(ShlChecked(Integer<I>, Integer<M>) => Integer<I>, &(mode_a, mode_b), candidate);
241                assert!(Circuit::is_satisfied_in_scope(), "(is_satisfied_in_scope)");
242            }),
243            None => match (mode_a, mode_b) {
244                (Mode::Constant, Mode::Constant) => check_operation_halts(&a, &b, Integer::shl_checked),
245                (_, Mode::Constant) => {
246                    // If `second` >= I::BITS, then the invocation to `pow_checked` will halt.
247                    // Otherwise, the invocation to `mul_checked` will not be satisfied.
248                    if *second >= M::try_from(I::BITS).unwrap_or_default() {
249                        check_operation_halts(&a, &b, Integer::shl_checked);
250                    } else {
251                        Circuit::scope(name, || {
252                            let _candidate = a.shl_checked(&b);
253                            // assert_count_fails!(ShlChecked(Integer<I>, Integer<M>) => Integer<I>, &(mode_a, mode_b));
254                            assert!(!Circuit::is_satisfied_in_scope(), "(!is_satisfied_in_scope)");
255                        })
256                    }
257                }
258                _ => Circuit::scope(name, || {
259                    let _candidate = a.shl_checked(&b);
260                    // assert_count_fails!(ShlChecked(Integer<I>, Integer<M>) => Integer<I>, &(mode_a, mode_b));
261                    assert!(!Circuit::is_satisfied_in_scope(), "(!is_satisfied_in_scope)");
262                }),
263            },
264        };
265        Circuit::reset();
266    }
267
268    fn run_test<I: IntegerType + RefUnwindSafe, M: Magnitude + RefUnwindSafe + TryFrom<u64>>(
269        mode_a: Mode,
270        mode_b: Mode,
271    ) {
272        let mut rng = TestRng::default();
273
274        for i in 0..ITERATIONS {
275            let first = Uniform::rand(&mut rng);
276            let second = Uniform::rand(&mut rng);
277
278            let name = format!("Shl: {mode_a} << {mode_b} {i}");
279            check_shl::<I, M>(&name, first, second, mode_a, mode_b);
280
281            // Check that shift left by zero is computed correctly.
282            let name = format!("Identity: {mode_a} << {mode_b} {i}");
283            check_shl::<I, M>(&name, first, console::Integer::zero(), mode_a, mode_b);
284
285            // Check that shift left by one is computed correctly.
286            let name = format!("Double: {mode_a} << {mode_b} {i}");
287            check_shl::<I, M>(&name, first, console::Integer::one(), mode_a, mode_b);
288
289            // Check that shift left by two is computed correctly.
290            let name = format!("Quadruple: {mode_a} << {mode_b} {i}");
291            check_shl::<I, M>(&name, first, console::Integer::one() + console::Integer::one(), mode_a, mode_b);
292
293            // Check that zero shifted left by `second` is computed correctly.
294            let name = format!("Zero: {mode_a} << {mode_b} {i}");
295            check_shl::<I, M>(&name, console::Integer::zero(), second, mode_a, mode_b);
296        }
297    }
298
299    fn run_exhaustive_test<I: IntegerType + RefUnwindSafe, M: Magnitude + RefUnwindSafe + TryFrom<u64>>(
300        mode_a: Mode,
301        mode_b: Mode,
302    ) where
303        RangeInclusive<I>: Iterator<Item = I>,
304        RangeInclusive<M>: Iterator<Item = M>,
305    {
306        for first in I::MIN..=I::MAX {
307            for second in M::MIN..=M::MAX {
308                let first = console::Integer::<_, I>::new(first);
309                let second = console::Integer::<_, M>::new(second);
310
311                let name = format!("Shl: ({first} << {second})");
312                check_shl::<I, M>(&name, first, second, mode_a, mode_b);
313            }
314        }
315    }
316
317    test_integer_binary!(run_test, i8, u8, shl);
318    test_integer_binary!(run_test, i8, u16, shl);
319    test_integer_binary!(run_test, i8, u32, shl);
320
321    test_integer_binary!(run_test, i16, u8, shl);
322    test_integer_binary!(run_test, i16, u16, shl);
323    test_integer_binary!(run_test, i16, u32, shl);
324
325    test_integer_binary!(run_test, i32, u8, shl);
326    test_integer_binary!(run_test, i32, u16, shl);
327    test_integer_binary!(run_test, i32, u32, shl);
328
329    test_integer_binary!(run_test, i64, u8, shl);
330    test_integer_binary!(run_test, i64, u16, shl);
331    test_integer_binary!(run_test, i64, u32, shl);
332
333    test_integer_binary!(run_test, i128, u8, shl);
334    test_integer_binary!(run_test, i128, u16, shl);
335    test_integer_binary!(run_test, i128, u32, shl);
336
337    test_integer_binary!(run_test, u8, u8, shl);
338    test_integer_binary!(run_test, u8, u16, shl);
339    test_integer_binary!(run_test, u8, u32, shl);
340
341    test_integer_binary!(run_test, u16, u8, shl);
342    test_integer_binary!(run_test, u16, u16, shl);
343    test_integer_binary!(run_test, u16, u32, shl);
344
345    test_integer_binary!(run_test, u32, u8, shl);
346    test_integer_binary!(run_test, u32, u16, shl);
347    test_integer_binary!(run_test, u32, u32, shl);
348
349    test_integer_binary!(run_test, u64, u8, shl);
350    test_integer_binary!(run_test, u64, u16, shl);
351    test_integer_binary!(run_test, u64, u32, shl);
352
353    test_integer_binary!(run_test, u128, u8, shl);
354    test_integer_binary!(run_test, u128, u16, shl);
355    test_integer_binary!(run_test, u128, u32, shl);
356
357    test_integer_binary!(#[ignore], run_exhaustive_test, u8, u8, shl, exhaustive);
358    test_integer_binary!(#[ignore], run_exhaustive_test, i8, u8, shl, exhaustive);
359}