sp1_recursion_machine/chips/
alu_base.rs1use crate::builder::SP1RecursionAirBuilder;
2use core::borrow::Borrow;
3use slop_air::{Air, AirBuilder, BaseAir, PairBuilder};
4use slop_algebra::{Field, PrimeField32};
5use slop_matrix::Matrix;
6use slop_maybe_rayon::prelude::{IndexedParallelIterator, ParallelIterator, ParallelSliceMut};
7use sp1_derive::AlignedBorrow;
8use sp1_hypercube::{air::MachineAir, next_multiple_of_32};
9use sp1_primitives::SP1Field;
10use sp1_recursion_executor::{
11 Address, BaseAluInstr, BaseAluIo, BaseAluOpcode, ExecutionRecord, Instruction, RecursionProgram,
12};
13use std::{borrow::BorrowMut, iter::zip, mem::MaybeUninit};
14
15pub const NUM_BASE_ALU_ENTRIES_PER_ROW: usize = 1;
16
17#[derive(Default, Clone)]
18pub struct BaseAluChip;
19
20pub const NUM_BASE_ALU_COLS: usize = core::mem::size_of::<BaseAluCols<u8>>();
21
22#[derive(AlignedBorrow, Debug, Clone, Copy)]
23#[repr(C)]
24pub struct BaseAluCols<F: Copy> {
25 pub values: [BaseAluValueCols<F>; NUM_BASE_ALU_ENTRIES_PER_ROW],
26}
27
28pub const NUM_BASE_ALU_VALUE_COLS: usize = core::mem::size_of::<BaseAluValueCols<u8>>();
29
30#[derive(AlignedBorrow, Debug, Clone, Copy)]
31#[repr(C)]
32pub struct BaseAluValueCols<F: Copy> {
33 pub vals: BaseAluIo<F>,
34}
35
36pub const NUM_BASE_ALU_PREPROCESSED_COLS: usize =
37 core::mem::size_of::<BaseAluPreprocessedCols<u8>>();
38
39#[derive(AlignedBorrow, Debug, Clone, Copy)]
40#[repr(C)]
41pub struct BaseAluPreprocessedCols<F: Copy> {
42 pub accesses: [BaseAluAccessCols<F>; NUM_BASE_ALU_ENTRIES_PER_ROW],
43}
44
45pub const NUM_BASE_ALU_ACCESS_COLS: usize = core::mem::size_of::<BaseAluAccessCols<u8>>();
46
47#[derive(AlignedBorrow, Debug, Clone, Copy)]
48#[repr(C)]
49pub struct BaseAluAccessCols<F: Copy> {
50 pub addrs: BaseAluIo<Address<F>>,
51 pub is_add: F,
52 pub is_sub: F,
53 pub is_mul: F,
54 pub is_div: F,
55 pub mult: F,
56}
57
58impl<F: Field> BaseAir<F> for BaseAluChip {
59 fn width(&self) -> usize {
60 NUM_BASE_ALU_COLS
61 }
62}
63
64impl<F: PrimeField32> MachineAir<F> for BaseAluChip {
65 type Record = ExecutionRecord<F>;
66
67 type Program = RecursionProgram<F>;
68
69 fn name(&self) -> &'static str {
70 "BaseAlu"
71 }
72
73 fn preprocessed_width(&self) -> usize {
74 NUM_BASE_ALU_PREPROCESSED_COLS
75 }
76
77 fn preprocessed_num_rows(&self, program: &Self::Program) -> Option<usize> {
78 let instrs_len = program
79 .inner
80 .iter()
81 .filter_map(|instruction| match instruction.inner() {
82 Instruction::BaseAlu(x) => Some(x),
83 _ => None,
84 })
85 .count();
86 self.preprocessed_num_rows_with_instrs_len(program, instrs_len)
87 }
88
89 fn preprocessed_num_rows_with_instrs_len(
90 &self,
91 program: &Self::Program,
92 instrs_len: usize,
93 ) -> Option<usize> {
94 let height = program.shape.as_ref().and_then(|shape| shape.height(self));
95 let nb_rows = instrs_len.div_ceil(NUM_BASE_ALU_ENTRIES_PER_ROW);
96 Some(next_multiple_of_32(nb_rows, height))
97 }
98
99 fn generate_preprocessed_trace_into(
100 &self,
101 program: &Self::Program,
102 buffer: &mut [MaybeUninit<F>],
103 ) {
104 assert_eq!(
105 std::any::TypeId::of::<F>(),
106 std::any::TypeId::of::<SP1Field>(),
107 "generate_preprocessed_trace only supports SP1Field field"
108 );
109
110 let instrs = program
111 .inner
112 .iter()
113 .filter_map(|instruction| match instruction.inner() {
114 Instruction::BaseAlu(x) => Some(x),
115 _ => None,
116 })
117 .collect::<Vec<_>>();
118 let padded_nb_rows =
119 self.preprocessed_num_rows_with_instrs_len(program, instrs.len()).unwrap();
120
121 let buffer_ptr = buffer.as_mut_ptr() as *mut F;
122 let values = unsafe {
123 core::slice::from_raw_parts_mut(
124 buffer_ptr,
125 padded_nb_rows * NUM_BASE_ALU_PREPROCESSED_COLS,
126 )
127 };
128
129 unsafe {
130 let padding_start = instrs.len() * NUM_BASE_ALU_ACCESS_COLS;
131 let padding_size = padded_nb_rows * NUM_BASE_ALU_PREPROCESSED_COLS - padding_start;
132 if padding_size > 0 {
133 core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
134 }
135 }
136
137 let populate_len = instrs.len() * NUM_BASE_ALU_ACCESS_COLS;
139 values[..populate_len].par_chunks_mut(NUM_BASE_ALU_ACCESS_COLS).zip_eq(instrs).for_each(
140 |(row, instr)| {
141 let BaseAluInstr { opcode, mult, addrs } = instr;
142 let access: &mut BaseAluAccessCols<_> = row.borrow_mut();
143 *access = BaseAluAccessCols {
144 addrs: addrs.to_owned(),
145 is_add: F::from_bool(false),
146 is_sub: F::from_bool(false),
147 is_mul: F::from_bool(false),
148 is_div: F::from_bool(false),
149 mult: mult.to_owned(),
150 };
151 let target_flag = match opcode {
152 BaseAluOpcode::AddF => &mut access.is_add,
153 BaseAluOpcode::SubF => &mut access.is_sub,
154 BaseAluOpcode::MulF => &mut access.is_mul,
155 BaseAluOpcode::DivF => &mut access.is_div,
156 };
157 *target_flag = F::from_bool(true);
158 },
159 );
160 }
161
162 fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
163 }
165
166 fn num_rows(&self, input: &Self::Record) -> Option<usize> {
167 let height = input.program.shape.as_ref().and_then(|shape| shape.height(self));
168 let nb_rows = input.base_alu_events.len().div_ceil(NUM_BASE_ALU_ENTRIES_PER_ROW);
169 Some(next_multiple_of_32(nb_rows, height))
170 }
171
172 fn generate_trace_into(
173 &self,
174 input: &ExecutionRecord<F>,
175 _: &mut ExecutionRecord<F>,
176 buffer: &mut [MaybeUninit<F>],
177 ) {
178 assert_eq!(
179 std::any::TypeId::of::<F>(),
180 std::any::TypeId::of::<SP1Field>(),
181 "generate_trace_into only supports SP1Field"
182 );
183
184 let events = &input.base_alu_events;
185 let padded_nb_rows = self.num_rows(input).unwrap();
186 let num_event_rows = events.len();
187
188 unsafe {
189 let padding_start = num_event_rows * NUM_BASE_ALU_COLS;
190 let padding_size = (padded_nb_rows - num_event_rows) * NUM_BASE_ALU_COLS;
191 if padding_size > 0 {
192 core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
193 }
194 }
195
196 let buffer_ptr = buffer.as_mut_ptr() as *mut F;
197 let values = unsafe {
198 core::slice::from_raw_parts_mut(buffer_ptr, num_event_rows * NUM_BASE_ALU_COLS)
199 };
200
201 let populate_len = events.len() * NUM_BASE_ALU_VALUE_COLS;
203 values[..populate_len].par_chunks_mut(NUM_BASE_ALU_VALUE_COLS).zip_eq(events).for_each(
204 |(row, &vals)| {
205 let cols: &mut BaseAluValueCols<_> = row.borrow_mut();
206 *cols = BaseAluValueCols { vals };
207 },
208 );
209 }
210
211 fn included(&self, _record: &Self::Record) -> bool {
212 true
213 }
214}
215
216impl<AB> Air<AB> for BaseAluChip
217where
218 AB: SP1RecursionAirBuilder + PairBuilder,
219{
220 fn eval(&self, builder: &mut AB) {
221 let main = builder.main();
222 let local = main.row_slice(0);
223 let local: &BaseAluCols<AB::Var> = (*local).borrow();
224 let prep = builder.preprocessed();
225 let prep_local = prep.row_slice(0);
226 let prep_local: &BaseAluPreprocessedCols<AB::Var> = (*prep_local).borrow();
227
228 for (
229 BaseAluValueCols { vals: BaseAluIo { out, in1, in2 } },
230 BaseAluAccessCols { addrs, is_add, is_sub, is_mul, is_div, mult },
231 ) in zip(local.values, prep_local.accesses)
232 {
233 let is_real = is_add + is_sub + is_mul + is_div;
235 builder.assert_bool(is_real.clone());
236
237 builder.when(is_add).assert_eq(in1 + in2, out);
238 builder.when(is_sub).assert_eq(in1, in2 + out);
239 builder.when(is_mul).assert_eq(out, in1 * in2);
240 builder.when(is_div).assert_eq(in2 * out, in1);
241
242 builder.receive_single(addrs.in1, in1, is_real.clone());
244 builder.receive_single(addrs.in2, in2, is_real);
245
246 builder.send_single(addrs.out, out, mult);
248 }
249 }
250}
251
252#[cfg(test)]
253mod tests {
254
255 use rand::prelude::*;
256 use sp1_recursion_executor::{instruction as instr, BaseAluOpcode, MemAccessKind};
257
258 use crate::{chips::test_fixtures, test::test_recursion_linear_program};
259
260 use super::*;
261
262 #[tokio::test]
263 async fn generate_trace() {
264 let shard = test_fixtures::shard().await;
265 let trace = BaseAluChip.generate_trace(shard, &mut ExecutionRecord::default());
266 assert!(trace.height() > test_fixtures::MIN_ROWS);
267 }
268
269 #[tokio::test]
270 async fn generate_preprocessed_trace() {
271 let program = &test_fixtures::program_with_input().await.0;
272 let trace = BaseAluChip.generate_preprocessed_trace(program).unwrap();
273 assert!(trace.height() > test_fixtures::MIN_ROWS);
274 }
275
276 #[tokio::test]
277 pub async fn four_ops() {
278 let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
279 let mut random_felt = move || -> SP1Field { rng.sample(rand::distributions::Standard) };
280 let mut addr = 0;
281
282 let instructions = (0..1000)
283 .flat_map(|_| {
284 let quot = random_felt();
285 let in2 = random_felt();
286 let in1 = in2 * quot;
287 let alloc_size = 6;
288 let a = (0..alloc_size).map(|x| x + addr).collect::<Vec<_>>();
289 addr += alloc_size;
290 [
291 instr::mem_single(MemAccessKind::Write, 4, a[0], in1),
292 instr::mem_single(MemAccessKind::Write, 4, a[1], in2),
293 instr::base_alu(BaseAluOpcode::AddF, 1, a[2], a[0], a[1]),
294 instr::mem_single(MemAccessKind::Read, 1, a[2], in1 + in2),
295 instr::base_alu(BaseAluOpcode::SubF, 1, a[3], a[0], a[1]),
296 instr::mem_single(MemAccessKind::Read, 1, a[3], in1 - in2),
297 instr::base_alu(BaseAluOpcode::MulF, 1, a[4], a[0], a[1]),
298 instr::mem_single(MemAccessKind::Read, 1, a[4], in1 * in2),
299 instr::base_alu(BaseAluOpcode::DivF, 1, a[5], a[0], a[1]),
300 instr::mem_single(MemAccessKind::Read, 1, a[5], quot),
301 ]
302 })
303 .collect::<Vec<Instruction<SP1Field>>>();
304
305 test_recursion_linear_program(instructions).await;
306 }
307}