sp1_recursion_core/chips/
public_values.rs

1use crate::{
2    air::{RecursionPublicValues, RECURSIVE_PROOF_NUM_PV_ELTS},
3    builder::SP1RecursionAirBuilder,
4    runtime::Instruction,
5    CommitPublicValuesEvent, CommitPublicValuesInstr, ExecutionRecord, DIGEST_SIZE,
6};
7use p3_air::{Air, AirBuilder, BaseAir, PairBuilder};
8use p3_baby_bear::BabyBear;
9use p3_field::{AbstractField, PrimeField32};
10use p3_matrix::{dense::RowMajorMatrix, Matrix};
11use sp1_core_machine::utils::pad_rows_fixed;
12use sp1_derive::AlignedBorrow;
13use sp1_stark::air::MachineAir;
14use std::borrow::{Borrow, BorrowMut};
15
16use super::mem::MemoryAccessColsChips;
17
18pub const NUM_PUBLIC_VALUES_COLS: usize = core::mem::size_of::<PublicValuesCols<u8>>();
19pub const NUM_PUBLIC_VALUES_PREPROCESSED_COLS: usize =
20    core::mem::size_of::<PublicValuesPreprocessedCols<u8>>();
21
22pub const PUB_VALUES_LOG_HEIGHT: usize = 4;
23
24#[derive(Default)]
25pub struct PublicValuesChip;
26
27/// The preprocessed columns for the CommitPVHash instruction.
28#[derive(AlignedBorrow, Debug, Clone, Copy)]
29#[repr(C)]
30pub struct PublicValuesPreprocessedCols<T: Copy> {
31    pub pv_idx: [T; DIGEST_SIZE],
32    pub pv_mem: MemoryAccessColsChips<T>,
33}
34
35/// The cols for a CommitPVHash invocation.
36#[derive(AlignedBorrow, Debug, Clone, Copy)]
37#[repr(C)]
38pub struct PublicValuesCols<T: Copy> {
39    pub pv_element: T,
40}
41
42impl<F> BaseAir<F> for PublicValuesChip {
43    fn width(&self) -> usize {
44        NUM_PUBLIC_VALUES_COLS
45    }
46}
47
48impl<F: PrimeField32> MachineAir<F> for PublicValuesChip {
49    type Record = ExecutionRecord<F>;
50
51    type Program = crate::RecursionProgram<F>;
52
53    fn name(&self) -> String {
54        "PublicValues".to_string()
55    }
56
57    fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
58        // This is a no-op.
59    }
60
61    fn preprocessed_width(&self) -> usize {
62        NUM_PUBLIC_VALUES_PREPROCESSED_COLS
63    }
64
65    fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
66        assert_eq!(
67            std::any::TypeId::of::<F>(),
68            std::any::TypeId::of::<BabyBear>(),
69            "generate_preprocessed_trace only supports BabyBear field"
70        );
71
72        let mut rows: Vec<[BabyBear; NUM_PUBLIC_VALUES_PREPROCESSED_COLS]> = Vec::new();
73        let commit_pv_hash_instrs: Vec<&Box<CommitPublicValuesInstr<BabyBear>>> = program
74            .inner
75            .iter()
76            .filter_map(|instruction| {
77                if let Instruction::CommitPublicValues(instr) = instruction {
78                    Some(unsafe {
79                        std::mem::transmute::<
80                            &Box<CommitPublicValuesInstr<F>>,
81                            &Box<CommitPublicValuesInstr<BabyBear>>,
82                        >(instr)
83                    })
84                } else {
85                    None
86                }
87            })
88            .collect::<Vec<_>>();
89
90        if commit_pv_hash_instrs.len() != 1 {
91            tracing::warn!("Expected exactly one CommitPVHash instruction.");
92        }
93
94        // We only take 1 commit pv hash instruction, since our air only checks for one public
95        // values hash.
96        for instr in commit_pv_hash_instrs.iter().take(1) {
97            for i in 0..DIGEST_SIZE {
98                let mut row = [BabyBear::zero(); NUM_PUBLIC_VALUES_PREPROCESSED_COLS];
99                let cols: &mut PublicValuesPreprocessedCols<BabyBear> =
100                    row.as_mut_slice().borrow_mut();
101                unsafe {
102                    crate::sys::public_values_instr_to_row_babybear(instr, i, cols);
103                }
104                rows.push(row);
105            }
106        }
107
108        // Pad the preprocessed rows to 8 rows.
109        // gpu code breaks for small traces
110        pad_rows_fixed(
111            &mut rows,
112            || [BabyBear::zero(); NUM_PUBLIC_VALUES_PREPROCESSED_COLS],
113            Some(PUB_VALUES_LOG_HEIGHT),
114        );
115
116        let trace = RowMajorMatrix::new(
117            unsafe {
118                std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
119                    rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
120                )
121            },
122            NUM_PUBLIC_VALUES_PREPROCESSED_COLS,
123        );
124        Some(trace)
125    }
126
127    fn generate_trace(
128        &self,
129        input: &ExecutionRecord<F>,
130        _: &mut ExecutionRecord<F>,
131    ) -> RowMajorMatrix<F> {
132        assert_eq!(
133            std::any::TypeId::of::<F>(),
134            std::any::TypeId::of::<BabyBear>(),
135            "generate_trace only supports BabyBear field"
136        );
137
138        if input.commit_pv_hash_events.len() != 1 {
139            tracing::warn!("Expected exactly one CommitPVHash event.");
140        }
141
142        let mut rows: Vec<[BabyBear; NUM_PUBLIC_VALUES_COLS]> = Vec::new();
143
144        // We only take 1 commit pv hash instruction, since our air only checks for one public
145        // values hash.
146        for event in input.commit_pv_hash_events.iter().take(1) {
147            let bb_event = unsafe {
148                std::mem::transmute::<&CommitPublicValuesEvent<F>, &CommitPublicValuesEvent<BabyBear>>(
149                    event,
150                )
151            };
152            for i in 0..DIGEST_SIZE {
153                let mut row = [BabyBear::zero(); NUM_PUBLIC_VALUES_COLS];
154                let cols: &mut PublicValuesCols<BabyBear> = row.as_mut_slice().borrow_mut();
155                unsafe {
156                    crate::sys::public_values_event_to_row_babybear(bb_event, i, cols);
157                }
158                rows.push(row);
159            }
160        }
161
162        // Pad the trace to 8 rows.
163        pad_rows_fixed(
164            &mut rows,
165            || [BabyBear::zero(); NUM_PUBLIC_VALUES_COLS],
166            Some(PUB_VALUES_LOG_HEIGHT),
167        );
168
169        // Convert the trace to a row major matrix.
170        RowMajorMatrix::new(
171            unsafe {
172                std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
173                    rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
174                )
175            },
176            NUM_PUBLIC_VALUES_COLS,
177        )
178    }
179
180    fn included(&self, _record: &Self::Record) -> bool {
181        true
182    }
183}
184
185impl<AB> Air<AB> for PublicValuesChip
186where
187    AB: SP1RecursionAirBuilder + PairBuilder,
188{
189    fn eval(&self, builder: &mut AB) {
190        let main = builder.main();
191        let local = main.row_slice(0);
192        let local: &PublicValuesCols<AB::Var> = (*local).borrow();
193        let prepr = builder.preprocessed();
194        let local_prepr = prepr.row_slice(0);
195        let local_prepr: &PublicValuesPreprocessedCols<AB::Var> = (*local_prepr).borrow();
196        let pv = builder.public_values();
197        let pv_elms: [AB::Expr; RECURSIVE_PROOF_NUM_PV_ELTS] =
198            core::array::from_fn(|i| pv[i].into());
199        let public_values: &RecursionPublicValues<AB::Expr> = pv_elms.as_slice().borrow();
200
201        // Constrain mem read for the public value element.
202        builder.send_single(local_prepr.pv_mem.addr, local.pv_element, local_prepr.pv_mem.mult);
203
204        for (i, pv_elm) in public_values.digest.iter().enumerate() {
205            // Ensure that the public value element is the same for all rows within a fri fold
206            // invocation.
207            builder.when(local_prepr.pv_idx[i]).assert_eq(pv_elm.clone(), local.pv_element);
208        }
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    #![allow(clippy::print_stdout)]
215
216    use crate::{
217        air::{RecursionPublicValues, NUM_PV_ELMS_TO_HASH, RECURSIVE_PROOF_NUM_PV_ELTS},
218        chips::{
219            mem::MemoryAccessCols,
220            public_values::{
221                PublicValuesChip, PublicValuesCols, PublicValuesPreprocessedCols,
222                NUM_PUBLIC_VALUES_COLS, NUM_PUBLIC_VALUES_PREPROCESSED_COLS, PUB_VALUES_LOG_HEIGHT,
223            },
224            test_fixtures,
225        },
226        machine::tests::test_recursion_linear_program,
227        runtime::{instruction as instr, ExecutionRecord},
228        stark::BabyBearPoseidon2Outer,
229        Instruction, MemAccessKind, RecursionProgram, DIGEST_SIZE,
230    };
231    use p3_baby_bear::BabyBear;
232    use p3_field::AbstractField;
233    use p3_matrix::{dense::RowMajorMatrix, Matrix};
234    use rand::{rngs::StdRng, Rng, SeedableRng};
235    use sp1_core_machine::utils::{pad_rows_fixed, setup_logger};
236    use sp1_stark::{air::MachineAir, StarkGenericConfig};
237    use std::{
238        array,
239        borrow::{Borrow, BorrowMut},
240    };
241
242    #[test]
243    fn prove_babybear_circuit_public_values() {
244        setup_logger();
245        type SC = BabyBearPoseidon2Outer;
246        type F = <SC as StarkGenericConfig>::Val;
247
248        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
249        let mut random_felt = move || -> F { F::from_canonical_u32(rng.gen_range(0..1 << 16)) };
250        let random_pv_elms: [F; RECURSIVE_PROOF_NUM_PV_ELTS] = array::from_fn(|_| random_felt());
251        let public_values_a: [u32; RECURSIVE_PROOF_NUM_PV_ELTS] = array::from_fn(|i| i as u32);
252
253        let mut instructions = Vec::new();
254        // Allocate the memory for the public values hash.
255
256        for i in 0..RECURSIVE_PROOF_NUM_PV_ELTS {
257            let mult = (NUM_PV_ELMS_TO_HASH..NUM_PV_ELMS_TO_HASH + DIGEST_SIZE).contains(&i);
258            instructions.push(instr::mem_block(
259                MemAccessKind::Write,
260                mult as u32,
261                public_values_a[i],
262                random_pv_elms[i].into(),
263            ));
264        }
265        let public_values_a: &RecursionPublicValues<u32> = public_values_a.as_slice().borrow();
266        instructions.push(instr::commit_public_values(public_values_a));
267
268        test_recursion_linear_program(instructions);
269    }
270
271    #[test]
272    #[ignore = "Failing due to merge conflicts. Will be fixed shortly."]
273    fn generate_public_values_preprocessed_trace() {
274        let program = test_fixtures::program();
275
276        let chip = PublicValuesChip;
277        let trace = chip.generate_preprocessed_trace(&program).unwrap();
278        println!("{:?}", trace.values);
279    }
280
281    fn generate_trace_reference(
282        input: &ExecutionRecord<BabyBear>,
283        _: &mut ExecutionRecord<BabyBear>,
284    ) -> RowMajorMatrix<BabyBear> {
285        type F = BabyBear;
286
287        if input.commit_pv_hash_events.len() != 1 {
288            tracing::warn!("Expected exactly one CommitPVHash event.");
289        }
290
291        let mut rows: Vec<[F; NUM_PUBLIC_VALUES_COLS]> = Vec::new();
292
293        // We only take 1 commit pv hash instruction, since our air only checks for one public
294        // values hash.
295        for event in input.commit_pv_hash_events.iter().take(1) {
296            for element in event.public_values.digest.iter() {
297                let mut row = [F::zero(); NUM_PUBLIC_VALUES_COLS];
298                let cols: &mut PublicValuesCols<F> = row.as_mut_slice().borrow_mut();
299
300                cols.pv_element = *element;
301                rows.push(row);
302            }
303        }
304
305        // Pad the trace to 8 rows.
306        pad_rows_fixed(
307            &mut rows,
308            || [F::zero(); NUM_PUBLIC_VALUES_COLS],
309            Some(PUB_VALUES_LOG_HEIGHT),
310        );
311
312        RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_PUBLIC_VALUES_COLS)
313    }
314
315    #[test]
316    fn test_generate_trace() {
317        let shard = test_fixtures::shard();
318        let trace = PublicValuesChip.generate_trace(&shard, &mut ExecutionRecord::default());
319        assert_eq!(trace.height(), 16);
320
321        assert_eq!(trace, generate_trace_reference(&shard, &mut ExecutionRecord::default()));
322    }
323
324    fn generate_preprocessed_trace_reference(
325        program: &RecursionProgram<BabyBear>,
326    ) -> RowMajorMatrix<BabyBear> {
327        type F = BabyBear;
328
329        let mut rows: Vec<[F; NUM_PUBLIC_VALUES_PREPROCESSED_COLS]> = Vec::new();
330        let commit_pv_hash_instrs = program
331            .inner
332            .iter()
333            .filter_map(|instruction| {
334                if let Instruction::CommitPublicValues(instr) = instruction {
335                    Some(instr)
336                } else {
337                    None
338                }
339            })
340            .collect::<Vec<_>>();
341
342        if commit_pv_hash_instrs.len() != 1 {
343            tracing::warn!("Expected exactly one CommitPVHash instruction.");
344        }
345
346        // We only take 1 commit pv hash instruction
347        for instr in commit_pv_hash_instrs.iter().take(1) {
348            for (i, addr) in instr.pv_addrs.digest.iter().enumerate() {
349                let mut row = [F::zero(); NUM_PUBLIC_VALUES_PREPROCESSED_COLS];
350                let cols: &mut PublicValuesPreprocessedCols<F> = row.as_mut_slice().borrow_mut();
351                cols.pv_idx[i] = F::one();
352                cols.pv_mem = MemoryAccessCols { addr: *addr, mult: F::neg_one() };
353                rows.push(row);
354            }
355        }
356
357        // Pad the preprocessed rows to 8 rows
358        pad_rows_fixed(
359            &mut rows,
360            || [F::zero(); NUM_PUBLIC_VALUES_PREPROCESSED_COLS],
361            Some(PUB_VALUES_LOG_HEIGHT),
362        );
363
364        RowMajorMatrix::new(
365            rows.into_iter().flatten().collect(),
366            NUM_PUBLIC_VALUES_PREPROCESSED_COLS,
367        )
368    }
369
370    #[test]
371    #[ignore = "Failing due to merge conflicts. Will be fixed shortly."]
372    fn test_generate_preprocessed_trace() {
373        let program = test_fixtures::program();
374        let trace = PublicValuesChip.generate_preprocessed_trace(&program).unwrap();
375        assert_eq!(trace.height(), 16);
376
377        assert_eq!(trace, generate_preprocessed_trace_reference(&program));
378    }
379}