mod utils;
use core::borrow::{Borrow, BorrowMut};
use core::mem::size_of;
use p3_air::{Air, AirBuilder, BaseAir};
use p3_field::AbstractField;
use p3_field::PrimeField;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use p3_maybe_rayon::prelude::ParallelIterator;
use p3_maybe_rayon::prelude::ParallelSlice;
use sp1_derive::AlignedBorrow;
use crate::air::MachineAir;
use crate::air::{SP1AirBuilder, Word};
use crate::alu::mul::utils::get_msb;
use crate::bytes::event::ByteRecord;
use crate::bytes::{ByteLookupEvent, ByteOpcode};
use crate::disassembler::WORD_SIZE;
use crate::runtime::{ExecutionRecord, Opcode, Program};
use crate::stark::MachineRecord;
use crate::utils::pad_to_power_of_two;
pub const NUM_MUL_COLS: usize = size_of::<MulCols<u8>>();
const PRODUCT_SIZE: usize = 2 * WORD_SIZE;
const BYTE_SIZE: usize = 8;
const BYTE_MASK: u8 = 0xff;
#[derive(Default)]
pub struct MulChip;
#[derive(AlignedBorrow, Default, Debug, Clone, Copy)]
#[repr(C)]
pub struct MulCols<T> {
pub shard: T,
pub channel: T,
pub nonce: T,
pub a: Word<T>,
pub b: Word<T>,
pub c: Word<T>,
pub carry: [T; PRODUCT_SIZE],
pub product: [T; PRODUCT_SIZE],
pub b_msb: T,
pub c_msb: T,
pub b_sign_extend: T,
pub c_sign_extend: T,
pub is_mul: T,
pub is_mulh: T,
pub is_mulhu: T,
pub is_mulhsu: T,
pub is_real: T,
}
impl<F: PrimeField> MachineAir<F> for MulChip {
type Record = ExecutionRecord;
type Program = Program;
fn name(&self) -> String {
"Mul".to_string()
}
fn generate_trace(
&self,
input: &ExecutionRecord,
output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
let mul_events = input.mul_events.clone();
let chunk_size = std::cmp::max(mul_events.len() / num_cpus::get(), 1);
let rows_and_records = mul_events
.par_chunks(chunk_size)
.map(|events| {
let mut record = ExecutionRecord::default();
let rows = events
.iter()
.map(|event| {
assert!(
event.opcode == Opcode::MUL
|| event.opcode == Opcode::MULHU
|| event.opcode == Opcode::MULH
|| event.opcode == Opcode::MULHSU
);
let mut row = [F::zero(); NUM_MUL_COLS];
let cols: &mut MulCols<F> = row.as_mut_slice().borrow_mut();
let a_word = event.a.to_le_bytes();
let b_word = event.b.to_le_bytes();
let c_word = event.c.to_le_bytes();
let mut b = b_word.to_vec();
let mut c = c_word.to_vec();
{
let b_msb = get_msb(b_word);
cols.b_msb = F::from_canonical_u8(b_msb);
let c_msb = get_msb(c_word);
cols.c_msb = F::from_canonical_u8(c_msb);
if (event.opcode == Opcode::MULH || event.opcode == Opcode::MULHSU)
&& b_msb == 1
{
cols.b_sign_extend = F::one();
b.resize(PRODUCT_SIZE, BYTE_MASK);
}
if event.opcode == Opcode::MULH && c_msb == 1 {
cols.c_sign_extend = F::one();
c.resize(PRODUCT_SIZE, BYTE_MASK);
}
{
let words = [b_word, c_word];
let mut blu_events: Vec<ByteLookupEvent> = vec![];
for word in words.iter() {
let most_significant_byte = word[WORD_SIZE - 1];
blu_events.push(ByteLookupEvent {
shard: event.shard,
channel: event.channel,
opcode: ByteOpcode::MSB,
a1: get_msb(*word) as u32,
a2: 0,
b: most_significant_byte as u32,
c: 0,
});
}
record.add_byte_lookup_events(blu_events);
}
}
let mut product = [0u32; PRODUCT_SIZE];
for i in 0..b.len() {
for j in 0..c.len() {
if i + j < PRODUCT_SIZE {
product[i + j] += (b[i] as u32) * (c[j] as u32);
}
}
}
let base = 1 << BYTE_SIZE;
let mut carry = [0u32; PRODUCT_SIZE];
for i in 0..PRODUCT_SIZE {
carry[i] = product[i] / base;
product[i] %= base;
if i + 1 < PRODUCT_SIZE {
product[i + 1] += carry[i];
}
cols.carry[i] = F::from_canonical_u32(carry[i]);
}
cols.product = product.map(F::from_canonical_u32);
cols.a = Word(a_word.map(F::from_canonical_u8));
cols.b = Word(b_word.map(F::from_canonical_u8));
cols.c = Word(c_word.map(F::from_canonical_u8));
cols.is_real = F::one();
cols.is_mul = F::from_bool(event.opcode == Opcode::MUL);
cols.is_mulh = F::from_bool(event.opcode == Opcode::MULH);
cols.is_mulhu = F::from_bool(event.opcode == Opcode::MULHU);
cols.is_mulhsu = F::from_bool(event.opcode == Opcode::MULHSU);
cols.shard = F::from_canonical_u32(event.shard);
cols.channel = F::from_canonical_u32(event.channel);
{
record.add_u16_range_checks(event.shard, event.channel, &carry);
record.add_u8_range_checks(
event.shard,
event.channel,
&product.map(|x| x as u8),
);
}
row
})
.collect::<Vec<_>>();
(rows, record)
})
.collect::<Vec<_>>();
let mut rows: Vec<[F; NUM_MUL_COLS]> = vec![];
for mut row_and_record in rows_and_records {
rows.extend(row_and_record.0);
output.append(&mut row_and_record.1);
}
let mut trace =
RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_MUL_COLS);
pad_to_power_of_two::<NUM_MUL_COLS, F>(&mut trace.values);
for i in 0..trace.height() {
let cols: &mut MulCols<F> =
trace.values[i * NUM_MUL_COLS..(i + 1) * NUM_MUL_COLS].borrow_mut();
cols.nonce = F::from_canonical_usize(i);
}
trace
}
fn included(&self, shard: &Self::Record) -> bool {
!shard.mul_events.is_empty()
}
}
impl<F> BaseAir<F> for MulChip {
fn width(&self) -> usize {
NUM_MUL_COLS
}
}
impl<AB> Air<AB> for MulChip
where
AB: SP1AirBuilder,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.row_slice(0);
let local: &MulCols<AB::Var> = (*local).borrow();
let next = main.row_slice(1);
let next: &MulCols<AB::Var> = (*next).borrow();
let base = AB::F::from_canonical_u32(1 << 8);
let zero: AB::Expr = AB::F::zero().into();
let one: AB::Expr = AB::F::one().into();
let byte_mask = AB::F::from_canonical_u8(BYTE_MASK);
builder.when_first_row().assert_zero(local.nonce);
builder
.when_transition()
.assert_eq(local.nonce + AB::Expr::one(), next.nonce);
let (b_msb, c_msb) = {
let msb_pairs = [
(local.b_msb, local.b[WORD_SIZE - 1]),
(local.c_msb, local.c[WORD_SIZE - 1]),
];
let opcode = AB::F::from_canonical_u32(ByteOpcode::MSB as u32);
for msb_pair in msb_pairs.iter() {
let msb = msb_pair.0;
let byte = msb_pair.1;
builder.send_byte(
opcode,
msb,
byte,
zero.clone(),
local.shard,
local.channel,
local.is_real,
);
}
(local.b_msb, local.c_msb)
};
let (b_sign_extend, c_sign_extend) = {
let is_b_i32 = local.is_mulh + local.is_mulhsu - local.is_mulh * local.is_mulhsu;
let is_c_i32 = local.is_mulh;
builder.assert_eq(local.b_sign_extend, is_b_i32 * b_msb);
builder.assert_eq(local.c_sign_extend, is_c_i32 * c_msb);
(local.b_sign_extend, local.c_sign_extend)
};
let (b, c) = {
let mut b: Vec<AB::Expr> = vec![AB::F::zero().into(); PRODUCT_SIZE];
let mut c: Vec<AB::Expr> = vec![AB::F::zero().into(); PRODUCT_SIZE];
for i in 0..PRODUCT_SIZE {
if i < WORD_SIZE {
b[i] = local.b[i].into();
c[i] = local.c[i].into();
} else {
b[i] = b_sign_extend * byte_mask;
c[i] = c_sign_extend * byte_mask;
}
}
(b, c)
};
let mut m: Vec<AB::Expr> = vec![AB::F::zero().into(); PRODUCT_SIZE];
for i in 0..PRODUCT_SIZE {
for j in 0..PRODUCT_SIZE {
if i + j < PRODUCT_SIZE {
m[i + j] += b[i].clone() * c[j].clone();
}
}
}
let product = {
for i in 0..PRODUCT_SIZE {
if i == 0 {
builder.assert_eq(local.product[i], m[i].clone() - local.carry[i] * base);
} else {
builder.assert_eq(
local.product[i],
m[i].clone() + local.carry[i - 1] - local.carry[i] * base,
);
}
}
local.product
};
{
let is_lower = local.is_mul;
let is_upper = local.is_mulh + local.is_mulhu + local.is_mulhsu;
for i in 0..WORD_SIZE {
builder.when(is_lower).assert_eq(product[i], local.a[i]);
builder
.when(is_upper.clone())
.assert_eq(product[i + WORD_SIZE], local.a[i]);
}
}
{
let booleans = [
local.b_msb,
local.c_msb,
local.b_sign_extend,
local.c_sign_extend,
local.is_mul,
local.is_mulh,
local.is_mulhu,
local.is_mulhsu,
local.is_real,
];
for boolean in booleans.iter() {
builder.assert_bool(*boolean);
}
}
builder
.when(local.b_sign_extend)
.assert_eq(local.b_msb, one.clone());
builder
.when(local.c_sign_extend)
.assert_eq(local.c_msb, one.clone());
let opcode = {
builder
.when(local.is_real)
.assert_one(local.is_mul + local.is_mulh + local.is_mulhu + local.is_mulhsu);
let mul: AB::Expr = AB::F::from_canonical_u32(Opcode::MUL as u32).into();
let mulh: AB::Expr = AB::F::from_canonical_u32(Opcode::MULH as u32).into();
let mulhu: AB::Expr = AB::F::from_canonical_u32(Opcode::MULHU as u32).into();
let mulhsu: AB::Expr = AB::F::from_canonical_u32(Opcode::MULHSU as u32).into();
local.is_mul * mul
+ local.is_mulh * mulh
+ local.is_mulhu * mulhu
+ local.is_mulhsu * mulhsu
};
{
builder.slice_range_check_u16(&local.carry, local.shard, local.channel, local.is_real);
builder.slice_range_check_u8(&local.product, local.shard, local.channel, local.is_real);
}
builder.receive_alu(
opcode,
local.a,
local.b,
local.c,
local.shard,
local.channel,
local.nonce,
local.is_real,
);
}
}
#[cfg(test)]
mod tests {
use crate::{
air::MachineAir,
stark::StarkGenericConfig,
utils::{uni_stark_prove as prove, uni_stark_verify as verify},
};
use p3_baby_bear::BabyBear;
use p3_matrix::dense::RowMajorMatrix;
use crate::{
alu::AluEvent,
runtime::{ExecutionRecord, Opcode},
utils::BabyBearPoseidon2,
};
use super::MulChip;
#[test]
fn generate_trace_mul() {
let mut shard = ExecutionRecord::default();
let mut mul_events: Vec<AluEvent> = Vec::new();
for _ in 0..10i32.pow(7) {
mul_events.push(AluEvent::new(
0,
0,
0,
Opcode::MULHSU,
0x80004000,
0x80000000,
0xffff8000,
));
}
shard.mul_events = mul_events;
let chip = MulChip::default();
let _trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&shard, &mut ExecutionRecord::default());
}
#[test]
fn prove_babybear() {
let config = BabyBearPoseidon2::new();
let mut challenger = config.challenger();
let mut shard = ExecutionRecord::default();
let mut mul_events: Vec<AluEvent> = Vec::new();
let mul_instructions: Vec<(Opcode, u32, u32, u32)> = vec![
(Opcode::MUL, 0x00001200, 0x00007e00, 0xb6db6db7),
(Opcode::MUL, 0x00001240, 0x00007fc0, 0xb6db6db7),
(Opcode::MUL, 0x00000000, 0x00000000, 0x00000000),
(Opcode::MUL, 0x00000001, 0x00000001, 0x00000001),
(Opcode::MUL, 0x00000015, 0x00000003, 0x00000007),
(Opcode::MUL, 0x00000000, 0x00000000, 0xffff8000),
(Opcode::MUL, 0x00000000, 0x80000000, 0x00000000),
(Opcode::MUL, 0x00000000, 0x80000000, 0xffff8000),
(Opcode::MUL, 0x0000ff7f, 0xaaaaaaab, 0x0002fe7d),
(Opcode::MUL, 0x0000ff7f, 0x0002fe7d, 0xaaaaaaab),
(Opcode::MUL, 0x00000000, 0xff000000, 0xff000000),
(Opcode::MUL, 0x00000001, 0xffffffff, 0xffffffff),
(Opcode::MUL, 0xffffffff, 0xffffffff, 0x00000001),
(Opcode::MUL, 0xffffffff, 0x00000001, 0xffffffff),
(Opcode::MULHU, 0x00000000, 0x00000000, 0x00000000),
(Opcode::MULHU, 0x00000000, 0x00000001, 0x00000001),
(Opcode::MULHU, 0x00000000, 0x00000003, 0x00000007),
(Opcode::MULHU, 0x00000000, 0x00000000, 0xffff8000),
(Opcode::MULHU, 0x00000000, 0x80000000, 0x00000000),
(Opcode::MULHU, 0x7fffc000, 0x80000000, 0xffff8000),
(Opcode::MULHU, 0x0001fefe, 0xaaaaaaab, 0x0002fe7d),
(Opcode::MULHU, 0x0001fefe, 0x0002fe7d, 0xaaaaaaab),
(Opcode::MULHU, 0xfe010000, 0xff000000, 0xff000000),
(Opcode::MULHU, 0xfffffffe, 0xffffffff, 0xffffffff),
(Opcode::MULHU, 0x00000000, 0xffffffff, 0x00000001),
(Opcode::MULHU, 0x00000000, 0x00000001, 0xffffffff),
(Opcode::MULHSU, 0x00000000, 0x00000000, 0x00000000),
(Opcode::MULHSU, 0x00000000, 0x00000001, 0x00000001),
(Opcode::MULHSU, 0x00000000, 0x00000003, 0x00000007),
(Opcode::MULHSU, 0x00000000, 0x00000000, 0xffff8000),
(Opcode::MULHSU, 0x00000000, 0x80000000, 0x00000000),
(Opcode::MULHSU, 0x80004000, 0x80000000, 0xffff8000),
(Opcode::MULHSU, 0xffff0081, 0xaaaaaaab, 0x0002fe7d),
(Opcode::MULHSU, 0x0001fefe, 0x0002fe7d, 0xaaaaaaab),
(Opcode::MULHSU, 0xff010000, 0xff000000, 0xff000000),
(Opcode::MULHSU, 0xffffffff, 0xffffffff, 0xffffffff),
(Opcode::MULHSU, 0xffffffff, 0xffffffff, 0x00000001),
(Opcode::MULHSU, 0x00000000, 0x00000001, 0xffffffff),
(Opcode::MULH, 0x00000000, 0x00000000, 0x00000000),
(Opcode::MULH, 0x00000000, 0x00000001, 0x00000001),
(Opcode::MULH, 0x00000000, 0x00000003, 0x00000007),
(Opcode::MULH, 0x00000000, 0x00000000, 0xffff8000),
(Opcode::MULH, 0x00000000, 0x80000000, 0x00000000),
(Opcode::MULH, 0x00000000, 0x80000000, 0x00000000),
(Opcode::MULH, 0xffff0081, 0xaaaaaaab, 0x0002fe7d),
(Opcode::MULH, 0xffff0081, 0x0002fe7d, 0xaaaaaaab),
(Opcode::MULH, 0x00010000, 0xff000000, 0xff000000),
(Opcode::MULH, 0x00000000, 0xffffffff, 0xffffffff),
(Opcode::MULH, 0xffffffff, 0xffffffff, 0x00000001),
(Opcode::MULH, 0xffffffff, 0x00000001, 0xffffffff),
];
for t in mul_instructions.iter() {
mul_events.push(AluEvent::new(0, 0, 0, t.0, t.1, t.2, t.3));
}
for _ in 0..(1000 - mul_instructions.len()) {
mul_events.push(AluEvent::new(0, 0, 0, Opcode::MUL, 1, 1, 1));
}
shard.mul_events = mul_events;
let chip = MulChip::default();
let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&shard, &mut ExecutionRecord::default());
let proof = prove::<BabyBearPoseidon2, _>(&config, &chip, &mut challenger, trace);
let mut challenger = config.challenger();
verify(&config, &chip, &mut challenger, &proof).unwrap();
}
}