Skip to main content

sp1_core_machine/operations/
mul.rs

1use std::num::Wrapping;
2
3use crate::{
4    air::{SP1Operation, SP1OperationBuilder, WordAirBuilder},
5    operations::{U16MSBOperation, U16toU8OperationSafe, U16toU8OperationSafeInput},
6};
7use sp1_core_executor::{
8    events::{ByteLookupEvent, ByteRecord},
9    ByteOpcode,
10};
11use sp1_hypercube::{air::SP1AirBuilder, Word};
12use struct_reflection::{StructReflection, StructReflectionHelper};
13
14use slop_air::AirBuilder;
15use slop_algebra::{AbstractField, Field};
16use sp1_derive::{AlignedBorrow, InputExpr, InputParams, IntoShape, SP1OperationBuilder};
17use sp1_primitives::consts::{
18    u64_to_u16_limbs, BYTE_SIZE, LONG_WORD_BYTE_SIZE, WORD_BYTE_SIZE, WORD_SIZE,
19};
20
21use super::{U16MSBOperationInput, U16toU8Operation};
22
23/// The mask for a byte.
24const BYTE_MASK: u8 = 0xff;
25
26pub const fn get_msb(a: [u8; 8]) -> u8 {
27    ((a[7] >> (BYTE_SIZE - 1)) & 1) as u8
28}
29
30/// A set of columns needed for the MUL operations.
31#[derive(
32    AlignedBorrow, Default, Debug, Clone, Copy, IntoShape, SP1OperationBuilder, StructReflection,
33)]
34#[repr(C)]
35pub struct MulOperation<T> {
36    /// Trace.
37    pub carry: [T; LONG_WORD_BYTE_SIZE],
38
39    /// An array storing the product of `b * c` after the carry propagation.
40    pub product: [T; LONG_WORD_BYTE_SIZE],
41
42    /// The lower byte of two limbs of `b`.
43    pub b_lower_byte: U16toU8Operation<T>,
44
45    /// The lower byte of two limbs of `c`.
46    pub c_lower_byte: U16toU8Operation<T>,
47
48    /// The most significant bit of `b`.
49    pub b_msb: T,
50
51    /// The most significant bit of `c`.
52    pub c_msb: T,
53
54    /// The most significant bit of the product.
55    pub product_msb: U16MSBOperation<T>,
56
57    /// The sign extension of `b`.
58    pub b_sign_extend: T,
59
60    /// The sign extension of `c`.
61    pub c_sign_extend: T,
62}
63
64impl<F: Field> MulOperation<F> {
65    /// Populate the MUL operation from an event.
66    pub fn populate(
67        &mut self,
68        record: &mut impl ByteRecord,
69        b_u64: u64,
70        c_u64: u64,
71        is_mulh: bool,
72        is_mulhsu: bool,
73        is_mulw: bool,
74    ) {
75        let b_word = b_u64.to_le_bytes();
76        let c_word = c_u64.to_le_bytes();
77
78        let mulw_value = (Wrapping(b_u64 as i32) * Wrapping(c_u64 as i32)).0 as i64 as u64;
79        let limbs = u64_to_u16_limbs(mulw_value);
80
81        if is_mulw {
82            self.product_msb.populate_msb(record, limbs[1]);
83        } else {
84            self.product_msb.msb = F::zero();
85        }
86
87        let mut b = b_word.to_vec();
88        let mut c = c_word.to_vec();
89
90        self.b_lower_byte.populate_u16_to_u8_safe(record, b_u64);
91        self.c_lower_byte.populate_u16_to_u8_safe(record, c_u64);
92
93        // Handle b and c's signs.
94        {
95            let b_msb = get_msb(b_word);
96            self.b_msb = F::from_canonical_u8(b_msb);
97            let c_msb = get_msb(c_word);
98            self.c_msb = F::from_canonical_u8(c_msb);
99
100            // If b is signed and it is negative, sign extend b.
101            if (is_mulh || is_mulhsu) && b_msb == 1 {
102                self.b_sign_extend = F::one();
103                b.resize(LONG_WORD_BYTE_SIZE, BYTE_MASK);
104            } else {
105                self.b_sign_extend = F::zero();
106            }
107
108            // If c is signed and it is negative, sign extend c.
109            if is_mulh && c_msb == 1 {
110                self.c_sign_extend = F::one();
111                c.resize(LONG_WORD_BYTE_SIZE, BYTE_MASK);
112            } else {
113                self.c_sign_extend = F::zero();
114            }
115
116            // Insert the MSB lookup events.
117            {
118                let words = [b_word, c_word];
119                let mut blu_events: Vec<ByteLookupEvent> = vec![];
120                for word in words.iter() {
121                    let most_significant_byte = word[WORD_BYTE_SIZE - 1];
122                    blu_events.push(ByteLookupEvent {
123                        opcode: ByteOpcode::MSB,
124                        a: get_msb(*word) as u16,
125                        b: most_significant_byte,
126                        c: 0,
127                    });
128                }
129                record.add_byte_lookup_events(blu_events);
130            }
131        }
132
133        let mut product = [0u32; LONG_WORD_BYTE_SIZE];
134        for i in 0..b.len() {
135            for j in 0..c.len() {
136                if i + j < LONG_WORD_BYTE_SIZE {
137                    product[i + j] += (b[i] as u32) * (c[j] as u32);
138                }
139            }
140        }
141
142        // Calculate the correct product using the `product` array. We store the
143        // correct carry value for verification.
144        let base = (1 << BYTE_SIZE) as u32;
145        let mut carry = [0u32; LONG_WORD_BYTE_SIZE];
146        for i in 0..LONG_WORD_BYTE_SIZE {
147            carry[i] = product[i] / base;
148            product[i] %= base;
149            if i + 1 < LONG_WORD_BYTE_SIZE {
150                product[i + 1] += carry[i];
151            }
152            self.carry[i] = F::from_canonical_u32(carry[i]);
153        }
154
155        self.product = product.map(F::from_canonical_u32);
156
157        // Range check.
158        {
159            record.add_u16_range_checks(&carry.map(|x| x as u16));
160            record.add_u8_range_checks(&product.map(|x| x as u8));
161        }
162    }
163
164    /// Evaluate the MUL operation.
165    /// Assumes that `b_word`, `c_word` are valid `Word`s of u16 limbs.
166    /// Constrains that all flags are boolean.
167    /// Constrains that at most one of `is_mul`, `is_mulh`, `is_mulhu`, `is_mulhsu` is true.
168    /// If `is_real` is true, constrains that the product is correctly placed at `a_word`.
169    #[allow(clippy::too_many_arguments)]
170    pub fn eval<
171        AB: SP1AirBuilder
172            + SP1OperationBuilder<U16toU8OperationSafe>
173            + SP1OperationBuilder<U16MSBOperation<<AB as AirBuilder>::F>>,
174    >(
175        builder: &mut AB,
176        a_word: Word<AB::Expr>,
177        b_word: Word<AB::Expr>,
178        c_word: Word<AB::Expr>,
179        cols: MulOperation<AB::Var>,
180        is_real: AB::Expr,
181        is_mul: AB::Expr,
182        is_mulh: AB::Expr,
183        is_mulw: AB::Expr,
184        is_mulhu: AB::Expr,
185        is_mulhsu: AB::Expr,
186    ) {
187        let zero: AB::Expr = AB::F::zero().into();
188        let base = AB::F::from_canonical_u32(1 << 8);
189        let one: AB::Expr = AB::F::one().into();
190        let byte_mask = AB::F::from_canonical_u8(BYTE_MASK);
191
192        // Uses the safe API to convert the words into eight bytes.
193        let b_input = U16toU8OperationSafeInput::new(b_word.0, cols.b_lower_byte, is_real.clone());
194        let b = U16toU8OperationSafe::eval(builder, b_input);
195        let c_input = U16toU8OperationSafeInput::new(c_word.0, cols.c_lower_byte, is_real.clone());
196        let c = U16toU8OperationSafe::eval(builder, c_input);
197
198        // Calculate the MSBs.
199        let msb_opcode = AB::F::from_canonical_u32(ByteOpcode::MSB as u32);
200        let (b_msb, c_msb) = {
201            let msb_pairs = [
202                (cols.b_msb, b[WORD_BYTE_SIZE - 1].clone()),
203                (cols.c_msb, c[WORD_BYTE_SIZE - 1].clone()),
204            ];
205
206            for msb_pair in msb_pairs.iter() {
207                let msb = msb_pair.0;
208                let byte = msb_pair.1.clone();
209                builder.send_byte(msb_opcode, msb, byte, zero.clone(), is_real.clone());
210            }
211            (cols.b_msb, cols.c_msb)
212        };
213
214        <U16MSBOperation<AB::F> as SP1Operation<AB>>::eval(
215            builder,
216            U16MSBOperationInput::new(a_word.0[1].clone(), cols.product_msb, is_mulw.clone()),
217        );
218
219        // Calculate whether to extend b and c's sign.
220        let (b_sign_extend, c_sign_extend) = {
221            let is_b_i64 = is_mulh.clone() + is_mulhsu.clone();
222            let is_c_i64 = is_mulh.clone();
223
224            builder.assert_eq(cols.b_sign_extend, is_b_i64 * b_msb);
225            builder.assert_eq(cols.c_sign_extend, is_c_i64 * c_msb);
226            (cols.b_sign_extend, cols.c_sign_extend)
227        };
228
229        // Sign extend `b` and `c` whenever appropriate.
230        let (b, c) = {
231            let mut b_extended: Vec<AB::Expr> = vec![AB::F::zero().into(); LONG_WORD_BYTE_SIZE];
232            let mut c_extended: Vec<AB::Expr> = vec![AB::F::zero().into(); LONG_WORD_BYTE_SIZE];
233            for i in 0..LONG_WORD_BYTE_SIZE {
234                if i < WORD_BYTE_SIZE {
235                    b_extended[i] = b[i].clone();
236                    c_extended[i] = c[i].clone();
237                } else {
238                    b_extended[i] = b_sign_extend * byte_mask;
239                    c_extended[i] = c_sign_extend * byte_mask;
240                }
241            }
242            (b_extended, c_extended)
243        };
244
245        // Compute the uncarried product b(x) * c(x) = m(x).
246        let mut m: Vec<AB::Expr> = vec![AB::F::zero().into(); LONG_WORD_BYTE_SIZE];
247        for i in 0..LONG_WORD_BYTE_SIZE {
248            for j in 0..LONG_WORD_BYTE_SIZE {
249                if i + j < LONG_WORD_BYTE_SIZE {
250                    m[i + j] = m[i + j].clone() + b[i].clone() * c[j].clone();
251                }
252            }
253        }
254
255        // Propagate carry.
256        let product = {
257            for i in 0..LONG_WORD_BYTE_SIZE {
258                if i == 0 {
259                    builder
260                        .when(is_real.clone())
261                        .assert_eq(cols.product[i], m[i].clone() - cols.carry[i] * base);
262                } else {
263                    builder.when(is_real.clone()).assert_eq(
264                        cols.product[i],
265                        m[i].clone() + cols.carry[i - 1] - cols.carry[i] * base,
266                    );
267                }
268            }
269            cols.product
270        };
271
272        // Compare the product's appropriate bytes with that of the result.
273        {
274            let is_lower = is_mul.clone();
275            let is_upper = is_mulh.clone() + is_mulhu.clone() + is_mulhsu.clone();
276            let is_word = is_mulw.clone();
277            let u16_max = AB::F::from_canonical_u32((1 << 16) - 1);
278            for i in 0..WORD_SIZE {
279                if i < WORD_SIZE / 2 {
280                    builder.when(is_word.clone()).assert_eq(
281                        product[2 * i] + product[2 * i + 1] * AB::F::from_canonical_u16(1 << 8),
282                        a_word[i].clone(),
283                    );
284                } else {
285                    builder
286                        .when(is_word.clone())
287                        .assert_eq(cols.product_msb.msb * u16_max, a_word[i].clone());
288                }
289                builder.when(is_lower.clone()).assert_eq(
290                    product[2 * i] + product[2 * i + 1] * AB::F::from_canonical_u16(1 << 8),
291                    a_word[i].clone(),
292                );
293                builder.when(is_upper.clone()).assert_eq(
294                    product[2 * i + WORD_BYTE_SIZE]
295                        + product[2 * i + 1 + WORD_BYTE_SIZE] * AB::F::from_canonical_u16(1 << 8),
296                    a_word[i].clone(),
297                );
298            }
299        }
300
301        // Check that the boolean values are indeed boolean values.
302        {
303            let booleans = [
304                cols.b_msb.into(),
305                cols.c_msb.into(),
306                cols.b_sign_extend.into(),
307                cols.c_sign_extend.into(),
308                is_mul.clone(),
309                is_mulh.clone(),
310                is_mulhu.clone(),
311                is_mulhsu.clone(),
312                is_mulw.clone(),
313                is_mul.clone()
314                    + is_mulh.clone()
315                    + is_mulhu.clone()
316                    + is_mulhsu.clone()
317                    + is_mulw.clone(),
318                is_real.clone(),
319            ];
320            for boolean in booleans.iter() {
321                builder.assert_bool(boolean.clone());
322            }
323        }
324
325        // If signed extended, the MSB better be 1.
326        builder.when(cols.b_sign_extend).assert_eq(cols.b_msb, one.clone());
327        builder.when(cols.c_sign_extend).assert_eq(cols.c_msb, one.clone());
328
329        // Range check.
330        {
331            // Ensure that the carry is at most 2^16. This ensures that
332            // product_before_carry_propagation - carry * base + last_carry never overflows or
333            // underflows enough to "wrap" around to create a second solution.
334            builder.slice_range_check_u16(&cols.carry, is_real.clone());
335            builder.slice_range_check_u8(&cols.product, is_real.clone());
336        }
337    }
338}
339
340#[derive(Debug, Clone, InputExpr, InputParams)]
341pub struct MulOperationInput<AB: SP1AirBuilder> {
342    pub a_word: Word<AB::Expr>,
343    pub b_word: Word<AB::Expr>,
344    pub c_word: Word<AB::Expr>,
345    pub cols: MulOperation<AB::Var>,
346    pub is_real: AB::Expr,
347    pub is_mul: AB::Expr,
348    pub is_mulh: AB::Expr,
349    pub is_mulw: AB::Expr,
350    pub is_mulhu: AB::Expr,
351    pub is_mulhsu: AB::Expr,
352}
353
354impl<AB> SP1Operation<AB> for MulOperation<AB::F>
355where
356    AB: SP1AirBuilder
357        + SP1OperationBuilder<U16toU8OperationSafe>
358        + SP1OperationBuilder<U16MSBOperation<<AB as AirBuilder>::F>>,
359{
360    type Input = MulOperationInput<AB>;
361    type Output = ();
362
363    fn lower(builder: &mut AB, input: Self::Input) -> Self::Output {
364        Self::eval(
365            builder,
366            input.a_word,
367            input.b_word,
368            input.c_word,
369            input.cols,
370            input.is_real,
371            input.is_mul,
372            input.is_mulh,
373            input.is_mulw,
374            input.is_mulhu,
375            input.is_mulhsu,
376        );
377    }
378}