Skip to main content

sp1_core_machine/operations/
fixed_shift_right.rs

1use slop_air::AirBuilder;
2use slop_algebra::{AbstractField, Field};
3use sp1_core_executor::{events::ByteRecord, ByteOpcode};
4use sp1_derive::AlignedBorrow;
5use sp1_hypercube::air::SP1AirBuilder;
6use sp1_primitives::consts::u32_to_u16_limbs;
7
8use crate::utils::u32_to_half_word;
9
10/// A set of columns needed to compute `>>` of an u32 with a fixed offset R.
11///
12/// Note that we decompose shifts into a limb shift and a bit shift.
13#[derive(AlignedBorrow, Default, Debug, Clone, Copy)]
14#[repr(C)]
15pub struct FixedShiftRightOperation<T> {
16    /// The output value.
17    pub value: [T; 2],
18
19    /// The higher bits of each limb.
20    pub higher_limb: [T; 2],
21}
22
23impl<F: Field> FixedShiftRightOperation<F> {
24    pub const fn nb_limbs_to_shift(rotation: usize) -> usize {
25        rotation / 16
26    }
27
28    pub const fn nb_bits_to_shift(rotation: usize) -> usize {
29        rotation % 16
30    }
31
32    pub const fn carry_multiplier(rotation: usize) -> u32 {
33        let nb_bits_to_shift = Self::nb_bits_to_shift(rotation);
34        1 << (16 - nb_bits_to_shift)
35    }
36
37    pub fn populate(&mut self, record: &mut impl ByteRecord, input: u32, rotation: usize) -> u32 {
38        let input_limbs = u32_to_u16_limbs(input);
39        let expected = input >> rotation;
40        self.value = u32_to_half_word(expected);
41
42        // Compute some constants with respect to the rotation needed for the rotation.
43        let nb_limbs_to_shift = Self::nb_limbs_to_shift(rotation);
44        let nb_bits_to_shift = Self::nb_bits_to_shift(rotation);
45
46        // Perform the limb shift.
47        let mut word = [0u16; 2];
48        for i in 0..2 {
49            if i + nb_limbs_to_shift < 2 {
50                word[i] = input_limbs[i + nb_limbs_to_shift];
51            }
52        }
53
54        for i in (0..2).rev() {
55            let limb = word[i];
56            let lower_limb = (limb & ((1 << nb_bits_to_shift) - 1)) as u16;
57            let higher_limb = (limb >> nb_bits_to_shift) as u16;
58            self.higher_limb[i] = F::from_canonical_u16(higher_limb);
59            record.add_bit_range_check(lower_limb, nb_bits_to_shift as u8);
60            record.add_bit_range_check(higher_limb, (16 - nb_bits_to_shift) as u8);
61        }
62
63        expected
64    }
65
66    /// Evaluates the u32 fixed shift right. Constrains that `is_real` is boolean.
67    /// If `is_real` is true, the result `value` will be the correct result with two u16 limbs.
68    /// This function assumes that the `input` is a u32 with valid two u16 limbs.
69    pub fn eval<AB: SP1AirBuilder>(
70        builder: &mut AB,
71        input: [AB::Var; 2],
72        rotation: usize,
73        cols: FixedShiftRightOperation<AB::Var>,
74        is_real: AB::Var,
75    ) {
76        builder.assert_bool(is_real);
77
78        // Compute some constants with respect to the rotation needed for the rotation.
79        let nb_limbs_to_shift = Self::nb_limbs_to_shift(rotation);
80        let nb_bits_to_shift = Self::nb_bits_to_shift(rotation);
81        let carry_multiplier = AB::F::from_canonical_u32(Self::carry_multiplier(rotation));
82
83        // Perform the limb shift.
84        let input_limbs_shifted: [AB::Expr; 2] = std::array::from_fn(|i| {
85            if i + nb_limbs_to_shift < 2 {
86                input[i + nb_limbs_to_shift].into()
87            } else {
88                AB::Expr::zero()
89            }
90        });
91
92        // For each limb, constrain the lower and higher parts of the limb.
93        let mut lower_limb = [AB::Expr::zero(), AB::Expr::zero()];
94        for i in 0..2 {
95            let limb = input_limbs_shifted[i].clone();
96
97            // Break down the limb into lower and higher parts.
98            //  - `limb = lower_limb + higher_limb * 2^bit_shift`
99            //  - `lower_limb < 2^(bit_shift)`
100            //  - `higher_limb < 2^(16 - bit_shift)`
101            lower_limb[i] =
102                limb - cols.higher_limb[i] * AB::Expr::from_canonical_u32(1 << nb_bits_to_shift);
103
104            // Check that `lower_limb < 2^(bit_shift)`
105            builder.send_byte(
106                AB::F::from_canonical_u32(ByteOpcode::Range as u32),
107                lower_limb[i].clone(),
108                AB::F::from_canonical_u32(nb_bits_to_shift as u32),
109                AB::Expr::zero(),
110                is_real,
111            );
112            // Check that `higher_limb < 2^(16 - bit_shift)`
113            builder.send_byte(
114                AB::F::from_canonical_u32(ByteOpcode::Range as u32),
115                cols.higher_limb[i],
116                AB::Expr::from_canonical_u32(16 - nb_bits_to_shift as u32),
117                AB::Expr::zero(),
118                is_real,
119            );
120        }
121
122        // Constrain the resulting value using the lower and higher parts.
123        builder.when(is_real).assert_eq(cols.value[1], cols.higher_limb[1]);
124        builder.when(is_real).assert_eq(
125            cols.value[0],
126            cols.higher_limb[0] + lower_limb[1].clone() * carry_multiplier,
127        );
128    }
129}