sp1_recursion_core/chips/
public_values.rs

1use crate::{
2    air::{RecursionPublicValues, RECURSIVE_PROOF_NUM_PV_ELTS},
3    builder::SP1RecursionAirBuilder,
4    ExecutionRecord, DIGEST_SIZE,
5};
6use p3_air::{Air, AirBuilder, BaseAir, PairBuilder};
7use p3_field::PrimeField32;
8use p3_matrix::{dense::RowMajorMatrix, Matrix};
9use sp1_derive::AlignedBorrow;
10use sp1_stark::air::MachineAir;
11use std::borrow::Borrow;
12
13#[cfg(feature = "sys")]
14use {
15    crate::{runtime::Instruction, CommitPublicValuesEvent, CommitPublicValuesInstr},
16    p3_baby_bear::BabyBear,
17    p3_field::AbstractField,
18    sp1_core_machine::utils::pad_rows_fixed,
19    std::borrow::BorrowMut,
20};
21
22use super::mem::MemoryAccessColsChips;
23
24pub const NUM_PUBLIC_VALUES_COLS: usize = core::mem::size_of::<PublicValuesCols<u8>>();
25pub const NUM_PUBLIC_VALUES_PREPROCESSED_COLS: usize =
26    core::mem::size_of::<PublicValuesPreprocessedCols<u8>>();
27
28pub const PUB_VALUES_LOG_HEIGHT: usize = 4;
29
30#[derive(Default)]
31pub struct PublicValuesChip;
32
33/// The preprocessed columns for the CommitPVHash instruction.
34#[derive(AlignedBorrow, Debug, Clone, Copy)]
35#[repr(C)]
36pub struct PublicValuesPreprocessedCols<T: Copy> {
37    pub pv_idx: [T; DIGEST_SIZE],
38    pub pv_mem: MemoryAccessColsChips<T>,
39}
40
41/// The cols for a CommitPVHash invocation.
42#[derive(AlignedBorrow, Debug, Clone, Copy)]
43#[repr(C)]
44pub struct PublicValuesCols<T: Copy> {
45    pub pv_element: T,
46}
47
48impl<F> BaseAir<F> for PublicValuesChip {
49    fn width(&self) -> usize {
50        NUM_PUBLIC_VALUES_COLS
51    }
52}
53
54impl<F: PrimeField32> MachineAir<F> for PublicValuesChip {
55    type Record = ExecutionRecord<F>;
56
57    type Program = crate::RecursionProgram<F>;
58
59    fn name(&self) -> String {
60        "PublicValues".to_string()
61    }
62
63    fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
64        // This is a no-op.
65    }
66
67    fn preprocessed_width(&self) -> usize {
68        NUM_PUBLIC_VALUES_PREPROCESSED_COLS
69    }
70
71    #[cfg(not(feature = "sys"))]
72    fn generate_preprocessed_trace(&self, _program: &Self::Program) -> Option<RowMajorMatrix<F>> {
73        unimplemented!("To generate traces, enable feature `sp1-recursion-core/sys`");
74    }
75
76    #[cfg(feature = "sys")]
77    fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
78        assert_eq!(
79            std::any::TypeId::of::<F>(),
80            std::any::TypeId::of::<BabyBear>(),
81            "generate_preprocessed_trace only supports BabyBear field"
82        );
83
84        let mut rows: Vec<[BabyBear; NUM_PUBLIC_VALUES_PREPROCESSED_COLS]> = Vec::new();
85        let commit_pv_hash_instrs: Vec<&Box<CommitPublicValuesInstr<BabyBear>>> = program
86            .inner
87            .iter()
88            .filter_map(|instruction| {
89                if let Instruction::CommitPublicValues(instr) = instruction {
90                    Some(unsafe {
91                        std::mem::transmute::<
92                            &Box<CommitPublicValuesInstr<F>>,
93                            &Box<CommitPublicValuesInstr<BabyBear>>,
94                        >(instr)
95                    })
96                } else {
97                    None
98                }
99            })
100            .collect::<Vec<_>>();
101
102        if commit_pv_hash_instrs.len() != 1 {
103            tracing::warn!("Expected exactly one CommitPVHash instruction.");
104        }
105
106        // We only take 1 commit pv hash instruction, since our air only checks for one public
107        // values hash.
108        for instr in commit_pv_hash_instrs.iter().take(1) {
109            for i in 0..DIGEST_SIZE {
110                let mut row = [BabyBear::zero(); NUM_PUBLIC_VALUES_PREPROCESSED_COLS];
111                let cols: &mut PublicValuesPreprocessedCols<BabyBear> =
112                    row.as_mut_slice().borrow_mut();
113                unsafe {
114                    crate::sys::public_values_instr_to_row_babybear(instr, i, cols);
115                }
116                rows.push(row);
117            }
118        }
119
120        // Pad the preprocessed rows to 8 rows.
121        // gpu code breaks for small traces
122        pad_rows_fixed(
123            &mut rows,
124            || [BabyBear::zero(); NUM_PUBLIC_VALUES_PREPROCESSED_COLS],
125            Some(PUB_VALUES_LOG_HEIGHT),
126        );
127
128        let trace = RowMajorMatrix::new(
129            unsafe {
130                std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
131                    rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
132                )
133            },
134            NUM_PUBLIC_VALUES_PREPROCESSED_COLS,
135        );
136        Some(trace)
137    }
138
139    #[cfg(not(feature = "sys"))]
140    fn generate_trace(&self, _input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
141        unimplemented!("To generate traces, enable feature `sp1-recursion-core/sys`");
142    }
143
144    #[cfg(feature = "sys")]
145    fn generate_trace(
146        &self,
147        input: &ExecutionRecord<F>,
148        _: &mut ExecutionRecord<F>,
149    ) -> RowMajorMatrix<F> {
150        assert_eq!(
151            std::any::TypeId::of::<F>(),
152            std::any::TypeId::of::<BabyBear>(),
153            "generate_trace only supports BabyBear field"
154        );
155
156        if input.commit_pv_hash_events.len() != 1 {
157            tracing::warn!("Expected exactly one CommitPVHash event.");
158        }
159
160        let mut rows: Vec<[BabyBear; NUM_PUBLIC_VALUES_COLS]> = Vec::new();
161
162        // We only take 1 commit pv hash instruction, since our air only checks for one public
163        // values hash.
164        for event in input.commit_pv_hash_events.iter().take(1) {
165            let bb_event = unsafe {
166                std::mem::transmute::<&CommitPublicValuesEvent<F>, &CommitPublicValuesEvent<BabyBear>>(
167                    event,
168                )
169            };
170            for i in 0..DIGEST_SIZE {
171                let mut row = [BabyBear::zero(); NUM_PUBLIC_VALUES_COLS];
172                let cols: &mut PublicValuesCols<BabyBear> = row.as_mut_slice().borrow_mut();
173                unsafe {
174                    crate::sys::public_values_event_to_row_babybear(bb_event, i, cols);
175                }
176                rows.push(row);
177            }
178        }
179
180        // Pad the trace to 8 rows.
181        pad_rows_fixed(
182            &mut rows,
183            || [BabyBear::zero(); NUM_PUBLIC_VALUES_COLS],
184            Some(PUB_VALUES_LOG_HEIGHT),
185        );
186
187        // Convert the trace to a row major matrix.
188        RowMajorMatrix::new(
189            unsafe {
190                std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
191                    rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
192                )
193            },
194            NUM_PUBLIC_VALUES_COLS,
195        )
196    }
197
198    fn included(&self, _record: &Self::Record) -> bool {
199        true
200    }
201}
202
203impl<AB> Air<AB> for PublicValuesChip
204where
205    AB: SP1RecursionAirBuilder + PairBuilder,
206{
207    fn eval(&self, builder: &mut AB) {
208        let main = builder.main();
209        let local = main.row_slice(0);
210        let local: &PublicValuesCols<AB::Var> = (*local).borrow();
211        let prepr = builder.preprocessed();
212        let local_prepr = prepr.row_slice(0);
213        let local_prepr: &PublicValuesPreprocessedCols<AB::Var> = (*local_prepr).borrow();
214        let pv = builder.public_values();
215        let pv_elms: [AB::Expr; RECURSIVE_PROOF_NUM_PV_ELTS] =
216            core::array::from_fn(|i| pv[i].into());
217        let public_values: &RecursionPublicValues<AB::Expr> = pv_elms.as_slice().borrow();
218
219        // Constrain mem read for the public value element.
220        builder.send_single(local_prepr.pv_mem.addr, local.pv_element, local_prepr.pv_mem.mult);
221
222        for (i, pv_elm) in public_values.digest.iter().enumerate() {
223            // Ensure that the public value element is the same for all rows within a fri fold
224            // invocation.
225            builder.when(local_prepr.pv_idx[i]).assert_eq(pv_elm.clone(), local.pv_element);
226        }
227    }
228}
229
230#[cfg(all(test, feature = "sys"))]
231mod tests {
232    #![allow(clippy::print_stdout)]
233
234    use crate::{
235        air::{RecursionPublicValues, NUM_PV_ELMS_TO_HASH, RECURSIVE_PROOF_NUM_PV_ELTS},
236        chips::{
237            mem::MemoryAccessCols,
238            public_values::{
239                PublicValuesChip, PublicValuesCols, PublicValuesPreprocessedCols,
240                NUM_PUBLIC_VALUES_COLS, NUM_PUBLIC_VALUES_PREPROCESSED_COLS, PUB_VALUES_LOG_HEIGHT,
241            },
242            test_fixtures,
243        },
244        machine::tests::test_recursion_linear_program,
245        runtime::{instruction as instr, ExecutionRecord},
246        stark::BabyBearPoseidon2Outer,
247        Instruction, MemAccessKind, RecursionProgram, DIGEST_SIZE,
248    };
249    use p3_baby_bear::BabyBear;
250    use p3_field::AbstractField;
251    use p3_matrix::{dense::RowMajorMatrix, Matrix};
252    use rand::{rngs::StdRng, Rng, SeedableRng};
253    use sp1_core_machine::utils::{pad_rows_fixed, setup_logger};
254    use sp1_stark::{air::MachineAir, StarkGenericConfig};
255    use std::{
256        array,
257        borrow::{Borrow, BorrowMut},
258    };
259
260    #[test]
261    fn prove_babybear_circuit_public_values() {
262        setup_logger();
263        type SC = BabyBearPoseidon2Outer;
264        type F = <SC as StarkGenericConfig>::Val;
265
266        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
267        let mut random_felt = move || -> F { F::from_canonical_u32(rng.gen_range(0..1 << 16)) };
268        let random_pv_elms: [F; RECURSIVE_PROOF_NUM_PV_ELTS] = array::from_fn(|_| random_felt());
269        let public_values_a: [u32; RECURSIVE_PROOF_NUM_PV_ELTS] = array::from_fn(|i| i as u32);
270
271        let mut instructions = Vec::new();
272        // Allocate the memory for the public values hash.
273
274        for i in 0..RECURSIVE_PROOF_NUM_PV_ELTS {
275            let mult = (NUM_PV_ELMS_TO_HASH..NUM_PV_ELMS_TO_HASH + DIGEST_SIZE).contains(&i);
276            instructions.push(instr::mem_block(
277                MemAccessKind::Write,
278                mult as u32,
279                public_values_a[i],
280                random_pv_elms[i].into(),
281            ));
282        }
283        let public_values_a: &RecursionPublicValues<u32> = public_values_a.as_slice().borrow();
284        instructions.push(instr::commit_public_values(public_values_a));
285
286        test_recursion_linear_program(instructions);
287    }
288
289    #[test]
290    #[ignore = "Failing due to merge conflicts. Will be fixed shortly."]
291    fn generate_public_values_preprocessed_trace() {
292        let program = test_fixtures::program();
293
294        let chip = PublicValuesChip;
295        let trace = chip.generate_preprocessed_trace(&program).unwrap();
296        println!("{:?}", trace.values);
297    }
298
299    fn generate_trace_reference(
300        input: &ExecutionRecord<BabyBear>,
301        _: &mut ExecutionRecord<BabyBear>,
302    ) -> RowMajorMatrix<BabyBear> {
303        type F = BabyBear;
304
305        if input.commit_pv_hash_events.len() != 1 {
306            tracing::warn!("Expected exactly one CommitPVHash event.");
307        }
308
309        let mut rows: Vec<[F; NUM_PUBLIC_VALUES_COLS]> = Vec::new();
310
311        // We only take 1 commit pv hash instruction, since our air only checks for one public
312        // values hash.
313        for event in input.commit_pv_hash_events.iter().take(1) {
314            for element in event.public_values.digest.iter() {
315                let mut row = [F::zero(); NUM_PUBLIC_VALUES_COLS];
316                let cols: &mut PublicValuesCols<F> = row.as_mut_slice().borrow_mut();
317
318                cols.pv_element = *element;
319                rows.push(row);
320            }
321        }
322
323        // Pad the trace to 8 rows.
324        pad_rows_fixed(
325            &mut rows,
326            || [F::zero(); NUM_PUBLIC_VALUES_COLS],
327            Some(PUB_VALUES_LOG_HEIGHT),
328        );
329
330        RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_PUBLIC_VALUES_COLS)
331    }
332
333    #[test]
334    fn test_generate_trace() {
335        let shard = test_fixtures::shard();
336        let trace = PublicValuesChip.generate_trace(&shard, &mut ExecutionRecord::default());
337        assert_eq!(trace.height(), 16);
338
339        assert_eq!(trace, generate_trace_reference(&shard, &mut ExecutionRecord::default()));
340    }
341
342    fn generate_preprocessed_trace_reference(
343        program: &RecursionProgram<BabyBear>,
344    ) -> RowMajorMatrix<BabyBear> {
345        type F = BabyBear;
346
347        let mut rows: Vec<[F; NUM_PUBLIC_VALUES_PREPROCESSED_COLS]> = Vec::new();
348        let commit_pv_hash_instrs = program
349            .inner
350            .iter()
351            .filter_map(|instruction| {
352                if let Instruction::CommitPublicValues(instr) = instruction {
353                    Some(instr)
354                } else {
355                    None
356                }
357            })
358            .collect::<Vec<_>>();
359
360        if commit_pv_hash_instrs.len() != 1 {
361            tracing::warn!("Expected exactly one CommitPVHash instruction.");
362        }
363
364        // We only take 1 commit pv hash instruction
365        for instr in commit_pv_hash_instrs.iter().take(1) {
366            for (i, addr) in instr.pv_addrs.digest.iter().enumerate() {
367                let mut row = [F::zero(); NUM_PUBLIC_VALUES_PREPROCESSED_COLS];
368                let cols: &mut PublicValuesPreprocessedCols<F> = row.as_mut_slice().borrow_mut();
369                cols.pv_idx[i] = F::one();
370                cols.pv_mem = MemoryAccessCols { addr: *addr, mult: F::neg_one() };
371                rows.push(row);
372            }
373        }
374
375        // Pad the preprocessed rows to 8 rows
376        pad_rows_fixed(
377            &mut rows,
378            || [F::zero(); NUM_PUBLIC_VALUES_PREPROCESSED_COLS],
379            Some(PUB_VALUES_LOG_HEIGHT),
380        );
381
382        RowMajorMatrix::new(
383            rows.into_iter().flatten().collect(),
384            NUM_PUBLIC_VALUES_PREPROCESSED_COLS,
385        )
386    }
387
388    #[test]
389    #[ignore = "Failing due to merge conflicts. Will be fixed shortly."]
390    fn test_generate_preprocessed_trace() {
391        let program = test_fixtures::program();
392        let trace = PublicValuesChip.generate_preprocessed_trace(&program).unwrap();
393        assert_eq!(trace.height(), 16);
394
395        assert_eq!(trace, generate_preprocessed_trace_reference(&program));
396    }
397}