use core::borrow::{Borrow, BorrowMut};
use core::mem::size_of;
use hashbrown::HashMap;
use itertools::Itertools;
use p3_air::AirBuilder;
use p3_air::{Air, BaseAir};
use p3_field::{AbstractField, PrimeField};
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use p3_maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator, ParallelSlice};
use sp1_derive::AlignedBorrow;
use crate::air::MachineAir;
use crate::air::{SP1AirBuilder, Word};
use crate::bytes::event::ByteRecord;
use crate::bytes::{ByteLookupEvent, ByteOpcode};
use crate::runtime::{ExecutionRecord, Opcode, Program};
use crate::utils::pad_to_power_of_two;
use super::AluEvent;
pub const NUM_BITWISE_COLS: usize = size_of::<BitwiseCols<u8>>();
#[derive(Default)]
pub struct BitwiseChip;
#[derive(AlignedBorrow, Default, Clone, Copy)]
#[repr(C)]
pub struct BitwiseCols<T> {
pub shard: T,
pub channel: T,
pub nonce: T,
pub a: Word<T>,
pub b: Word<T>,
pub c: Word<T>,
pub is_xor: T,
pub is_or: T,
pub is_and: T,
}
impl<F: PrimeField> MachineAir<F> for BitwiseChip {
type Record = ExecutionRecord;
type Program = Program;
fn name(&self) -> String {
"Bitwise".to_string()
}
fn generate_trace(
&self,
input: &ExecutionRecord,
_: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
let rows = input
.bitwise_events
.par_iter()
.map(|event| {
let mut row = [F::zero(); NUM_BITWISE_COLS];
let cols: &mut BitwiseCols<F> = row.as_mut_slice().borrow_mut();
let mut blu = Vec::new();
self.event_to_row(event, cols, &mut blu);
row
})
.collect::<Vec<_>>();
let mut trace = RowMajorMatrix::new(
rows.into_iter().flatten().collect::<Vec<_>>(),
NUM_BITWISE_COLS,
);
pad_to_power_of_two::<NUM_BITWISE_COLS, F>(&mut trace.values);
for i in 0..trace.height() {
let cols: &mut BitwiseCols<F> =
trace.values[i * NUM_BITWISE_COLS..(i + 1) * NUM_BITWISE_COLS].borrow_mut();
cols.nonce = F::from_canonical_usize(i);
}
trace
}
fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
let chunk_size = std::cmp::max(input.bitwise_events.len() / num_cpus::get(), 1);
let blu_batches = input
.bitwise_events
.par_chunks(chunk_size)
.map(|events| {
let mut blu: HashMap<u32, HashMap<ByteLookupEvent, usize>> = HashMap::new();
events.iter().for_each(|event| {
let mut row = [F::zero(); NUM_BITWISE_COLS];
let cols: &mut BitwiseCols<F> = row.as_mut_slice().borrow_mut();
self.event_to_row(event, cols, &mut blu);
});
blu
})
.collect::<Vec<_>>();
output.add_sharded_byte_lookup_events(blu_batches.iter().collect_vec());
}
fn included(&self, shard: &Self::Record) -> bool {
!shard.bitwise_events.is_empty()
}
}
impl BitwiseChip {
fn event_to_row<F: PrimeField>(
&self,
event: &AluEvent,
cols: &mut BitwiseCols<F>,
blu: &mut impl ByteRecord,
) {
let a = event.a.to_le_bytes();
let b = event.b.to_le_bytes();
let c = event.c.to_le_bytes();
cols.shard = F::from_canonical_u32(event.shard);
cols.channel = F::from_canonical_u8(event.channel);
cols.a = Word::from(event.a);
cols.b = Word::from(event.b);
cols.c = Word::from(event.c);
cols.is_xor = F::from_bool(event.opcode == Opcode::XOR);
cols.is_or = F::from_bool(event.opcode == Opcode::OR);
cols.is_and = F::from_bool(event.opcode == Opcode::AND);
for ((b_a, b_b), b_c) in a.into_iter().zip(b).zip(c) {
let byte_event = ByteLookupEvent {
shard: event.shard,
channel: event.channel,
opcode: ByteOpcode::from(event.opcode),
a1: b_a as u16,
a2: 0,
b: b_b,
c: b_c,
};
blu.add_byte_lookup_event(byte_event);
}
}
}
impl<F> BaseAir<F> for BitwiseChip {
fn width(&self) -> usize {
NUM_BITWISE_COLS
}
}
impl<AB> Air<AB> for BitwiseChip
where
AB: SP1AirBuilder,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.row_slice(0);
let local: &BitwiseCols<AB::Var> = (*local).borrow();
let next = main.row_slice(1);
let next: &BitwiseCols<AB::Var> = (*next).borrow();
builder.when_first_row().assert_zero(local.nonce);
builder
.when_transition()
.assert_eq(local.nonce + AB::Expr::one(), next.nonce);
let opcode = local.is_xor * ByteOpcode::XOR.as_field::<AB::F>()
+ local.is_or * ByteOpcode::OR.as_field::<AB::F>()
+ local.is_and * ByteOpcode::AND.as_field::<AB::F>();
let mult = local.is_xor + local.is_or + local.is_and;
for ((a, b), c) in local.a.into_iter().zip(local.b).zip(local.c) {
builder.send_byte(
opcode.clone(),
a,
b,
c,
local.shard,
local.channel,
mult.clone(),
);
}
let cpu_opcode = local.is_xor * Opcode::XOR.as_field::<AB::F>()
+ local.is_or * Opcode::OR.as_field::<AB::F>()
+ local.is_and * Opcode::AND.as_field::<AB::F>();
builder.receive_alu(
cpu_opcode,
local.a,
local.b,
local.c,
local.shard,
local.channel,
local.nonce,
local.is_xor + local.is_or + local.is_and,
);
let is_real = local.is_xor + local.is_or + local.is_and;
builder.assert_bool(local.is_xor);
builder.assert_bool(local.is_or);
builder.assert_bool(local.is_and);
builder.assert_bool(is_real);
}
}
#[cfg(test)]
mod tests {
use p3_baby_bear::BabyBear;
use p3_matrix::dense::RowMajorMatrix;
use crate::air::MachineAir;
use crate::stark::StarkGenericConfig;
use crate::utils::{uni_stark_prove as prove, uni_stark_verify as verify};
use super::BitwiseChip;
use crate::alu::AluEvent;
use crate::runtime::{ExecutionRecord, Opcode};
use crate::utils::BabyBearPoseidon2;
#[test]
fn generate_trace() {
let mut shard = ExecutionRecord::default();
shard.bitwise_events = vec![AluEvent::new(0, 0, 0, Opcode::XOR, 25, 10, 19)];
let chip = BitwiseChip::default();
let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&shard, &mut ExecutionRecord::default());
println!("{:?}", trace.values)
}
#[test]
fn prove_babybear() {
let config = BabyBearPoseidon2::new();
let mut challenger = config.challenger();
let mut shard = ExecutionRecord::default();
shard.bitwise_events = [
AluEvent::new(0, 0, 0, Opcode::XOR, 25, 10, 19),
AluEvent::new(0, 1, 0, Opcode::OR, 27, 10, 19),
AluEvent::new(0, 0, 0, Opcode::AND, 2, 10, 19),
]
.repeat(1000);
let chip = BitwiseChip::default();
let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&shard, &mut ExecutionRecord::default());
let proof = prove::<BabyBearPoseidon2, _>(&config, &chip, &mut challenger, trace);
let mut challenger = config.challenger();
verify(&config, &chip, &mut challenger, &proof).unwrap();
}
}