1use super::*;
17
18impl<E: Environment, I: IntegerType, M: Magnitude> Shl<Integer<E, M>> for Integer<E, I> {
19 type Output = Self;
20
21 fn shl(self, rhs: Integer<E, M>) -> Self::Output {
22 self << &rhs
23 }
24}
25
26impl<E: Environment, I: IntegerType, M: Magnitude> Shl<Integer<E, M>> for &Integer<E, I> {
27 type Output = Integer<E, I>;
28
29 fn shl(self, rhs: Integer<E, M>) -> Self::Output {
30 self << &rhs
31 }
32}
33
34impl<E: Environment, I: IntegerType, M: Magnitude> Shl<&Integer<E, M>> for Integer<E, I> {
35 type Output = Self;
36
37 fn shl(self, rhs: &Integer<E, M>) -> Self::Output {
38 &self << rhs
39 }
40}
41
42impl<E: Environment, I: IntegerType, M: Magnitude> Shl<&Integer<E, M>> for &Integer<E, I> {
43 type Output = Integer<E, I>;
44
45 fn shl(self, rhs: &Integer<E, M>) -> Self::Output {
46 let mut output = self.clone();
47 output <<= rhs;
48 output
49 }
50}
51
52impl<E: Environment, I: IntegerType, M: Magnitude> ShlAssign<Integer<E, M>> for Integer<E, I> {
53 fn shl_assign(&mut self, rhs: Integer<E, M>) {
54 *self <<= &rhs
55 }
56}
57
58impl<E: Environment, I: IntegerType, M: Magnitude> ShlAssign<&Integer<E, M>> for Integer<E, I> {
59 fn shl_assign(&mut self, rhs: &Integer<E, M>) {
60 *self = self.shl_checked(rhs);
62 }
63}
64
65impl<E: Environment, I: IntegerType, M: Magnitude> ShlChecked<Integer<E, M>> for Integer<E, I> {
66 type Output = Self;
67
68 #[inline]
69 fn shl_checked(&self, rhs: &Integer<E, M>) -> Self::Output {
70 let first_upper_bit_index = I::BITS.trailing_zeros() as usize;
72 let two = Self::one() + Self::one();
74 match I::is_signed() {
75 true => {
76 if 3 * I::BITS < E::BaseField::size_in_data_bits() as u64 {
77 Boolean::assert_bits_are_zero(&rhs.bits_le[first_upper_bit_index..]);
79
80 let mut bits_le = self.to_bits_le();
82 bits_le.resize(2 * I::BITS as usize, self.msb().clone());
83
84 let mut result = Field::from_bits_le(&bits_le);
88 for (i, bit) in rhs.bits_le[..first_upper_bit_index].iter().enumerate() {
89 let constant = Field::constant(console::Field::from_u128(2u128.pow(1 << i)));
92 let product = &result * &constant;
93 result = Field::ternary(bit, &product, &result);
94 }
95 let bits_le = result.to_lower_bits_le(3 * I::BITS as usize);
97 let (lower_bits_le, upper_bits_le) = bits_le.split_at(I::BITS as usize);
99 let result = Self { bits_le: lower_bits_le.to_vec(), phantom: Default::default() };
101 for bit in &upper_bits_le[..(I::BITS as usize)] {
103 E::assert_eq(bit, result.msb());
104 }
105 result
107 } else {
108 let unsigned_two = two.cast_as_dual();
112 let unsigned_factor = unsigned_two.pow_checked(rhs);
114 let signed_factor = Self { bits_le: unsigned_factor.bits_le, phantom: Default::default() };
118
119 let signed_factor_is_min = &signed_factor.is_equal(&Self::constant(console::Integer::MIN));
121 let lhs = Self::ternary(signed_factor_is_min, &Self::zero().sub_wrapped(self), self);
122
123 lhs.mul_checked(&signed_factor)
125 }
126 }
127 false => {
128 if 2 * I::BITS < E::BaseField::size_in_data_bits() as u64 {
129 Boolean::assert_bits_are_zero(&rhs.bits_le[first_upper_bit_index..]);
131
132 let mut result = self.to_field();
136 for (i, bit) in rhs.bits_le[..first_upper_bit_index].iter().enumerate() {
137 let constant = Field::constant(console::Field::from_u128(2u128.pow(1 << i)));
140 let product = &result * &constant;
141 result = Field::ternary(bit, &product, &result);
142 }
143 let bits_le = result.to_lower_bits_le(2 * I::BITS as usize);
145 let (lower_bits_le, upper_bits_le) = bits_le.split_at(I::BITS as usize);
147 Boolean::assert_bits_are_zero(upper_bits_le);
149 Self { bits_le: lower_bits_le.to_vec(), phantom: Default::default() }
151 } else {
152 self.mul_checked(&two.pow_checked(rhs))
154 }
155 }
156 }
157 }
158}
159
160impl<E: Environment, I: IntegerType> Metrics<dyn Shl<Integer<E, I>, Output = Integer<E, I>>> for Integer<E, I> {
161 type Case = (Mode, Mode);
162
163 fn count(case: &Self::Case) -> Count {
164 <Self as Metrics<dyn DivChecked<Integer<E, I>, Output = Integer<E, I>>>>::count(case)
165 }
166}
167
168impl<E: Environment, I: IntegerType> OutputMode<dyn Shl<Integer<E, I>, Output = Integer<E, I>>> for Integer<E, I> {
169 type Case = (Mode, Mode);
170
171 fn output_mode(case: &Self::Case) -> Mode {
172 <Self as OutputMode<dyn DivChecked<Integer<E, I>, Output = Integer<E, I>>>>::output_mode(case)
173 }
174}
175
176impl<E: Environment, I: IntegerType, M: Magnitude> Metrics<dyn ShlChecked<Integer<E, M>, Output = Integer<E, I>>>
177 for Integer<E, I>
178{
179 type Case = (Mode, Mode, bool, bool);
180
181 fn count(case: &Self::Case) -> Count {
182 let index = |num_bits: u64| match [8, 16, 32, 64, 128].iter().position(|&bits| bits == num_bits) {
184 Some(index) => index as u64,
185 None => E::halt(format!("Integer of {num_bits} bits is not supported")),
186 };
187
188 match (case.0, case.1) {
189 (Mode::Constant, Mode::Constant) => Count::is(I::BITS, 0, 0, 0),
190 (_, Mode::Constant) => Count::is(0, 0, 0, 0),
191 (Mode::Constant, _) | (_, _) => {
192 let wrapped_count = count!(Integer<E, I>, ShlWrapped<Integer<E, M>, Output=Integer<E, I>>, case);
193 wrapped_count + Count::is(0, 0, M::BITS - 4 - index(I::BITS), M::BITS - 3 - index(I::BITS))
194 }
195 }
196 }
197}
198
199impl<E: Environment, I: IntegerType, M: Magnitude> OutputMode<dyn ShlChecked<Integer<E, M>, Output = Integer<E, I>>>
200 for Integer<E, I>
201{
202 type Case = (Mode, Mode);
203
204 fn output_mode(case: &Self::Case) -> Mode {
205 match (case.0, case.1) {
206 (Mode::Constant, Mode::Constant) => Mode::Constant,
207 (mode_a, Mode::Constant) => mode_a,
208 (_, _) => Mode::Private,
209 }
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use snarkvm_circuit_environment::Circuit;
217
218 use test_utilities::*;
219
220 use core::{ops::RangeInclusive, panic::RefUnwindSafe};
221
222 const ITERATIONS: u64 = 32;
223
224 fn check_shl<I: IntegerType + RefUnwindSafe, M: Magnitude + RefUnwindSafe + TryFrom<u64>>(
225 name: &str,
226 first: console::Integer<<Circuit as Environment>::Network, I>,
227 second: console::Integer<<Circuit as Environment>::Network, M>,
228 mode_a: Mode,
229 mode_b: Mode,
230 ) {
231 let a = Integer::<Circuit, I>::new(mode_a, first);
232 let b = Integer::<Circuit, M>::new(mode_b, second);
233
234 match first.checked_shl(&second.to_u32().unwrap()) {
235 Some(expected) => Circuit::scope(name, || {
236 let candidate = a.shl_checked(&b);
237 assert_eq!(expected, *candidate.eject_value());
238 assert_eq!(console::Integer::new(expected), candidate.eject_value());
239 assert!(Circuit::is_satisfied_in_scope(), "(is_satisfied_in_scope)");
242 }),
243 None => match (mode_a, mode_b) {
244 (Mode::Constant, Mode::Constant) => check_operation_halts(&a, &b, Integer::shl_checked),
245 (_, Mode::Constant) => {
246 if *second >= M::try_from(I::BITS).unwrap_or_default() {
249 check_operation_halts(&a, &b, Integer::shl_checked);
250 } else {
251 Circuit::scope(name, || {
252 let _candidate = a.shl_checked(&b);
253 assert!(!Circuit::is_satisfied_in_scope(), "(!is_satisfied_in_scope)");
255 })
256 }
257 }
258 _ => Circuit::scope(name, || {
259 let _candidate = a.shl_checked(&b);
260 assert!(!Circuit::is_satisfied_in_scope(), "(!is_satisfied_in_scope)");
262 }),
263 },
264 };
265 Circuit::reset();
266 }
267
268 fn run_test<I: IntegerType + RefUnwindSafe, M: Magnitude + RefUnwindSafe + TryFrom<u64>>(
269 mode_a: Mode,
270 mode_b: Mode,
271 ) {
272 let mut rng = TestRng::default();
273
274 for i in 0..ITERATIONS {
275 let first = Uniform::rand(&mut rng);
276 let second = Uniform::rand(&mut rng);
277
278 let name = format!("Shl: {mode_a} << {mode_b} {i}");
279 check_shl::<I, M>(&name, first, second, mode_a, mode_b);
280
281 let name = format!("Identity: {mode_a} << {mode_b} {i}");
283 check_shl::<I, M>(&name, first, console::Integer::zero(), mode_a, mode_b);
284
285 let name = format!("Double: {mode_a} << {mode_b} {i}");
287 check_shl::<I, M>(&name, first, console::Integer::one(), mode_a, mode_b);
288
289 let name = format!("Quadruple: {mode_a} << {mode_b} {i}");
291 check_shl::<I, M>(&name, first, console::Integer::one() + console::Integer::one(), mode_a, mode_b);
292
293 let name = format!("Zero: {mode_a} << {mode_b} {i}");
295 check_shl::<I, M>(&name, console::Integer::zero(), second, mode_a, mode_b);
296 }
297 }
298
299 fn run_exhaustive_test<I: IntegerType + RefUnwindSafe, M: Magnitude + RefUnwindSafe + TryFrom<u64>>(
300 mode_a: Mode,
301 mode_b: Mode,
302 ) where
303 RangeInclusive<I>: Iterator<Item = I>,
304 RangeInclusive<M>: Iterator<Item = M>,
305 {
306 for first in I::MIN..=I::MAX {
307 for second in M::MIN..=M::MAX {
308 let first = console::Integer::<_, I>::new(first);
309 let second = console::Integer::<_, M>::new(second);
310
311 let name = format!("Shl: ({first} << {second})");
312 check_shl::<I, M>(&name, first, second, mode_a, mode_b);
313 }
314 }
315 }
316
317 test_integer_binary!(run_test, i8, u8, shl);
318 test_integer_binary!(run_test, i8, u16, shl);
319 test_integer_binary!(run_test, i8, u32, shl);
320
321 test_integer_binary!(run_test, i16, u8, shl);
322 test_integer_binary!(run_test, i16, u16, shl);
323 test_integer_binary!(run_test, i16, u32, shl);
324
325 test_integer_binary!(run_test, i32, u8, shl);
326 test_integer_binary!(run_test, i32, u16, shl);
327 test_integer_binary!(run_test, i32, u32, shl);
328
329 test_integer_binary!(run_test, i64, u8, shl);
330 test_integer_binary!(run_test, i64, u16, shl);
331 test_integer_binary!(run_test, i64, u32, shl);
332
333 test_integer_binary!(run_test, i128, u8, shl);
334 test_integer_binary!(run_test, i128, u16, shl);
335 test_integer_binary!(run_test, i128, u32, shl);
336
337 test_integer_binary!(run_test, u8, u8, shl);
338 test_integer_binary!(run_test, u8, u16, shl);
339 test_integer_binary!(run_test, u8, u32, shl);
340
341 test_integer_binary!(run_test, u16, u8, shl);
342 test_integer_binary!(run_test, u16, u16, shl);
343 test_integer_binary!(run_test, u16, u32, shl);
344
345 test_integer_binary!(run_test, u32, u8, shl);
346 test_integer_binary!(run_test, u32, u16, shl);
347 test_integer_binary!(run_test, u32, u32, shl);
348
349 test_integer_binary!(run_test, u64, u8, shl);
350 test_integer_binary!(run_test, u64, u16, shl);
351 test_integer_binary!(run_test, u64, u32, shl);
352
353 test_integer_binary!(run_test, u128, u8, shl);
354 test_integer_binary!(run_test, u128, u16, shl);
355 test_integer_binary!(run_test, u128, u32, shl);
356
357 test_integer_binary!(#[ignore], run_exhaustive_test, u8, u8, shl, exhaustive);
358 test_integer_binary!(#[ignore], run_exhaustive_test, i8, u8, shl, exhaustive);
359}