sp1_core_machine/operations/
fixed_shift_right.rs

1use p3_field::{AbstractField, Field};
2use sp1_core_executor::{
3    events::{ByteLookupEvent, ByteRecord},
4    ByteOpcode,
5};
6use sp1_derive::AlignedBorrow;
7use sp1_primitives::consts::WORD_SIZE;
8use sp1_stark::{air::SP1AirBuilder, Word};
9
10use crate::bytes::utils::shr_carry;
11
12/// A set of columns needed to compute `>>` of a word with a fixed offset R.
13///
14/// Note that we decompose shifts into a byte shift and a bit shift.
15#[derive(AlignedBorrow, Default, Debug, Clone, Copy)]
16#[repr(C)]
17pub struct FixedShiftRightOperation<T> {
18    /// The output value.
19    pub value: Word<T>,
20
21    /// The shift output of `shrcarry` on each byte of a word.
22    pub shift: Word<T>,
23
24    /// The carry ouytput of `shrcarry` on each byte of a word.
25    pub carry: Word<T>,
26}
27
28impl<F: Field> FixedShiftRightOperation<F> {
29    pub const fn nb_bytes_to_shift(rotation: usize) -> usize {
30        rotation / 8
31    }
32
33    pub const fn nb_bits_to_shift(rotation: usize) -> usize {
34        rotation % 8
35    }
36
37    pub const fn carry_multiplier(rotation: usize) -> u32 {
38        let nb_bits_to_shift = Self::nb_bits_to_shift(rotation);
39        1 << (8 - nb_bits_to_shift)
40    }
41
42    pub fn populate(&mut self, record: &mut impl ByteRecord, input: u32, rotation: usize) -> u32 {
43        let input_bytes = input.to_le_bytes().map(F::from_canonical_u8);
44        let expected = input >> rotation;
45
46        // Compute some constants with respect to the rotation needed for the rotation.
47        let nb_bytes_to_shift = Self::nb_bytes_to_shift(rotation);
48        let nb_bits_to_shift = Self::nb_bits_to_shift(rotation);
49        let carry_multiplier = F::from_canonical_u32(Self::carry_multiplier(rotation));
50
51        // Perform the byte shift.
52        let mut word = [F::zero(); WORD_SIZE];
53        for i in 0..WORD_SIZE {
54            if i + nb_bytes_to_shift < WORD_SIZE {
55                word[i] = input_bytes[(i + nb_bytes_to_shift) % WORD_SIZE];
56            }
57        }
58        let input_bytes_rotated = Word(word);
59
60        // For each byte, calculate the shift and carry. If it's not the first byte, calculate the
61        // new byte value using the current shifted byte and the last carry.
62        let mut first_shift = F::zero();
63        let mut last_carry = F::zero();
64        for i in (0..WORD_SIZE).rev() {
65            let b = input_bytes_rotated[i].to_string().parse::<u8>().unwrap();
66            let c = nb_bits_to_shift as u8;
67            let (shift, carry) = shr_carry(b, c);
68            let byte_event =
69                ByteLookupEvent { opcode: ByteOpcode::ShrCarry, a1: shift as u16, a2: carry, b, c };
70            record.add_byte_lookup_event(byte_event);
71
72            self.shift[i] = F::from_canonical_u8(shift);
73            self.carry[i] = F::from_canonical_u8(carry);
74
75            if i == WORD_SIZE - 1 {
76                first_shift = self.shift[i];
77            } else {
78                self.value[i] = self.shift[i] + last_carry * carry_multiplier;
79            }
80
81            last_carry = self.carry[i];
82        }
83
84        // For the first byte, we don't move over the carry as this is a shift, not a rotate.
85        self.value[WORD_SIZE - 1] = first_shift;
86
87        // Assert the answer is correct.
88        assert_eq!(self.value.to_u32(), expected);
89
90        expected
91    }
92
93    pub fn eval<AB: SP1AirBuilder>(
94        builder: &mut AB,
95        input: Word<AB::Var>,
96        rotation: usize,
97        cols: FixedShiftRightOperation<AB::Var>,
98        is_real: AB::Var,
99    ) {
100        // Compute some constants with respect to the rotation needed for the rotation.
101        let nb_bytes_to_shift = Self::nb_bytes_to_shift(rotation);
102        let nb_bits_to_shift = Self::nb_bits_to_shift(rotation);
103        let carry_multiplier = AB::F::from_canonical_u32(Self::carry_multiplier(rotation));
104
105        // Perform the byte shift.
106        let input_bytes_rotated = Word(std::array::from_fn(|i| {
107            if i + nb_bytes_to_shift < WORD_SIZE {
108                input[(i + nb_bytes_to_shift) % WORD_SIZE].into()
109            } else {
110                AB::Expr::zero()
111            }
112        }));
113
114        // For each byte, calculate the shift and carry. If it's not the first byte, calculate the
115        // new byte value using the current shifted byte and the last carry.
116        let mut first_shift = AB::Expr::zero();
117        let mut last_carry = AB::Expr::zero();
118        for i in (0..WORD_SIZE).rev() {
119            builder.send_byte_pair(
120                AB::F::from_canonical_u32(ByteOpcode::ShrCarry as u32),
121                cols.shift[i],
122                cols.carry[i],
123                input_bytes_rotated[i].clone(),
124                AB::F::from_canonical_usize(nb_bits_to_shift),
125                is_real,
126            );
127
128            if i == WORD_SIZE - 1 {
129                first_shift = cols.shift[i].into();
130            } else {
131                builder.assert_eq(cols.value[i], cols.shift[i] + last_carry * carry_multiplier);
132            }
133
134            last_carry = cols.carry[i].into();
135        }
136
137        // For the first byte, we don't move over the carry as this is a shift, not a rotate.
138        builder.assert_eq(cols.value[WORD_SIZE - 1], first_shift);
139    }
140}