sp1_core_machine/operations/
add4.rs

1use p3_air::AirBuilder;
2use p3_field::{AbstractField, Field};
3use sp1_derive::AlignedBorrow;
4
5use sp1_core_executor::events::ByteRecord;
6use sp1_primitives::consts::WORD_SIZE;
7use sp1_stark::{air::SP1AirBuilder, Word};
8
9use crate::air::WordAirBuilder;
10
11/// A set of columns needed to compute the add of four words.
12#[derive(AlignedBorrow, Default, Debug, Clone, Copy)]
13#[repr(C)]
14pub struct Add4Operation<T> {
15    /// The result of `a + b + c + d`.
16    pub value: Word<T>,
17
18    /// Indicates if the carry for the `i`th digit is 0.
19    pub is_carry_0: Word<T>,
20
21    /// Indicates if the carry for the `i`th digit is 1.
22    pub is_carry_1: Word<T>,
23
24    /// Indicates if the carry for the `i`th digit is 2.
25    pub is_carry_2: Word<T>,
26
27    /// Indicates if the carry for the `i`th digit is 3. The carry when adding 4 words is at most
28    /// 3.
29    pub is_carry_3: Word<T>,
30
31    /// The carry for the `i`th digit.
32    pub carry: Word<T>,
33}
34
35impl<F: Field> Add4Operation<F> {
36    #[allow(clippy::too_many_arguments)]
37    pub fn populate(
38        &mut self,
39        record: &mut impl ByteRecord,
40        a_u32: u32,
41        b_u32: u32,
42        c_u32: u32,
43        d_u32: u32,
44    ) -> u32 {
45        let expected = a_u32.wrapping_add(b_u32).wrapping_add(c_u32).wrapping_add(d_u32);
46        self.value = Word::from(expected);
47        let a = a_u32.to_le_bytes();
48        let b = b_u32.to_le_bytes();
49        let c = c_u32.to_le_bytes();
50        let d = d_u32.to_le_bytes();
51
52        let base = 256;
53        let mut carry = [0u8, 0u8, 0u8, 0u8];
54        for i in 0..WORD_SIZE {
55            let mut res = (a[i] as u32) + (b[i] as u32) + (c[i] as u32) + (d[i] as u32);
56            if i > 0 {
57                res += carry[i - 1] as u32;
58            }
59            carry[i] = (res / base) as u8;
60            self.is_carry_0[i] = F::from_bool(carry[i] == 0);
61            self.is_carry_1[i] = F::from_bool(carry[i] == 1);
62            self.is_carry_2[i] = F::from_bool(carry[i] == 2);
63            self.is_carry_3[i] = F::from_bool(carry[i] == 3);
64            self.carry[i] = F::from_canonical_u8(carry[i]);
65            debug_assert!(carry[i] <= 3);
66            debug_assert_eq!(self.value[i], F::from_canonical_u32(res % base));
67        }
68
69        // Range check.
70        {
71            record.add_u8_range_checks(&a);
72            record.add_u8_range_checks(&b);
73            record.add_u8_range_checks(&c);
74            record.add_u8_range_checks(&d);
75            record.add_u8_range_checks(&expected.to_le_bytes());
76        }
77        expected
78    }
79
80    #[allow(clippy::too_many_arguments)]
81    pub fn eval<AB: SP1AirBuilder>(
82        builder: &mut AB,
83        a: Word<AB::Var>,
84        b: Word<AB::Var>,
85        c: Word<AB::Var>,
86        d: Word<AB::Var>,
87        is_real: AB::Var,
88        cols: Add4Operation<AB::Var>,
89    ) {
90        // Range check each byte.
91        {
92            builder.slice_range_check_u8(&a.0, is_real);
93            builder.slice_range_check_u8(&b.0, is_real);
94            builder.slice_range_check_u8(&c.0, is_real);
95            builder.slice_range_check_u8(&d.0, is_real);
96            builder.slice_range_check_u8(&cols.value.0, is_real);
97        }
98
99        builder.assert_bool(is_real);
100        let mut builder_is_real = builder.when(is_real);
101
102        // Each value in is_carry_{0,1,2,3} is 0 or 1, and exactly one of them is 1 per digit.
103        {
104            for i in 0..WORD_SIZE {
105                builder_is_real.assert_bool(cols.is_carry_0[i]);
106                builder_is_real.assert_bool(cols.is_carry_1[i]);
107                builder_is_real.assert_bool(cols.is_carry_2[i]);
108                builder_is_real.assert_bool(cols.is_carry_3[i]);
109                builder_is_real.assert_eq(
110                    cols.is_carry_0[i] +
111                        cols.is_carry_1[i] +
112                        cols.is_carry_2[i] +
113                        cols.is_carry_3[i],
114                    AB::Expr::one(),
115                );
116            }
117        }
118
119        // Calculates carry from is_carry_{0,1,2,3}.
120        {
121            let one = AB::Expr::one();
122            let two = AB::F::from_canonical_u32(2);
123            let three = AB::F::from_canonical_u32(3);
124
125            for i in 0..WORD_SIZE {
126                builder_is_real.assert_eq(
127                    cols.carry[i],
128                    cols.is_carry_1[i] * one.clone() +
129                        cols.is_carry_2[i] * two +
130                        cols.is_carry_3[i] * three,
131                );
132            }
133        }
134
135        // Compare the sum and summands by looking at carry.
136        {
137            let base = AB::F::from_canonical_u32(256);
138            // For each limb, assert that difference between the carried result and the non-carried
139            // result is the product of carry and base.
140            for i in 0..WORD_SIZE {
141                let mut overflow = a[i] + b[i] + c[i] + d[i] - cols.value[i];
142                if i > 0 {
143                    overflow = overflow.clone() + cols.carry[i - 1].into();
144                }
145                builder_is_real.assert_eq(cols.carry[i] * base, overflow.clone());
146            }
147        }
148    }
149}