sp1_recursion_core/chips/
select.rs

1use core::borrow::Borrow;
2use p3_air::{Air, BaseAir, PairBuilder};
3use p3_field::{AbstractField, Field, PrimeField32};
4use p3_matrix::{dense::RowMajorMatrix, Matrix};
5use sp1_core_machine::utils::next_power_of_two;
6use sp1_derive::AlignedBorrow;
7use sp1_stark::air::MachineAir;
8
9#[cfg(feature = "sys")]
10use {p3_baby_bear::BabyBear, p3_maybe_rayon::prelude::*, std::borrow::BorrowMut};
11
12use crate::{builder::SP1RecursionAirBuilder, *};
13
14#[derive(Default)]
15pub struct SelectChip;
16
17pub const SELECT_COLS: usize = core::mem::size_of::<SelectCols<u8>>();
18
19#[derive(AlignedBorrow, Debug, Clone, Copy)]
20#[repr(C)]
21pub struct SelectCols<F: Copy> {
22    pub vals: SelectIo<F>,
23}
24
25pub const SELECT_PREPROCESSED_COLS: usize = core::mem::size_of::<SelectPreprocessedCols<u8>>();
26
27#[derive(AlignedBorrow, Debug, Clone, Copy)]
28#[repr(C)]
29pub struct SelectPreprocessedCols<F: Copy> {
30    pub is_real: F,
31    pub addrs: SelectIo<Address<F>>,
32    pub mult1: F,
33    pub mult2: F,
34}
35
36impl<F: Field> BaseAir<F> for SelectChip {
37    fn width(&self) -> usize {
38        SELECT_COLS
39    }
40}
41
42impl<F: PrimeField32> MachineAir<F> for SelectChip {
43    type Record = ExecutionRecord<F>;
44
45    type Program = crate::RecursionProgram<F>;
46
47    fn name(&self) -> String {
48        "Select".to_string()
49    }
50
51    fn preprocessed_width(&self) -> usize {
52        SELECT_PREPROCESSED_COLS
53    }
54
55    fn preprocessed_num_rows(&self, program: &Self::Program, instrs_len: usize) -> Option<usize> {
56        let fixed_log2_rows = program.fixed_log2_rows(self);
57        Some(match fixed_log2_rows {
58            Some(log2_rows) => 1 << log2_rows,
59            None => next_power_of_two(instrs_len, None),
60        })
61    }
62
63    #[cfg(not(feature = "sys"))]
64    fn generate_preprocessed_trace(&self, _program: &Self::Program) -> Option<RowMajorMatrix<F>> {
65        unimplemented!("To generate traces, enable feature `sp1-recursion-core/sys`");
66    }
67
68    #[cfg(feature = "sys")]
69    fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
70        assert_eq!(
71            std::any::TypeId::of::<F>(),
72            std::any::TypeId::of::<BabyBear>(),
73            "generate_preprocessed_trace only supports BabyBear field"
74        );
75
76        let instrs = unsafe {
77            std::mem::transmute::<Vec<&SelectInstr<F>>, Vec<&SelectInstr<BabyBear>>>(
78                program
79                    .inner
80                    .iter()
81                    .filter_map(|instruction| match instruction {
82                        Instruction::Select(x) => Some(x),
83                        _ => None,
84                    })
85                    .collect::<Vec<_>>(),
86            )
87        };
88        let padded_nb_rows = self.preprocessed_num_rows(program, instrs.len()).unwrap();
89        let mut values = vec![BabyBear::zero(); padded_nb_rows * SELECT_PREPROCESSED_COLS];
90
91        // Generate the trace rows & corresponding records for each chunk of events in parallel.
92        let populate_len = instrs.len() * SELECT_PREPROCESSED_COLS;
93        values[..populate_len].par_chunks_mut(SELECT_PREPROCESSED_COLS).zip_eq(instrs).for_each(
94            |(row, instr)| {
95                let cols: &mut SelectPreprocessedCols<_> = row.borrow_mut();
96                unsafe {
97                    crate::sys::select_instr_to_row_babybear(instr, cols);
98                }
99            },
100        );
101
102        // Convert the trace to a row major matrix.
103        Some(RowMajorMatrix::new(
104            unsafe { std::mem::transmute::<Vec<BabyBear>, Vec<F>>(values) },
105            SELECT_PREPROCESSED_COLS,
106        ))
107    }
108
109    fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
110        // This is a no-op.
111    }
112
113    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
114        let events = &input.select_events;
115        Some(next_power_of_two(events.len(), input.fixed_log2_rows(self)))
116    }
117
118    #[cfg(not(feature = "sys"))]
119    fn generate_trace(&self, _input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
120        unimplemented!("To generate traces, enable feature `sp1-recursion-core/sys`");
121    }
122
123    #[cfg(feature = "sys")]
124    fn generate_trace(&self, input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
125        assert_eq!(
126            std::any::TypeId::of::<F>(),
127            std::any::TypeId::of::<BabyBear>(),
128            "generate_trace only supports BabyBear field"
129        );
130
131        let events = unsafe {
132            std::mem::transmute::<&Vec<SelectIo<F>>, &Vec<SelectIo<BabyBear>>>(&input.select_events)
133        };
134        let padded_nb_rows = self.num_rows(input).unwrap();
135        let mut values = vec![BabyBear::zero(); padded_nb_rows * SELECT_COLS];
136
137        // Generate the trace rows & corresponding records for each chunk of events in parallel.
138        let populate_len = events.len() * SELECT_COLS;
139        values[..populate_len].par_chunks_mut(SELECT_COLS).zip_eq(events).for_each(
140            |(row, &vals)| {
141                let cols: &mut SelectCols<_> = row.borrow_mut();
142                unsafe {
143                    crate::sys::select_event_to_row_babybear(&vals, cols);
144                }
145            },
146        );
147
148        // Convert the trace to a row major matrix.
149        RowMajorMatrix::new(
150            unsafe { std::mem::transmute::<Vec<BabyBear>, Vec<_>>(values) },
151            SELECT_COLS,
152        )
153    }
154
155    fn included(&self, _record: &Self::Record) -> bool {
156        true
157    }
158
159    fn local_only(&self) -> bool {
160        true
161    }
162}
163
164impl<AB> Air<AB> for SelectChip
165where
166    AB: SP1RecursionAirBuilder + PairBuilder,
167{
168    fn eval(&self, builder: &mut AB) {
169        let main = builder.main();
170        let local = main.row_slice(0);
171        let local: &SelectCols<AB::Var> = (*local).borrow();
172        let prep = builder.preprocessed();
173        let prep_local = prep.row_slice(0);
174        let prep_local: &SelectPreprocessedCols<AB::Var> = (*prep_local).borrow();
175
176        builder.receive_single(prep_local.addrs.bit, local.vals.bit, prep_local.is_real);
177        builder.receive_single(prep_local.addrs.in1, local.vals.in1, prep_local.is_real);
178        builder.receive_single(prep_local.addrs.in2, local.vals.in2, prep_local.is_real);
179        builder.send_single(prep_local.addrs.out1, local.vals.out1, prep_local.mult1);
180        builder.send_single(prep_local.addrs.out2, local.vals.out2, prep_local.mult2);
181        builder.assert_eq(
182            local.vals.out1,
183            local.vals.bit * local.vals.in2 + (AB::Expr::one() - local.vals.bit) * local.vals.in1,
184        );
185        builder.assert_eq(
186            local.vals.out2,
187            local.vals.bit * local.vals.in1 + (AB::Expr::one() - local.vals.bit) * local.vals.in2,
188        );
189    }
190}
191
192#[cfg(all(test, feature = "sys"))]
193mod tests {
194    use crate::{chips::test_fixtures, runtime::instruction as instr};
195    use machine::tests::test_recursion_linear_program;
196    use p3_baby_bear::BabyBear;
197    use p3_field::AbstractField;
198    use p3_matrix::dense::RowMajorMatrix;
199    use rand::{rngs::StdRng, Rng, SeedableRng};
200    use sp1_stark::{baby_bear_poseidon2::BabyBearPoseidon2, StarkGenericConfig};
201
202    use super::*;
203
204    #[test]
205    pub fn prove_select() {
206        type SC = BabyBearPoseidon2;
207        type F = <SC as StarkGenericConfig>::Val;
208
209        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
210        let mut addr = 0;
211
212        let instructions = (0..1000)
213            .flat_map(|_| {
214                let in1: F = rng.sample(rand::distributions::Standard);
215                let in2: F = rng.sample(rand::distributions::Standard);
216                let bit = F::from_bool(rng.gen_bool(0.5));
217                assert_eq!(bit * (bit - F::one()), F::zero());
218
219                let (out1, out2) = if bit == F::one() { (in2, in1) } else { (in1, in2) };
220                let alloc_size = 5;
221                let a = (0..alloc_size).map(|x| x + addr).collect::<Vec<_>>();
222                addr += alloc_size;
223                [
224                    instr::mem_single(MemAccessKind::Write, 1, a[0], bit),
225                    instr::mem_single(MemAccessKind::Write, 1, a[3], in1),
226                    instr::mem_single(MemAccessKind::Write, 1, a[4], in2),
227                    instr::select(1, 1, a[0], a[1], a[2], a[3], a[4]),
228                    instr::mem_single(MemAccessKind::Read, 1, a[1], out1),
229                    instr::mem_single(MemAccessKind::Read, 1, a[2], out2),
230                ]
231            })
232            .collect::<Vec<Instruction<F>>>();
233
234        test_recursion_linear_program(instructions);
235    }
236
237    fn generate_trace_reference(
238        input: &ExecutionRecord<BabyBear>,
239        _: &mut ExecutionRecord<BabyBear>,
240    ) -> RowMajorMatrix<BabyBear> {
241        type F = BabyBear;
242
243        let events = &input.select_events;
244        let padded_nb_rows = SelectChip.num_rows(input).unwrap();
245        let mut values = vec![F::zero(); padded_nb_rows * SELECT_COLS];
246
247        let populate_len = events.len() * SELECT_COLS;
248        values[..populate_len].par_chunks_mut(SELECT_COLS).zip_eq(events).for_each(
249            |(row, &vals)| {
250                let cols: &mut SelectCols<_> = row.borrow_mut();
251                *cols = SelectCols { vals };
252            },
253        );
254
255        RowMajorMatrix::new(values, SELECT_COLS)
256    }
257
258    #[test]
259    fn generate_trace() {
260        let shard = test_fixtures::shard();
261        let mut execution_record = test_fixtures::default_execution_record();
262        let trace = SelectChip.generate_trace(&shard, &mut execution_record);
263        assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
264
265        assert_eq!(trace, generate_trace_reference(&shard, &mut execution_record));
266    }
267
268    fn generate_preprocessed_trace_reference(
269        program: &RecursionProgram<BabyBear>,
270    ) -> RowMajorMatrix<BabyBear> {
271        type F = BabyBear;
272
273        let instrs = program
274            .inner
275            .iter()
276            .filter_map(|instruction| match instruction {
277                Instruction::Select(x) => Some(x),
278                _ => None,
279            })
280            .collect::<Vec<_>>();
281        let padded_nb_rows = SelectChip.preprocessed_num_rows(program, instrs.len()).unwrap();
282        let mut values = vec![F::zero(); padded_nb_rows * SELECT_PREPROCESSED_COLS];
283
284        let populate_len = instrs.len() * SELECT_PREPROCESSED_COLS;
285        values[..populate_len].par_chunks_mut(SELECT_PREPROCESSED_COLS).zip_eq(instrs).for_each(
286            |(row, instr)| {
287                let SelectInstr { addrs, mult1, mult2 } = instr;
288                let access: &mut SelectPreprocessedCols<_> = row.borrow_mut();
289                *access = SelectPreprocessedCols {
290                    is_real: F::one(),
291                    addrs: addrs.to_owned(),
292                    mult1: mult1.to_owned(),
293                    mult2: mult2.to_owned(),
294                };
295            },
296        );
297
298        RowMajorMatrix::new(values, SELECT_PREPROCESSED_COLS)
299    }
300
301    #[test]
302    #[ignore = "Failing due to merge conflicts. Will be fixed shortly."]
303    fn generate_preprocessed_trace() {
304        let program = test_fixtures::program();
305        let trace = SelectChip.generate_preprocessed_trace(&program).unwrap();
306        assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
307
308        assert_eq!(trace, generate_preprocessed_trace_reference(&program));
309    }
310}