use hashbrown::HashMap;
use itertools::Itertools;
use sp1_core_executor::{
events::{ByteLookupEvent, ByteRecord, CpuEvent, LookupId, MemoryRecordEnum},
syscalls::SyscallCode,
ByteOpcode::{self, U16Range},
CoreShape, ExecutionRecord, Opcode, Program,
Register::X0,
};
use sp1_primitives::consts::WORD_SIZE;
use sp1_stark::{air::MachineAir, Word};
use std::{array, borrow::BorrowMut};
use p3_field::{PrimeField, PrimeField32};
use p3_matrix::dense::RowMajorMatrix;
use p3_maybe_rayon::prelude::{
IntoParallelRefMutIterator, ParallelBridge, ParallelIterator, ParallelSlice,
};
use tracing::instrument;
use super::{
columns::{CPU_COL_MAP, NUM_CPU_COLS},
CpuChip,
};
use crate::{cpu::columns::CpuCols, memory::MemoryCols, utils::zeroed_f_vec};
impl<F: PrimeField32> MachineAir<F> for CpuChip {
type Record = ExecutionRecord;
type Program = Program;
fn name(&self) -> String {
"CPU".to_string()
}
fn generate_trace(
&self,
input: &ExecutionRecord,
_: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
let mut values = zeroed_f_vec(input.cpu_events.len() * NUM_CPU_COLS);
let chunk_size = std::cmp::max(input.cpu_events.len() / num_cpus::get(), 1);
values.chunks_mut(chunk_size * NUM_CPU_COLS).enumerate().par_bridge().for_each(
|(i, rows)| {
rows.chunks_mut(NUM_CPU_COLS).enumerate().for_each(|(j, row)| {
let idx = i * chunk_size + j;
let cols: &mut CpuCols<F> = row.borrow_mut();
let mut byte_lookup_events = Vec::new();
self.event_to_row(
&input.cpu_events[idx],
&input.nonce_lookup,
cols,
&mut byte_lookup_events,
);
});
},
);
let mut trace = RowMajorMatrix::new(values, NUM_CPU_COLS);
Self::pad_to_power_of_two::<F>(self, &input.shape, &mut trace.values);
trace
}
#[instrument(name = "generate cpu dependencies", level = "debug", skip_all)]
fn generate_dependencies(&self, input: &ExecutionRecord, output: &mut ExecutionRecord) {
let chunk_size = std::cmp::max(input.cpu_events.len() / num_cpus::get(), 1);
let blu_events: Vec<_> = input
.cpu_events
.par_chunks(chunk_size)
.map(|ops: &[CpuEvent]| {
let mut blu: HashMap<u32, HashMap<ByteLookupEvent, usize>> = HashMap::new();
ops.iter().for_each(|op| {
let mut row = [F::zero(); NUM_CPU_COLS];
let cols: &mut CpuCols<F> = row.as_mut_slice().borrow_mut();
self.event_to_row::<F>(op, &HashMap::new(), cols, &mut blu);
});
blu
})
.collect::<Vec<_>>();
output.add_sharded_byte_lookup_events(blu_events.iter().collect_vec());
}
fn included(&self, shard: &Self::Record) -> bool {
if let Some(shape) = shard.shape.as_ref() {
shape.included::<F, _>(self)
} else {
shard.contains_cpu()
}
}
}
impl CpuChip {
fn event_to_row<F: PrimeField32>(
&self,
event: &CpuEvent,
nonce_lookup: &HashMap<LookupId, u32>,
cols: &mut CpuCols<F>,
blu_events: &mut impl ByteRecord,
) {
self.populate_shard_clk(cols, event, blu_events);
cols.nonce = F::from_canonical_u32(
nonce_lookup.get(&event.alu_lookup_id).copied().unwrap_or_default(),
);
cols.pc = F::from_canonical_u32(event.pc);
cols.next_pc = F::from_canonical_u32(event.next_pc);
cols.instruction.populate(event.instruction);
cols.selectors.populate(event.instruction);
*cols.op_a_access.value_mut() = event.a.into();
*cols.op_b_access.value_mut() = event.b.into();
*cols.op_c_access.value_mut() = event.c.into();
if let Some(record) = event.a_record {
cols.op_a_access.populate(record, blu_events);
}
if let Some(MemoryRecordEnum::Read(record)) = event.b_record {
cols.op_b_access.populate(record, blu_events);
}
if let Some(MemoryRecordEnum::Read(record)) = event.c_record {
cols.op_c_access.populate(record, blu_events);
}
let a_bytes = cols
.op_a_access
.access
.value
.0
.iter()
.map(|x| x.as_canonical_u32())
.collect::<Vec<_>>();
blu_events.add_byte_lookup_event(ByteLookupEvent {
shard: event.shard,
opcode: ByteOpcode::U8Range,
a1: 0,
a2: 0,
b: a_bytes[0] as u8,
c: a_bytes[1] as u8,
});
blu_events.add_byte_lookup_event(ByteLookupEvent {
shard: event.shard,
opcode: ByteOpcode::U8Range,
a1: 0,
a2: 0,
b: a_bytes[2] as u8,
c: a_bytes[3] as u8,
});
assert_eq!(event.memory_record.is_some(), event.memory.is_some());
let memory_columns = cols.opcode_specific_columns.memory_mut();
if let Some(record) = event.memory_record {
memory_columns.memory_access.populate(record, blu_events)
}
self.populate_memory(cols, event, blu_events, nonce_lookup);
self.populate_branch(cols, event, nonce_lookup);
self.populate_jump(cols, event, nonce_lookup);
self.populate_auipc(cols, event, nonce_lookup);
let is_halt = self.populate_ecall(cols, event, nonce_lookup);
cols.is_sequential_instr = F::from_bool(
!event.instruction.is_branch_instruction()
&& !event.instruction.is_jump_instruction()
&& !is_halt,
);
cols.is_real = F::one();
}
fn populate_shard_clk<F: PrimeField>(
&self,
cols: &mut CpuCols<F>,
event: &CpuEvent,
blu_events: &mut impl ByteRecord,
) {
cols.shard = F::from_canonical_u32(event.shard);
cols.clk = F::from_canonical_u32(event.clk);
let clk_16bit_limb = (event.clk & 0xffff) as u16;
let clk_8bit_limb = ((event.clk >> 16) & 0xff) as u8;
cols.clk_16bit_limb = F::from_canonical_u16(clk_16bit_limb);
cols.clk_8bit_limb = F::from_canonical_u8(clk_8bit_limb);
blu_events.add_byte_lookup_event(ByteLookupEvent::new(
event.shard,
U16Range,
event.shard as u16,
0,
0,
0,
));
blu_events.add_byte_lookup_event(ByteLookupEvent::new(
event.shard,
U16Range,
clk_16bit_limb,
0,
0,
0,
));
blu_events.add_byte_lookup_event(ByteLookupEvent::new(
event.shard,
ByteOpcode::U8Range,
0,
0,
0,
clk_8bit_limb as u8,
));
}
fn populate_memory<F: PrimeField>(
&self,
cols: &mut CpuCols<F>,
event: &CpuEvent,
blu_events: &mut impl ByteRecord,
nonce_lookup: &HashMap<LookupId, u32>,
) {
if !matches!(
event.instruction.opcode,
Opcode::LB
| Opcode::LH
| Opcode::LW
| Opcode::LBU
| Opcode::LHU
| Opcode::SB
| Opcode::SH
| Opcode::SW
) {
return;
}
let memory_columns = cols.opcode_specific_columns.memory_mut();
let memory_addr = event.b.wrapping_add(event.c);
let aligned_addr = memory_addr - memory_addr % WORD_SIZE as u32;
memory_columns.addr_word = memory_addr.into();
memory_columns.addr_word_range_checker.populate(memory_addr);
memory_columns.addr_aligned = F::from_canonical_u32(aligned_addr);
assert!(aligned_addr % 4 == 0);
let aligned_addr_ls_byte = (aligned_addr & 0x000000FF) as u8;
let bits: [bool; 8] = array::from_fn(|i| aligned_addr_ls_byte & (1 << i) != 0);
memory_columns.aa_least_sig_byte_decomp = array::from_fn(|i| F::from_bool(bits[i + 2]));
memory_columns.addr_word_nonce = F::from_canonical_u32(
nonce_lookup.get(&event.memory_add_lookup_id).copied().unwrap_or_default(),
);
let addr_offset = (memory_addr % WORD_SIZE as u32) as u8;
memory_columns.addr_offset = F::from_canonical_u8(addr_offset);
memory_columns.offset_is_one = F::from_bool(addr_offset == 1);
memory_columns.offset_is_two = F::from_bool(addr_offset == 2);
memory_columns.offset_is_three = F::from_bool(addr_offset == 3);
let mem_value = event.memory_record.unwrap().value();
if matches!(
event.instruction.opcode,
Opcode::LB | Opcode::LBU | Opcode::LH | Opcode::LHU | Opcode::LW
) {
match event.instruction.opcode {
Opcode::LB | Opcode::LBU => {
cols.unsigned_mem_val =
(mem_value.to_le_bytes()[addr_offset as usize] as u32).into();
}
Opcode::LH | Opcode::LHU => {
let value = match (addr_offset >> 1) % 2 {
0 => mem_value & 0x0000FFFF,
1 => (mem_value & 0xFFFF0000) >> 16,
_ => unreachable!(),
};
cols.unsigned_mem_val = value.into();
}
Opcode::LW => {
cols.unsigned_mem_val = mem_value.into();
}
_ => unreachable!(),
}
if matches!(event.instruction.opcode, Opcode::LB | Opcode::LH) {
let most_sig_mem_value_byte = if matches!(event.instruction.opcode, Opcode::LB) {
cols.unsigned_mem_val.to_u32().to_le_bytes()[0]
} else {
cols.unsigned_mem_val.to_u32().to_le_bytes()[1]
};
for i in (0..8).rev() {
memory_columns.most_sig_byte_decomp[i] =
F::from_canonical_u8(most_sig_mem_value_byte >> i & 0x01);
}
if memory_columns.most_sig_byte_decomp[7] == F::one() {
cols.mem_value_is_neg_not_x0 =
F::from_bool(event.instruction.op_a != (X0 as u32));
cols.unsigned_mem_val_nonce = F::from_canonical_u32(
nonce_lookup.get(&event.memory_sub_lookup_id).copied().unwrap_or_default(),
);
}
}
cols.mem_value_is_pos_not_x0 = F::from_bool(
((matches!(event.instruction.opcode, Opcode::LB | Opcode::LH)
&& (memory_columns.most_sig_byte_decomp[7] == F::zero()))
|| matches!(event.instruction.opcode, Opcode::LBU | Opcode::LHU | Opcode::LW))
&& event.instruction.op_a != (X0 as u32),
);
}
let addr_bytes = memory_addr.to_le_bytes();
for byte_pair in addr_bytes.chunks_exact(2) {
blu_events.add_byte_lookup_event(ByteLookupEvent {
shard: event.shard,
opcode: ByteOpcode::U8Range,
a1: 0,
a2: 0,
b: byte_pair[0],
c: byte_pair[1],
});
}
}
fn populate_branch<F: PrimeField>(
&self,
cols: &mut CpuCols<F>,
event: &CpuEvent,
nonce_lookup: &HashMap<LookupId, u32>,
) {
if event.instruction.is_branch_instruction() {
let branch_columns = cols.opcode_specific_columns.branch_mut();
let a_eq_b = event.a == event.b;
let use_signed_comparison =
matches!(event.instruction.opcode, Opcode::BLT | Opcode::BGE);
let a_lt_b = if use_signed_comparison {
(event.a as i32) < (event.b as i32)
} else {
event.a < event.b
};
let a_gt_b = if use_signed_comparison {
(event.a as i32) > (event.b as i32)
} else {
event.a > event.b
};
branch_columns.a_lt_b_nonce = F::from_canonical_u32(
nonce_lookup.get(&event.branch_lt_lookup_id).copied().unwrap_or_default(),
);
branch_columns.a_gt_b_nonce = F::from_canonical_u32(
nonce_lookup.get(&event.branch_gt_lookup_id).copied().unwrap_or_default(),
);
branch_columns.a_eq_b = F::from_bool(a_eq_b);
branch_columns.a_lt_b = F::from_bool(a_lt_b);
branch_columns.a_gt_b = F::from_bool(a_gt_b);
let branching = match event.instruction.opcode {
Opcode::BEQ => a_eq_b,
Opcode::BNE => !a_eq_b,
Opcode::BLT | Opcode::BLTU => a_lt_b,
Opcode::BGE | Opcode::BGEU => a_eq_b || a_gt_b,
_ => unreachable!(),
};
let next_pc = event.pc.wrapping_add(event.c);
branch_columns.pc = Word::from(event.pc);
branch_columns.next_pc = Word::from(next_pc);
branch_columns.pc_range_checker.populate(event.pc);
branch_columns.next_pc_range_checker.populate(next_pc);
if branching {
cols.branching = F::one();
branch_columns.next_pc_nonce = F::from_canonical_u32(
nonce_lookup.get(&event.branch_add_lookup_id).copied().unwrap_or_default(),
);
} else {
cols.not_branching = F::one();
}
}
}
fn populate_jump<F: PrimeField>(
&self,
cols: &mut CpuCols<F>,
event: &CpuEvent,
nonce_lookup: &HashMap<LookupId, u32>,
) {
if event.instruction.is_jump_instruction() {
let jump_columns = cols.opcode_specific_columns.jump_mut();
match event.instruction.opcode {
Opcode::JAL => {
let next_pc = event.pc.wrapping_add(event.b);
jump_columns.op_a_range_checker.populate(event.a);
jump_columns.pc = Word::from(event.pc);
jump_columns.pc_range_checker.populate(event.pc);
jump_columns.next_pc = Word::from(next_pc);
jump_columns.next_pc_range_checker.populate(next_pc);
jump_columns.jal_nonce = F::from_canonical_u32(
nonce_lookup.get(&event.jump_jal_lookup_id).copied().unwrap_or_default(),
);
}
Opcode::JALR => {
let next_pc = event.b.wrapping_add(event.c);
jump_columns.op_a_range_checker.populate(event.a);
jump_columns.next_pc = Word::from(next_pc);
jump_columns.next_pc_range_checker.populate(next_pc);
jump_columns.jalr_nonce = F::from_canonical_u32(
nonce_lookup.get(&event.jump_jalr_lookup_id).copied().unwrap_or_default(),
);
}
_ => unreachable!(),
}
}
}
fn populate_auipc<F: PrimeField>(
&self,
cols: &mut CpuCols<F>,
event: &CpuEvent,
nonce_lookup: &HashMap<LookupId, u32>,
) {
if matches!(event.instruction.opcode, Opcode::AUIPC) {
let auipc_columns = cols.opcode_specific_columns.auipc_mut();
auipc_columns.pc = Word::from(event.pc);
auipc_columns.pc_range_checker.populate(event.pc);
auipc_columns.auipc_nonce = F::from_canonical_u32(
nonce_lookup.get(&event.auipc_lookup_id).copied().unwrap_or_default(),
);
}
}
fn populate_ecall<F: PrimeField>(
&self,
cols: &mut CpuCols<F>,
event: &CpuEvent,
nonce_lookup: &HashMap<LookupId, u32>,
) -> bool {
let mut is_halt = false;
if cols.selectors.is_ecall == F::one() {
let ecall_cols = cols.opcode_specific_columns.ecall_mut();
cols.ecall_mul_send_to_table = cols.selectors.is_ecall * cols.op_a_access.prev_value[1];
let syscall_id = cols.op_a_access.prev_value[0];
ecall_cols.is_enter_unconstrained.populate_from_field_element(
syscall_id - F::from_canonical_u32(SyscallCode::ENTER_UNCONSTRAINED.syscall_id()),
);
ecall_cols.is_hint_len.populate_from_field_element(
syscall_id - F::from_canonical_u32(SyscallCode::HINT_LEN.syscall_id()),
);
ecall_cols.is_halt.populate_from_field_element(
syscall_id - F::from_canonical_u32(SyscallCode::HALT.syscall_id()),
);
ecall_cols.is_commit.populate_from_field_element(
syscall_id - F::from_canonical_u32(SyscallCode::COMMIT.syscall_id()),
);
ecall_cols.is_commit_deferred_proofs.populate_from_field_element(
syscall_id
- F::from_canonical_u32(SyscallCode::COMMIT_DEFERRED_PROOFS.syscall_id()),
);
if syscall_id == F::from_canonical_u32(SyscallCode::COMMIT.syscall_id())
|| syscall_id
== F::from_canonical_u32(SyscallCode::COMMIT_DEFERRED_PROOFS.syscall_id())
{
let digest_idx = cols.op_b_access.value().to_u32() as usize;
ecall_cols.index_bitmap[digest_idx] = F::one();
}
ecall_cols.syscall_nonce = F::from_canonical_u32(
nonce_lookup.get(&event.syscall_lookup_id).copied().unwrap_or_default(),
);
is_halt = syscall_id == F::from_canonical_u32(SyscallCode::HALT.syscall_id());
if is_halt {
ecall_cols.operand_to_check = event.b.into();
ecall_cols.operand_range_check_cols.populate(event.b);
cols.ecall_range_check_operand = F::one();
}
if syscall_id == F::from_canonical_u32(SyscallCode::COMMIT_DEFERRED_PROOFS.syscall_id())
{
ecall_cols.operand_to_check = event.c.into();
ecall_cols.operand_range_check_cols.populate(event.c);
cols.ecall_range_check_operand = F::one();
}
}
is_halt
}
fn pad_to_power_of_two<F: PrimeField32>(&self, shape: &Option<CoreShape>, values: &mut Vec<F>) {
let n_real_rows = values.len() / NUM_CPU_COLS;
let padded_nb_rows = if let Some(shape) = shape {
1 << shape.inner[&MachineAir::<F>::name(self)]
} else if n_real_rows < 16 {
16
} else {
n_real_rows.next_power_of_two()
};
values.resize(padded_nb_rows * NUM_CPU_COLS, F::zero());
let rows = unsafe {
core::slice::from_raw_parts_mut(
values.as_mut_ptr() as *mut [F; NUM_CPU_COLS],
values.len() / NUM_CPU_COLS,
)
};
rows[n_real_rows..].par_iter_mut().for_each(|padded_row| {
padded_row[CPU_COL_MAP.selectors.imm_b] = F::one();
padded_row[CPU_COL_MAP.selectors.imm_c] = F::one();
});
}
}