use hashbrown::HashMap;
use itertools::Itertools;
use p3_field::PrimeField32;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice};
use std::borrow::BorrowMut;
use crate::{
air::MachineAir,
bytes::{event::ByteRecord, ByteLookupEvent},
runtime::{ExecutionRecord, Program},
};
use super::{ShaExtendChip, ShaExtendCols, ShaExtendEvent, NUM_SHA_EXTEND_COLS};
impl<F: PrimeField32> MachineAir<F> for ShaExtendChip {
type Record = ExecutionRecord;
type Program = Program;
fn name(&self) -> String {
"ShaExtend".to_string()
}
fn generate_trace(
&self,
input: &ExecutionRecord,
_: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
let rows = Vec::new();
let mut new_byte_lookup_events = Vec::new();
let mut wrapped_rows = Some(rows);
for i in 0..input.sha_extend_events.len() {
self.event_to_rows(
&input.sha_extend_events[i],
&mut wrapped_rows,
&mut new_byte_lookup_events,
);
}
let mut rows = wrapped_rows.unwrap();
let nb_rows = rows.len();
let mut padded_nb_rows = nb_rows.next_power_of_two();
if padded_nb_rows == 2 || padded_nb_rows == 1 {
padded_nb_rows = 4;
}
for i in nb_rows..padded_nb_rows {
let mut row = [F::zero(); NUM_SHA_EXTEND_COLS];
let cols: &mut ShaExtendCols<F> = row.as_mut_slice().borrow_mut();
cols.populate_flags(i);
rows.push(row);
}
let mut trace = RowMajorMatrix::new(
rows.into_iter().flatten().collect::<Vec<_>>(),
NUM_SHA_EXTEND_COLS,
);
for i in 0..trace.height() {
let cols: &mut ShaExtendCols<F> =
trace.values[i * NUM_SHA_EXTEND_COLS..(i + 1) * NUM_SHA_EXTEND_COLS].borrow_mut();
cols.nonce = F::from_canonical_usize(i);
}
trace
}
fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
let chunk_size = std::cmp::max(input.sha_extend_events.len() / num_cpus::get(), 1);
let blu_batches = input
.sha_extend_events
.par_chunks(chunk_size)
.map(|events| {
let mut blu: HashMap<u32, HashMap<ByteLookupEvent, usize>> = HashMap::new();
events.iter().for_each(|event| {
self.event_to_rows::<F>(event, &mut None, &mut blu);
});
blu
})
.collect::<Vec<_>>();
output.add_sharded_byte_lookup_events(blu_batches.iter().collect_vec());
}
fn included(&self, shard: &Self::Record) -> bool {
!shard.sha_extend_events.is_empty()
}
}
impl ShaExtendChip {
fn event_to_rows<F: PrimeField32>(
&self,
event: &ShaExtendEvent,
rows: &mut Option<Vec<[F; NUM_SHA_EXTEND_COLS]>>,
blu: &mut impl ByteRecord,
) {
let shard = event.shard;
for j in 0..48usize {
let mut row = [F::zero(); NUM_SHA_EXTEND_COLS];
let cols: &mut ShaExtendCols<F> = row.as_mut_slice().borrow_mut();
cols.is_real = F::one();
cols.populate_flags(j);
cols.shard = F::from_canonical_u32(event.shard);
cols.channel = F::from_canonical_u8(event.channel);
cols.clk = F::from_canonical_u32(event.clk);
cols.w_ptr = F::from_canonical_u32(event.w_ptr);
cols.w_i_minus_15
.populate(event.channel, event.w_i_minus_15_reads[j], blu);
cols.w_i_minus_2
.populate(event.channel, event.w_i_minus_2_reads[j], blu);
cols.w_i_minus_16
.populate(event.channel, event.w_i_minus_16_reads[j], blu);
cols.w_i_minus_7
.populate(event.channel, event.w_i_minus_7_reads[j], blu);
let w_i_minus_15 = event.w_i_minus_15_reads[j].value;
let w_i_minus_15_rr_7 =
cols.w_i_minus_15_rr_7
.populate(blu, shard, event.channel, w_i_minus_15, 7);
let w_i_minus_15_rr_18 =
cols.w_i_minus_15_rr_18
.populate(blu, shard, event.channel, w_i_minus_15, 18);
let w_i_minus_15_rs_3 =
cols.w_i_minus_15_rs_3
.populate(blu, shard, event.channel, w_i_minus_15, 3);
let s0_intermediate = cols.s0_intermediate.populate(
blu,
shard,
event.channel,
w_i_minus_15_rr_7,
w_i_minus_15_rr_18,
);
let s0 = cols.s0.populate(
blu,
shard,
event.channel,
s0_intermediate,
w_i_minus_15_rs_3,
);
let w_i_minus_2 = event.w_i_minus_2_reads[j].value;
let w_i_minus_2_rr_17 =
cols.w_i_minus_2_rr_17
.populate(blu, shard, event.channel, w_i_minus_2, 17);
let w_i_minus_2_rr_19 =
cols.w_i_minus_2_rr_19
.populate(blu, shard, event.channel, w_i_minus_2, 19);
let w_i_minus_2_rs_10 =
cols.w_i_minus_2_rs_10
.populate(blu, shard, event.channel, w_i_minus_2, 10);
let s1_intermediate = cols.s1_intermediate.populate(
blu,
shard,
event.channel,
w_i_minus_2_rr_17,
w_i_minus_2_rr_19,
);
let s1 = cols.s1.populate(
blu,
shard,
event.channel,
s1_intermediate,
w_i_minus_2_rs_10,
);
let w_i_minus_7 = event.w_i_minus_7_reads[j].value;
let w_i_minus_16 = event.w_i_minus_16_reads[j].value;
cols.s2
.populate(blu, shard, event.channel, w_i_minus_16, s0, w_i_minus_7, s1);
cols.w_i.populate(event.channel, event.w_i_writes[j], blu);
if rows.as_ref().is_some() {
rows.as_mut().unwrap().push(row);
}
}
}
}