mod utils;
use core::borrow::{Borrow, BorrowMut};
use core::mem::size_of;
use hashbrown::HashMap;
use itertools::Itertools;
use p3_air::{Air, AirBuilder, BaseAir};
use p3_field::AbstractField;
use p3_field::PrimeField;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice};
use sp1_derive::AlignedBorrow;
use crate::air::MachineAir;
use crate::air::{SP1AirBuilder, Word};
use crate::alu::sr::utils::{nb_bits_to_shift, nb_bytes_to_shift};
use crate::bytes::event::ByteRecord;
use crate::bytes::utils::shr_carry;
use crate::bytes::{ByteLookupEvent, ByteOpcode};
use crate::disassembler::WORD_SIZE;
use crate::runtime::{ExecutionRecord, Opcode, Program};
use crate::utils::pad_to_power_of_two;
use super::AluEvent;
pub const NUM_SHIFT_RIGHT_COLS: usize = size_of::<ShiftRightCols<u8>>();
const LONG_WORD_SIZE: usize = 2 * WORD_SIZE;
const BYTE_SIZE: usize = 8;
#[derive(Default)]
pub struct ShiftRightChip;
#[derive(AlignedBorrow, Default, Debug, Clone, Copy)]
#[repr(C)]
pub struct ShiftRightCols<T> {
pub shard: T,
pub channel: T,
pub nonce: T,
pub a: Word<T>,
pub b: Word<T>,
pub c: Word<T>,
pub shift_by_n_bits: [T; BYTE_SIZE],
pub shift_by_n_bytes: [T; WORD_SIZE],
pub byte_shift_result: [T; LONG_WORD_SIZE],
pub bit_shift_result: [T; LONG_WORD_SIZE],
pub shr_carry_output_carry: [T; LONG_WORD_SIZE],
pub shr_carry_output_shifted_byte: [T; LONG_WORD_SIZE],
pub b_msb: T,
pub c_least_sig_byte: [T; BYTE_SIZE],
pub is_srl: T,
pub is_sra: T,
pub is_real: T,
}
impl<F: PrimeField> MachineAir<F> for ShiftRightChip {
type Record = ExecutionRecord;
type Program = Program;
fn name(&self) -> String {
"ShiftRight".to_string()
}
fn generate_trace(
&self,
input: &ExecutionRecord,
_: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
let mut rows: Vec<[F; NUM_SHIFT_RIGHT_COLS]> = Vec::new();
let sr_events = input.shift_right_events.clone();
for event in sr_events.iter() {
assert!(event.opcode == Opcode::SRL || event.opcode == Opcode::SRA);
let mut row = [F::zero(); NUM_SHIFT_RIGHT_COLS];
let cols: &mut ShiftRightCols<F> = row.as_mut_slice().borrow_mut();
let mut blu = Vec::new();
self.event_to_row(event, cols, &mut blu);
rows.push(row);
}
let mut trace = RowMajorMatrix::new(
rows.into_iter().flatten().collect::<Vec<_>>(),
NUM_SHIFT_RIGHT_COLS,
);
pad_to_power_of_two::<NUM_SHIFT_RIGHT_COLS, F>(&mut trace.values);
let padded_row_template = {
let mut row = [F::zero(); NUM_SHIFT_RIGHT_COLS];
let cols: &mut ShiftRightCols<F> = row.as_mut_slice().borrow_mut();
cols.shift_by_n_bits[0] = F::one();
cols.shift_by_n_bytes[0] = F::one();
row
};
debug_assert!(padded_row_template.len() == NUM_SHIFT_RIGHT_COLS);
for i in input.shift_right_events.len() * NUM_SHIFT_RIGHT_COLS..trace.values.len() {
trace.values[i] = padded_row_template[i % NUM_SHIFT_RIGHT_COLS];
}
for i in 0..trace.height() {
let cols: &mut ShiftRightCols<F> =
trace.values[i * NUM_SHIFT_RIGHT_COLS..(i + 1) * NUM_SHIFT_RIGHT_COLS].borrow_mut();
cols.nonce = F::from_canonical_usize(i);
}
trace
}
fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
let chunk_size = std::cmp::max(input.shift_right_events.len() / num_cpus::get(), 1);
let blu_batches = input
.shift_right_events
.par_chunks(chunk_size)
.map(|events| {
let mut blu: HashMap<u32, HashMap<ByteLookupEvent, usize>> = HashMap::new();
events.iter().for_each(|event| {
let mut row = [F::zero(); NUM_SHIFT_RIGHT_COLS];
let cols: &mut ShiftRightCols<F> = row.as_mut_slice().borrow_mut();
self.event_to_row(event, cols, &mut blu);
});
blu
})
.collect::<Vec<_>>();
output.add_sharded_byte_lookup_events(blu_batches.iter().collect_vec());
}
fn included(&self, shard: &Self::Record) -> bool {
!shard.shift_right_events.is_empty()
}
}
impl ShiftRightChip {
fn event_to_row<F: PrimeField>(
&self,
event: &AluEvent,
cols: &mut ShiftRightCols<F>,
blu: &mut impl ByteRecord,
) {
{
cols.shard = F::from_canonical_u32(event.shard);
cols.channel = F::from_canonical_u8(event.channel);
cols.a = Word::from(event.a);
cols.b = Word::from(event.b);
cols.c = Word::from(event.c);
cols.b_msb = F::from_canonical_u32((event.b >> 31) & 1);
cols.is_srl = F::from_bool(event.opcode == Opcode::SRL);
cols.is_sra = F::from_bool(event.opcode == Opcode::SRA);
cols.is_real = F::one();
for i in 0..BYTE_SIZE {
cols.c_least_sig_byte[i] = F::from_canonical_u32((event.c >> i) & 1);
}
let most_significant_byte = event.b.to_le_bytes()[WORD_SIZE - 1];
blu.add_byte_lookup_events(vec![ByteLookupEvent {
shard: event.shard,
channel: event.channel,
opcode: ByteOpcode::MSB,
a1: ((most_significant_byte >> 7) & 1) as u16,
a2: 0,
b: most_significant_byte,
c: 0,
}]);
}
let num_bytes_to_shift = nb_bytes_to_shift(event.c);
let num_bits_to_shift = nb_bits_to_shift(event.c);
let mut byte_shift_result = [0u8; LONG_WORD_SIZE];
{
for i in 0..WORD_SIZE {
cols.shift_by_n_bytes[i] = F::from_bool(num_bytes_to_shift == i);
}
let sign_extended_b = {
if event.opcode == Opcode::SRA {
((event.b as i32) as i64).to_le_bytes()
} else {
(event.b as u64).to_le_bytes()
}
};
for i in 0..LONG_WORD_SIZE {
if i + num_bytes_to_shift < LONG_WORD_SIZE {
byte_shift_result[i] = sign_extended_b[i + num_bytes_to_shift];
}
}
cols.byte_shift_result = byte_shift_result.map(F::from_canonical_u8);
}
{
for i in 0..BYTE_SIZE {
cols.shift_by_n_bits[i] = F::from_bool(num_bits_to_shift == i);
}
let carry_multiplier = 1 << (8 - num_bits_to_shift);
let mut last_carry = 0u32;
let mut bit_shift_result = [0u8; LONG_WORD_SIZE];
let mut shr_carry_output_carry = [0u8; LONG_WORD_SIZE];
let mut shr_carry_output_shifted_byte = [0u8; LONG_WORD_SIZE];
for i in (0..LONG_WORD_SIZE).rev() {
let (shift, carry) = shr_carry(byte_shift_result[i], num_bits_to_shift as u8);
let byte_event = ByteLookupEvent {
shard: event.shard,
channel: event.channel,
opcode: ByteOpcode::ShrCarry,
a1: shift as u16,
a2: carry,
b: byte_shift_result[i],
c: num_bits_to_shift as u8,
};
blu.add_byte_lookup_event(byte_event);
shr_carry_output_carry[i] = carry;
shr_carry_output_shifted_byte[i] = shift;
bit_shift_result[i] = ((shift as u32 + last_carry * carry_multiplier) & 0xff) as u8;
last_carry = carry as u32;
}
cols.bit_shift_result = bit_shift_result.map(F::from_canonical_u8);
cols.shr_carry_output_carry = shr_carry_output_carry.map(F::from_canonical_u8);
cols.shr_carry_output_shifted_byte =
shr_carry_output_shifted_byte.map(F::from_canonical_u8);
for i in 0..WORD_SIZE {
debug_assert_eq!(cols.a[i], cols.bit_shift_result[i].clone());
}
blu.add_u8_range_checks(event.shard, event.channel, &byte_shift_result);
blu.add_u8_range_checks(event.shard, event.channel, &bit_shift_result);
blu.add_u8_range_checks(event.shard, event.channel, &shr_carry_output_carry);
blu.add_u8_range_checks(event.shard, event.channel, &shr_carry_output_shifted_byte);
}
}
}
impl<F> BaseAir<F> for ShiftRightChip {
fn width(&self) -> usize {
NUM_SHIFT_RIGHT_COLS
}
}
impl<AB> Air<AB> for ShiftRightChip
where
AB: SP1AirBuilder,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.row_slice(0);
let local: &ShiftRightCols<AB::Var> = (*local).borrow();
let next = main.row_slice(1);
let next: &ShiftRightCols<AB::Var> = (*next).borrow();
let zero: AB::Expr = AB::F::zero().into();
let one: AB::Expr = AB::F::one().into();
builder.when_first_row().assert_zero(local.nonce);
builder
.when_transition()
.assert_eq(local.nonce + AB::Expr::one(), next.nonce);
{
let byte = local.b[WORD_SIZE - 1];
let opcode = AB::F::from_canonical_u32(ByteOpcode::MSB as u32);
let msb = local.b_msb;
builder.send_byte(
opcode,
msb,
byte,
zero.clone(),
local.shard,
local.channel,
local.is_real,
);
}
{
let mut c_byte_sum = AB::Expr::zero();
for i in 0..BYTE_SIZE {
let val: AB::Expr = AB::F::from_canonical_u32(1 << i).into();
c_byte_sum += val * local.c_least_sig_byte[i];
}
builder.assert_eq(c_byte_sum, local.c[0]);
let mut num_bits_to_shift = AB::Expr::zero();
for i in 0..3 {
num_bits_to_shift += local.c_least_sig_byte[i] * AB::F::from_canonical_u32(1 << i);
}
for i in 0..BYTE_SIZE {
builder
.when(local.shift_by_n_bits[i])
.assert_eq(num_bits_to_shift.clone(), AB::F::from_canonical_usize(i));
}
builder.assert_eq(
local
.shift_by_n_bits
.iter()
.fold(zero.clone(), |acc, &x| acc + x),
one.clone(),
);
let num_bytes_to_shift = local.c_least_sig_byte[3]
+ local.c_least_sig_byte[4] * AB::F::from_canonical_u32(2);
for i in 0..WORD_SIZE {
builder
.when(local.shift_by_n_bytes[i])
.assert_eq(num_bytes_to_shift.clone(), AB::F::from_canonical_usize(i));
}
builder.assert_eq(
local
.shift_by_n_bytes
.iter()
.fold(zero.clone(), |acc, &x| acc + x),
one.clone(),
);
}
{
let leading_byte = local.is_sra * local.b_msb * AB::Expr::from_canonical_u8(0xff);
let mut sign_extended_b: Vec<AB::Expr> = vec![];
for i in 0..WORD_SIZE {
sign_extended_b.push(local.b[i].into());
}
for _ in 0..WORD_SIZE {
sign_extended_b.push(leading_byte.clone());
}
for num_bytes_to_shift in 0..WORD_SIZE {
for i in 0..(LONG_WORD_SIZE - num_bytes_to_shift) {
builder
.when(local.shift_by_n_bytes[num_bytes_to_shift])
.assert_eq(
local.byte_shift_result[i],
sign_extended_b[i + num_bytes_to_shift].clone(),
);
}
}
}
{
let mut carry_multiplier = AB::Expr::from_canonical_u8(0);
for i in 0..BYTE_SIZE {
carry_multiplier +=
AB::Expr::from_canonical_u32(1u32 << (8 - i)) * local.shift_by_n_bits[i];
}
let mut num_bits_to_shift = AB::Expr::zero();
for i in 0..3 {
num_bits_to_shift += local.c_least_sig_byte[i] * AB::F::from_canonical_u32(1 << i);
}
for i in (0..LONG_WORD_SIZE).rev() {
builder.send_byte_pair(
AB::F::from_canonical_u32(ByteOpcode::ShrCarry as u32),
local.shr_carry_output_shifted_byte[i],
local.shr_carry_output_carry[i],
local.byte_shift_result[i],
num_bits_to_shift.clone(),
local.shard,
local.channel,
local.is_real,
);
}
for i in (0..LONG_WORD_SIZE).rev() {
let mut v: AB::Expr = local.shr_carry_output_shifted_byte[i].into();
if i + 1 < LONG_WORD_SIZE {
v += local.shr_carry_output_carry[i + 1] * carry_multiplier.clone();
}
builder.assert_eq(v, local.bit_shift_result[i]);
}
}
{
for i in 0..WORD_SIZE {
builder.assert_eq(local.a[i], local.bit_shift_result[i]);
}
}
{
let flags = [local.is_srl, local.is_sra, local.is_real, local.b_msb];
for flag in flags.iter() {
builder.assert_bool(*flag);
}
for shift_by_n_byte in local.shift_by_n_bytes.iter() {
builder.assert_bool(*shift_by_n_byte);
}
for shift_by_n_bit in local.shift_by_n_bits.iter() {
builder.assert_bool(*shift_by_n_bit);
}
for bit in local.c_least_sig_byte.iter() {
builder.assert_bool(*bit);
}
}
{
let long_words = [
local.byte_shift_result,
local.bit_shift_result,
local.shr_carry_output_carry,
local.shr_carry_output_shifted_byte,
];
for long_word in long_words.iter() {
builder.slice_range_check_u8(long_word, local.shard, local.channel, local.is_real);
}
}
builder.assert_bool(local.is_srl);
builder.assert_bool(local.is_sra);
builder.assert_bool(local.is_real);
builder.assert_eq(local.is_srl + local.is_sra, local.is_real);
builder.receive_alu(
local.is_srl * AB::F::from_canonical_u32(Opcode::SRL as u32)
+ local.is_sra * AB::F::from_canonical_u32(Opcode::SRA as u32),
local.a,
local.b,
local.c,
local.shard,
local.channel,
local.nonce,
local.is_real,
);
}
}
#[cfg(test)]
mod tests {
use crate::{
air::MachineAir,
stark::StarkGenericConfig,
utils::{uni_stark_prove as prove, uni_stark_verify as verify},
};
use p3_baby_bear::BabyBear;
use p3_matrix::dense::RowMajorMatrix;
use crate::{
alu::AluEvent,
runtime::{ExecutionRecord, Opcode},
utils::BabyBearPoseidon2,
};
use super::ShiftRightChip;
#[test]
fn generate_trace() {
let mut shard = ExecutionRecord::default();
shard.shift_right_events = vec![AluEvent::new(0, 0, 0, Opcode::SRL, 6, 12, 1)];
let chip = ShiftRightChip::default();
let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&shard, &mut ExecutionRecord::default());
println!("{:?}", trace.values)
}
#[test]
fn prove_babybear() {
let config = BabyBearPoseidon2::new();
let mut challenger = config.challenger();
let shifts = vec![
(Opcode::SRL, 0xffff8000, 0xffff8000, 0),
(Opcode::SRL, 0x7fffc000, 0xffff8000, 1),
(Opcode::SRL, 0x01ffff00, 0xffff8000, 7),
(Opcode::SRL, 0x0003fffe, 0xffff8000, 14),
(Opcode::SRL, 0x0001ffff, 0xffff8001, 15),
(Opcode::SRL, 0xffffffff, 0xffffffff, 0),
(Opcode::SRL, 0x7fffffff, 0xffffffff, 1),
(Opcode::SRL, 0x01ffffff, 0xffffffff, 7),
(Opcode::SRL, 0x0003ffff, 0xffffffff, 14),
(Opcode::SRL, 0x00000001, 0xffffffff, 31),
(Opcode::SRL, 0x21212121, 0x21212121, 0),
(Opcode::SRL, 0x10909090, 0x21212121, 1),
(Opcode::SRL, 0x00424242, 0x21212121, 7),
(Opcode::SRL, 0x00008484, 0x21212121, 14),
(Opcode::SRL, 0x00000000, 0x21212121, 31),
(Opcode::SRL, 0x21212121, 0x21212121, 0xffffffe0),
(Opcode::SRL, 0x10909090, 0x21212121, 0xffffffe1),
(Opcode::SRL, 0x00424242, 0x21212121, 0xffffffe7),
(Opcode::SRL, 0x00008484, 0x21212121, 0xffffffee),
(Opcode::SRL, 0x00000000, 0x21212121, 0xffffffff),
(Opcode::SRA, 0x00000000, 0x00000000, 0),
(Opcode::SRA, 0xc0000000, 0x80000000, 1),
(Opcode::SRA, 0xff000000, 0x80000000, 7),
(Opcode::SRA, 0xfffe0000, 0x80000000, 14),
(Opcode::SRA, 0xffffffff, 0x80000001, 31),
(Opcode::SRA, 0x7fffffff, 0x7fffffff, 0),
(Opcode::SRA, 0x3fffffff, 0x7fffffff, 1),
(Opcode::SRA, 0x00ffffff, 0x7fffffff, 7),
(Opcode::SRA, 0x0001ffff, 0x7fffffff, 14),
(Opcode::SRA, 0x00000000, 0x7fffffff, 31),
(Opcode::SRA, 0x81818181, 0x81818181, 0),
(Opcode::SRA, 0xc0c0c0c0, 0x81818181, 1),
(Opcode::SRA, 0xff030303, 0x81818181, 7),
(Opcode::SRA, 0xfffe0606, 0x81818181, 14),
(Opcode::SRA, 0xffffffff, 0x81818181, 31),
];
let mut shift_events: Vec<AluEvent> = Vec::new();
for t in shifts.iter() {
shift_events.push(AluEvent::new(0, 0, 0, t.0, t.1, t.2, t.3));
}
let mut shard = ExecutionRecord::default();
shard.shift_right_events = shift_events;
let chip = ShiftRightChip::default();
let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&shard, &mut ExecutionRecord::default());
let proof = prove::<BabyBearPoseidon2, _>(&config, &chip, &mut challenger, trace);
let mut challenger = config.challenger();
verify(&config, &chip, &mut challenger, &proof).unwrap();
}
}