sp1_core_machine/operations/
mul.rs1use std::num::Wrapping;
2
3use crate::{
4 air::{SP1Operation, SP1OperationBuilder, WordAirBuilder},
5 operations::{U16MSBOperation, U16toU8OperationSafe, U16toU8OperationSafeInput},
6};
7use sp1_core_executor::{
8 events::{ByteLookupEvent, ByteRecord},
9 ByteOpcode,
10};
11use sp1_hypercube::{air::SP1AirBuilder, Word};
12use struct_reflection::{StructReflection, StructReflectionHelper};
13
14use slop_air::AirBuilder;
15use slop_algebra::{AbstractField, Field};
16use sp1_derive::{AlignedBorrow, InputExpr, InputParams, IntoShape, SP1OperationBuilder};
17use sp1_primitives::consts::{
18 u64_to_u16_limbs, BYTE_SIZE, LONG_WORD_BYTE_SIZE, WORD_BYTE_SIZE, WORD_SIZE,
19};
20
21use super::{U16MSBOperationInput, U16toU8Operation};
22
23const BYTE_MASK: u8 = 0xff;
25
26pub const fn get_msb(a: [u8; 8]) -> u8 {
27 ((a[7] >> (BYTE_SIZE - 1)) & 1) as u8
28}
29
30#[derive(
32 AlignedBorrow, Default, Debug, Clone, Copy, IntoShape, SP1OperationBuilder, StructReflection,
33)]
34#[repr(C)]
35pub struct MulOperation<T> {
36 pub carry: [T; LONG_WORD_BYTE_SIZE],
38
39 pub product: [T; LONG_WORD_BYTE_SIZE],
41
42 pub b_lower_byte: U16toU8Operation<T>,
44
45 pub c_lower_byte: U16toU8Operation<T>,
47
48 pub b_msb: T,
50
51 pub c_msb: T,
53
54 pub product_msb: U16MSBOperation<T>,
56
57 pub b_sign_extend: T,
59
60 pub c_sign_extend: T,
62}
63
64impl<F: Field> MulOperation<F> {
65 pub fn populate(
67 &mut self,
68 record: &mut impl ByteRecord,
69 b_u64: u64,
70 c_u64: u64,
71 is_mulh: bool,
72 is_mulhsu: bool,
73 is_mulw: bool,
74 ) {
75 let b_word = b_u64.to_le_bytes();
76 let c_word = c_u64.to_le_bytes();
77
78 let mulw_value = (Wrapping(b_u64 as i32) * Wrapping(c_u64 as i32)).0 as i64 as u64;
79 let limbs = u64_to_u16_limbs(mulw_value);
80
81 if is_mulw {
82 self.product_msb.populate_msb(record, limbs[1]);
83 } else {
84 self.product_msb.msb = F::zero();
85 }
86
87 let mut b = b_word.to_vec();
88 let mut c = c_word.to_vec();
89
90 self.b_lower_byte.populate_u16_to_u8_safe(record, b_u64);
91 self.c_lower_byte.populate_u16_to_u8_safe(record, c_u64);
92
93 {
95 let b_msb = get_msb(b_word);
96 self.b_msb = F::from_canonical_u8(b_msb);
97 let c_msb = get_msb(c_word);
98 self.c_msb = F::from_canonical_u8(c_msb);
99
100 if (is_mulh || is_mulhsu) && b_msb == 1 {
102 self.b_sign_extend = F::one();
103 b.resize(LONG_WORD_BYTE_SIZE, BYTE_MASK);
104 } else {
105 self.b_sign_extend = F::zero();
106 }
107
108 if is_mulh && c_msb == 1 {
110 self.c_sign_extend = F::one();
111 c.resize(LONG_WORD_BYTE_SIZE, BYTE_MASK);
112 } else {
113 self.c_sign_extend = F::zero();
114 }
115
116 {
118 let words = [b_word, c_word];
119 let mut blu_events: Vec<ByteLookupEvent> = vec![];
120 for word in words.iter() {
121 let most_significant_byte = word[WORD_BYTE_SIZE - 1];
122 blu_events.push(ByteLookupEvent {
123 opcode: ByteOpcode::MSB,
124 a: get_msb(*word) as u16,
125 b: most_significant_byte,
126 c: 0,
127 });
128 }
129 record.add_byte_lookup_events(blu_events);
130 }
131 }
132
133 let mut product = [0u32; LONG_WORD_BYTE_SIZE];
134 for i in 0..b.len() {
135 for j in 0..c.len() {
136 if i + j < LONG_WORD_BYTE_SIZE {
137 product[i + j] += (b[i] as u32) * (c[j] as u32);
138 }
139 }
140 }
141
142 let base = (1 << BYTE_SIZE) as u32;
145 let mut carry = [0u32; LONG_WORD_BYTE_SIZE];
146 for i in 0..LONG_WORD_BYTE_SIZE {
147 carry[i] = product[i] / base;
148 product[i] %= base;
149 if i + 1 < LONG_WORD_BYTE_SIZE {
150 product[i + 1] += carry[i];
151 }
152 self.carry[i] = F::from_canonical_u32(carry[i]);
153 }
154
155 self.product = product.map(F::from_canonical_u32);
156
157 {
159 record.add_u16_range_checks(&carry.map(|x| x as u16));
160 record.add_u8_range_checks(&product.map(|x| x as u8));
161 }
162 }
163
164 #[allow(clippy::too_many_arguments)]
170 pub fn eval<
171 AB: SP1AirBuilder
172 + SP1OperationBuilder<U16toU8OperationSafe>
173 + SP1OperationBuilder<U16MSBOperation<<AB as AirBuilder>::F>>,
174 >(
175 builder: &mut AB,
176 a_word: Word<AB::Expr>,
177 b_word: Word<AB::Expr>,
178 c_word: Word<AB::Expr>,
179 cols: MulOperation<AB::Var>,
180 is_real: AB::Expr,
181 is_mul: AB::Expr,
182 is_mulh: AB::Expr,
183 is_mulw: AB::Expr,
184 is_mulhu: AB::Expr,
185 is_mulhsu: AB::Expr,
186 ) {
187 let zero: AB::Expr = AB::F::zero().into();
188 let base = AB::F::from_canonical_u32(1 << 8);
189 let one: AB::Expr = AB::F::one().into();
190 let byte_mask = AB::F::from_canonical_u8(BYTE_MASK);
191
192 let b_input = U16toU8OperationSafeInput::new(b_word.0, cols.b_lower_byte, is_real.clone());
194 let b = U16toU8OperationSafe::eval(builder, b_input);
195 let c_input = U16toU8OperationSafeInput::new(c_word.0, cols.c_lower_byte, is_real.clone());
196 let c = U16toU8OperationSafe::eval(builder, c_input);
197
198 let msb_opcode = AB::F::from_canonical_u32(ByteOpcode::MSB as u32);
200 let (b_msb, c_msb) = {
201 let msb_pairs = [
202 (cols.b_msb, b[WORD_BYTE_SIZE - 1].clone()),
203 (cols.c_msb, c[WORD_BYTE_SIZE - 1].clone()),
204 ];
205
206 for msb_pair in msb_pairs.iter() {
207 let msb = msb_pair.0;
208 let byte = msb_pair.1.clone();
209 builder.send_byte(msb_opcode, msb, byte, zero.clone(), is_real.clone());
210 }
211 (cols.b_msb, cols.c_msb)
212 };
213
214 <U16MSBOperation<AB::F> as SP1Operation<AB>>::eval(
215 builder,
216 U16MSBOperationInput::new(a_word.0[1].clone(), cols.product_msb, is_mulw.clone()),
217 );
218
219 let (b_sign_extend, c_sign_extend) = {
221 let is_b_i64 = is_mulh.clone() + is_mulhsu.clone();
222 let is_c_i64 = is_mulh.clone();
223
224 builder.assert_eq(cols.b_sign_extend, is_b_i64 * b_msb);
225 builder.assert_eq(cols.c_sign_extend, is_c_i64 * c_msb);
226 (cols.b_sign_extend, cols.c_sign_extend)
227 };
228
229 let (b, c) = {
231 let mut b_extended: Vec<AB::Expr> = vec![AB::F::zero().into(); LONG_WORD_BYTE_SIZE];
232 let mut c_extended: Vec<AB::Expr> = vec![AB::F::zero().into(); LONG_WORD_BYTE_SIZE];
233 for i in 0..LONG_WORD_BYTE_SIZE {
234 if i < WORD_BYTE_SIZE {
235 b_extended[i] = b[i].clone();
236 c_extended[i] = c[i].clone();
237 } else {
238 b_extended[i] = b_sign_extend * byte_mask;
239 c_extended[i] = c_sign_extend * byte_mask;
240 }
241 }
242 (b_extended, c_extended)
243 };
244
245 let mut m: Vec<AB::Expr> = vec![AB::F::zero().into(); LONG_WORD_BYTE_SIZE];
247 for i in 0..LONG_WORD_BYTE_SIZE {
248 for j in 0..LONG_WORD_BYTE_SIZE {
249 if i + j < LONG_WORD_BYTE_SIZE {
250 m[i + j] = m[i + j].clone() + b[i].clone() * c[j].clone();
251 }
252 }
253 }
254
255 let product = {
257 for i in 0..LONG_WORD_BYTE_SIZE {
258 if i == 0 {
259 builder
260 .when(is_real.clone())
261 .assert_eq(cols.product[i], m[i].clone() - cols.carry[i] * base);
262 } else {
263 builder.when(is_real.clone()).assert_eq(
264 cols.product[i],
265 m[i].clone() + cols.carry[i - 1] - cols.carry[i] * base,
266 );
267 }
268 }
269 cols.product
270 };
271
272 {
274 let is_lower = is_mul.clone();
275 let is_upper = is_mulh.clone() + is_mulhu.clone() + is_mulhsu.clone();
276 let is_word = is_mulw.clone();
277 let u16_max = AB::F::from_canonical_u32((1 << 16) - 1);
278 for i in 0..WORD_SIZE {
279 if i < WORD_SIZE / 2 {
280 builder.when(is_word.clone()).assert_eq(
281 product[2 * i] + product[2 * i + 1] * AB::F::from_canonical_u16(1 << 8),
282 a_word[i].clone(),
283 );
284 } else {
285 builder
286 .when(is_word.clone())
287 .assert_eq(cols.product_msb.msb * u16_max, a_word[i].clone());
288 }
289 builder.when(is_lower.clone()).assert_eq(
290 product[2 * i] + product[2 * i + 1] * AB::F::from_canonical_u16(1 << 8),
291 a_word[i].clone(),
292 );
293 builder.when(is_upper.clone()).assert_eq(
294 product[2 * i + WORD_BYTE_SIZE]
295 + product[2 * i + 1 + WORD_BYTE_SIZE] * AB::F::from_canonical_u16(1 << 8),
296 a_word[i].clone(),
297 );
298 }
299 }
300
301 {
303 let booleans = [
304 cols.b_msb.into(),
305 cols.c_msb.into(),
306 cols.b_sign_extend.into(),
307 cols.c_sign_extend.into(),
308 is_mul.clone(),
309 is_mulh.clone(),
310 is_mulhu.clone(),
311 is_mulhsu.clone(),
312 is_mulw.clone(),
313 is_mul.clone()
314 + is_mulh.clone()
315 + is_mulhu.clone()
316 + is_mulhsu.clone()
317 + is_mulw.clone(),
318 is_real.clone(),
319 ];
320 for boolean in booleans.iter() {
321 builder.assert_bool(boolean.clone());
322 }
323 }
324
325 builder.when(cols.b_sign_extend).assert_eq(cols.b_msb, one.clone());
327 builder.when(cols.c_sign_extend).assert_eq(cols.c_msb, one.clone());
328
329 {
331 builder.slice_range_check_u16(&cols.carry, is_real.clone());
335 builder.slice_range_check_u8(&cols.product, is_real.clone());
336 }
337 }
338}
339
340#[derive(Debug, Clone, InputExpr, InputParams)]
341pub struct MulOperationInput<AB: SP1AirBuilder> {
342 pub a_word: Word<AB::Expr>,
343 pub b_word: Word<AB::Expr>,
344 pub c_word: Word<AB::Expr>,
345 pub cols: MulOperation<AB::Var>,
346 pub is_real: AB::Expr,
347 pub is_mul: AB::Expr,
348 pub is_mulh: AB::Expr,
349 pub is_mulw: AB::Expr,
350 pub is_mulhu: AB::Expr,
351 pub is_mulhsu: AB::Expr,
352}
353
354impl<AB> SP1Operation<AB> for MulOperation<AB::F>
355where
356 AB: SP1AirBuilder
357 + SP1OperationBuilder<U16toU8OperationSafe>
358 + SP1OperationBuilder<U16MSBOperation<<AB as AirBuilder>::F>>,
359{
360 type Input = MulOperationInput<AB>;
361 type Output = ();
362
363 fn lower(builder: &mut AB, input: Self::Input) -> Self::Output {
364 Self::eval(
365 builder,
366 input.a_word,
367 input.b_word,
368 input.c_word,
369 input.cols,
370 input.is_real,
371 input.is_mul,
372 input.is_mulh,
373 input.is_mulw,
374 input.is_mulhu,
375 input.is_mulhsu,
376 );
377 }
378}