use super::*;
impl<E: Environment, I: IntegerType, M: Magnitude> ShrWrapped<Integer<E, M>> for Integer<E, I> {
type Output = Self;
#[inline]
fn shr_wrapped(&self, rhs: &Integer<E, M>) -> Self::Output {
if self.is_constant() && rhs.is_constant() {
witness!(|self, rhs| console::Integer::new(self.wrapping_shr(rhs.to_u32().unwrap())))
} else {
let first_upper_bit_index = I::BITS.trailing_zeros() as usize;
let mut lower_rhs_bits = Vec::with_capacity(8);
lower_rhs_bits.extend_from_slice(&rhs.bits_le[..first_upper_bit_index]);
lower_rhs_bits.resize(8, Boolean::constant(false));
let rhs_as_u8 = U8 { bits_le: lower_rhs_bits, phantom: Default::default() };
if rhs_as_u8.is_constant() {
let shift_amount = *rhs_as_u8.eject_value() as usize;
let mut bits_le = Vec::with_capacity(I::BITS as usize + shift_amount);
bits_le.extend_from_slice(&self.bits_le);
match I::is_signed() {
true => bits_le.extend(core::iter::repeat(self.msb().clone()).take(shift_amount)),
false => bits_le.extend(core::iter::repeat(Boolean::constant(false)).take(shift_amount)),
};
bits_le.reverse();
bits_le.truncate(I::BITS as usize);
bits_le.reverse();
Self { bits_le, phantom: Default::default() }
} else {
let two = Field::one() + Field::one();
let mut shift_in_field = Field::one();
for bit in rhs.bits_le[..first_upper_bit_index].iter().rev() {
shift_in_field = shift_in_field.square();
shift_in_field = Field::ternary(bit, &(&shift_in_field * &two), &shift_in_field);
}
let shift_as_divisor =
Self { bits_le: shift_in_field.to_lower_bits_le(I::BITS as usize), phantom: Default::default() };
if I::is_signed() {
let unsigned_divided = self.abs_wrapped().cast_as_dual();
let unsigned_divisor = shift_as_divisor.cast_as_dual();
let (unsigned_quotient, unsigned_remainder) =
unsigned_divided.unsigned_division_via_witness(&unsigned_divisor);
let quotient = Self { bits_le: unsigned_quotient.bits_le, phantom: Default::default() };
let negated_quotient = &(!"ient).add_wrapped(&Self::one());
let rounded_negated_quotient = Self::ternary(
&unsigned_remainder.to_field().is_equal(&Field::zero()),
negated_quotient,
&(negated_quotient).sub_wrapped(&Self::one()),
);
Self::ternary(self.msb(), &rounded_negated_quotient, "ient)
} else {
self.div_wrapped(&shift_as_divisor)
}
}
}
}
}
impl<E: Environment, I: IntegerType, M: Magnitude> Metrics<dyn ShrWrapped<Integer<E, M>, Output = Integer<E, I>>>
for Integer<E, I>
{
type Case = (Mode, Mode);
#[rustfmt::skip]
fn count(case: &Self::Case) -> Count {
let index = |num_bits: u64| match [8, 16, 32, 64, 128].iter().position(|&bits| bits == num_bits) {
Some(index) => index as u64,
None => E::halt(format!("Integer of {num_bits} bits is not supported")),
};
match (case.0, case.1) {
(Mode::Constant, Mode::Constant) => Count::is(I::BITS, 0, 0, 0),
(_, Mode::Constant) => Count::is(0, 0, 0, 0),
(Mode::Constant, _) => {
match (I::is_signed(), 2 * I::BITS < E::BaseField::size_in_data_bits() as u64) {
(true, true) => Count::less_than(5 * I::BITS, 0, (10 * I::BITS) + (2 * index(I::BITS)) + 11, (10 * I::BITS) + (2 * index(I::BITS)) + 19),
(true, false) => Count::less_than(5 * I::BITS, 0, 1752, 1957),
(false, true) => Count::less_than(I::BITS, 0, (4 * I::BITS) + (2 * index(I::BITS)) + 6, (4 * I::BITS) + (2 * index(I::BITS)) + 10),
(false, false) => Count::less_than(I::BITS, 0, 979, 1180),
}
}
(_, _) => match (I::is_signed(), 2 * I::BITS < E::BaseField::size_in_data_bits() as u64) {
(true, true) => Count::is(4 * I::BITS, 0, (10 * I::BITS) + (2 * index(I::BITS)) + 11, (10 * I::BITS) + (2 * index(I::BITS)) + 19),
(true, false) => Count::is(4 * I::BITS, 0, 1752, 1957),
(false, true) => Count::is(I::BITS, 0, (4 * I::BITS) + (2 * index(I::BITS)) + 6, (4 * I::BITS) + (2 * index(I::BITS)) + 10),
(false, false) => Count::is(I::BITS, 0, 979, 1180),
},
}
}
}
impl<E: Environment, I: IntegerType, M: Magnitude> OutputMode<dyn ShrWrapped<Integer<E, M>, Output = Integer<E, I>>>
for Integer<E, I>
{
type Case = (Mode, Mode);
fn output_mode(case: &Self::Case) -> Mode {
match (case.0, case.1) {
(Mode::Constant, Mode::Constant) => Mode::Constant,
(mode_a, Mode::Constant) => mode_a,
(_, _) => Mode::Private,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use snarkvm_circuit_environment::Circuit;
use core::{ops::RangeInclusive, panic::RefUnwindSafe};
const ITERATIONS: u64 = 32;
fn check_shr<I: IntegerType + RefUnwindSafe, M: Magnitude + RefUnwindSafe>(
name: &str,
first: console::Integer<<Circuit as Environment>::Network, I>,
second: console::Integer<<Circuit as Environment>::Network, M>,
mode_a: Mode,
mode_b: Mode,
) {
let expected = first.wrapping_shr(second.to_u32().unwrap());
let a = Integer::<Circuit, I>::new(mode_a, first);
let b = Integer::<Circuit, M>::new(mode_b, second);
Circuit::scope(name, || {
let candidate = a.shr_wrapped(&b);
assert_eq!(expected, *candidate.eject_value());
assert_eq!(console::Integer::new(expected), candidate.eject_value());
assert_count!(ShrWrapped(Integer<I>, Integer<M>) => Integer<I>, &(mode_a, mode_b));
assert_output_mode!(ShrWrapped(Integer<I>, Integer<M>) => Integer<I>, &(mode_a, mode_b), candidate);
});
Circuit::reset();
}
fn run_test<I: IntegerType + RefUnwindSafe, M: Magnitude + RefUnwindSafe>(mode_a: Mode, mode_b: Mode) {
let mut rng = TestRng::default();
for i in 0..ITERATIONS {
let first = Uniform::rand(&mut rng);
let second = Uniform::rand(&mut rng);
let name = format!("Shr: {mode_a} >> {mode_b} {i}");
check_shr::<I, M>(&name, first, second, mode_a, mode_b);
let name = format!("Half: {mode_a} >> {mode_b} {i}");
check_shr::<I, M>(&name, first, console::Integer::one(), mode_a, mode_b);
}
}
fn run_exhaustive_test<I: IntegerType + RefUnwindSafe, M: Magnitude + RefUnwindSafe>(mode_a: Mode, mode_b: Mode)
where
RangeInclusive<I>: Iterator<Item = I>,
RangeInclusive<M>: Iterator<Item = M>,
{
for first in I::MIN..=I::MAX {
for second in M::MIN..=M::MAX {
let first = console::Integer::<_, I>::new(first);
let second = console::Integer::<_, M>::new(second);
let name = format!("Shr: ({first} >> {second})");
check_shr::<I, M>(&name, first, second, mode_a, mode_b);
}
}
}
test_integer_binary!(run_test, i8, u8, shr);
test_integer_binary!(run_test, i8, u16, shr);
test_integer_binary!(run_test, i8, u32, shr);
test_integer_binary!(run_test, i16, u8, shr);
test_integer_binary!(run_test, i16, u16, shr);
test_integer_binary!(run_test, i16, u32, shr);
test_integer_binary!(run_test, i32, u8, shr);
test_integer_binary!(run_test, i32, u16, shr);
test_integer_binary!(run_test, i32, u32, shr);
test_integer_binary!(run_test, i64, u8, shr);
test_integer_binary!(run_test, i64, u16, shr);
test_integer_binary!(run_test, i64, u32, shr);
test_integer_binary!(run_test, i128, u8, shr);
test_integer_binary!(run_test, i128, u16, shr);
test_integer_binary!(run_test, i128, u32, shr);
test_integer_binary!(run_test, u8, u8, shr);
test_integer_binary!(run_test, u8, u16, shr);
test_integer_binary!(run_test, u8, u32, shr);
test_integer_binary!(run_test, u16, u8, shr);
test_integer_binary!(run_test, u16, u16, shr);
test_integer_binary!(run_test, u16, u32, shr);
test_integer_binary!(run_test, u32, u8, shr);
test_integer_binary!(run_test, u32, u16, shr);
test_integer_binary!(run_test, u32, u32, shr);
test_integer_binary!(run_test, u64, u8, shr);
test_integer_binary!(run_test, u64, u16, shr);
test_integer_binary!(run_test, u64, u32, shr);
test_integer_binary!(run_test, u128, u8, shr);
test_integer_binary!(run_test, u128, u16, shr);
test_integer_binary!(run_test, u128, u32, shr);
test_integer_binary!(#[ignore], run_exhaustive_test, u8, u8, shr, exhaustive);
test_integer_binary!(#[ignore], run_exhaustive_test, i8, u8, shr, exhaustive);
}