sp1_recursion_core_v2/chips/
public_values.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use p3_air::{Air, AirBuilder, BaseAir, PairBuilder};
4use p3_field::PrimeField32;
5use p3_matrix::{dense::RowMajorMatrix, Matrix};
6use sp1_core_machine::utils::pad_rows_fixed;
7use sp1_derive::AlignedBorrow;
8use sp1_recursion_core::air::{RecursionPublicValues, RECURSIVE_PROOF_NUM_PV_ELTS};
9use sp1_stark::air::MachineAir;
10
11use crate::{
12    builder::SP1RecursionAirBuilder,
13    runtime::{Instruction, RecursionProgram},
14    ExecutionRecord,
15};
16
17use crate::DIGEST_SIZE;
18
19use super::mem::MemoryAccessCols;
20
21pub const NUM_PUBLIC_VALUES_COLS: usize = core::mem::size_of::<PublicValuesCols<u8>>();
22pub const NUM_PUBLIC_VALUES_PREPROCESSED_COLS: usize =
23    core::mem::size_of::<PublicValuesPreprocessedCols<u8>>();
24
25#[derive(Default)]
26pub struct PublicValuesChip {}
27
28/// The preprocessed columns for the CommitPVHash instruction.
29#[derive(AlignedBorrow, Debug, Clone, Copy)]
30#[repr(C)]
31pub struct PublicValuesPreprocessedCols<T: Copy> {
32    pub pv_idx: [T; DIGEST_SIZE],
33    pub pv_mem: MemoryAccessCols<T>,
34}
35
36/// The cols for a CommitPVHash invocation.
37#[derive(AlignedBorrow, Debug, Clone, Copy)]
38#[repr(C)]
39pub struct PublicValuesCols<T: Copy> {
40    pub pv_element: T,
41}
42
43impl<F> BaseAir<F> for PublicValuesChip {
44    fn width(&self) -> usize {
45        NUM_PUBLIC_VALUES_COLS
46    }
47}
48
49impl<F: PrimeField32> MachineAir<F> for PublicValuesChip {
50    type Record = ExecutionRecord<F>;
51
52    type Program = RecursionProgram<F>;
53
54    fn name(&self) -> String {
55        "PublicValues".to_string()
56    }
57
58    fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
59        // This is a no-op.
60    }
61
62    fn preprocessed_width(&self) -> usize {
63        NUM_PUBLIC_VALUES_PREPROCESSED_COLS
64    }
65
66    fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
67        let mut rows: Vec<[F; NUM_PUBLIC_VALUES_PREPROCESSED_COLS]> = Vec::new();
68        let commit_pv_hash_instrs = program
69            .instructions
70            .iter()
71            .filter_map(|instruction| {
72                if let Instruction::CommitPublicValues(instr) = instruction {
73                    Some(instr)
74                } else {
75                    None
76                }
77            })
78            .collect::<Vec<_>>();
79
80        if commit_pv_hash_instrs.len() != 1 {
81            tracing::warn!("Expected exactly one CommitPVHash instruction.");
82        }
83
84        // We only take 1 commit pv hash instruction, since our air only checks for one public
85        // values hash.
86        for instr in commit_pv_hash_instrs.iter().take(1) {
87            for (i, addr) in instr.pv_addrs.digest.iter().enumerate() {
88                let mut row = [F::zero(); NUM_PUBLIC_VALUES_PREPROCESSED_COLS];
89                let cols: &mut PublicValuesPreprocessedCols<F> = row.as_mut_slice().borrow_mut();
90                cols.pv_idx[i] = F::one();
91                cols.pv_mem = MemoryAccessCols { addr: *addr, mult: F::neg_one() };
92                rows.push(row);
93            }
94        }
95
96        // Pad the preprocessed rows to 8 rows.
97        pad_rows_fixed(&mut rows, || [F::zero(); NUM_PUBLIC_VALUES_PREPROCESSED_COLS], Some(3));
98
99        let trace = RowMajorMatrix::new(
100            rows.into_iter().flatten().collect(),
101            NUM_PUBLIC_VALUES_PREPROCESSED_COLS,
102        );
103        Some(trace)
104    }
105
106    fn generate_trace(
107        &self,
108        input: &ExecutionRecord<F>,
109        _: &mut ExecutionRecord<F>,
110    ) -> RowMajorMatrix<F> {
111        if input.commit_pv_hash_events.len() != 1 {
112            tracing::warn!("Expected exactly one CommitPVHash event.");
113        }
114
115        let mut rows: Vec<[F; NUM_PUBLIC_VALUES_COLS]> = Vec::new();
116
117        // We only take 1 commit pv hash instruction, since our air only checks for one public
118        // values hash.
119        for event in input.commit_pv_hash_events.iter().take(1) {
120            for element in event.public_values.digest.iter() {
121                let mut row = [F::zero(); NUM_PUBLIC_VALUES_COLS];
122                let cols: &mut PublicValuesCols<F> = row.as_mut_slice().borrow_mut();
123
124                cols.pv_element = *element;
125                rows.push(row);
126            }
127        }
128
129        // Pad the trace to 8 rows.
130        pad_rows_fixed(&mut rows, || [F::zero(); NUM_PUBLIC_VALUES_COLS], Some(3));
131
132        // Convert the trace to a row major matrix.
133        RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_PUBLIC_VALUES_COLS)
134    }
135
136    fn included(&self, _record: &Self::Record) -> bool {
137        true
138    }
139}
140
141impl<AB> Air<AB> for PublicValuesChip
142where
143    AB: SP1RecursionAirBuilder + PairBuilder,
144{
145    fn eval(&self, builder: &mut AB) {
146        let main = builder.main();
147        let local = main.row_slice(0);
148        let local: &PublicValuesCols<AB::Var> = (*local).borrow();
149        let prepr = builder.preprocessed();
150        let local_prepr = prepr.row_slice(0);
151        let local_prepr: &PublicValuesPreprocessedCols<AB::Var> = (*local_prepr).borrow();
152        let pv = builder.public_values();
153        let pv_elms: [AB::Expr; RECURSIVE_PROOF_NUM_PV_ELTS] =
154            core::array::from_fn(|i| pv[i].into());
155        let public_values: &RecursionPublicValues<AB::Expr> = pv_elms.as_slice().borrow();
156
157        // Constrain mem read for the public value element.
158        builder.send_single(local_prepr.pv_mem.addr, local.pv_element, local_prepr.pv_mem.mult);
159
160        for (i, pv_elm) in public_values.digest.iter().enumerate() {
161            // Ensure that the public value element is the same for all rows within a fri fold
162            // invocation.
163            builder.when(local_prepr.pv_idx[i]).assert_eq(pv_elm.clone(), local.pv_element);
164        }
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use rand::{rngs::StdRng, Rng, SeedableRng};
171    use sp1_core_machine::utils::setup_logger;
172    use sp1_recursion_core::{
173        air::{RecursionPublicValues, NUM_PV_ELMS_TO_HASH, RECURSIVE_PROOF_NUM_PV_ELTS},
174        stark::config::BabyBearPoseidon2Outer,
175    };
176    use sp1_stark::{air::MachineAir, StarkGenericConfig};
177    use std::{array, borrow::Borrow};
178
179    use p3_baby_bear::BabyBear;
180    use p3_field::AbstractField;
181    use p3_matrix::dense::RowMajorMatrix;
182
183    use crate::{
184        chips::public_values::PublicValuesChip,
185        machine::tests::run_recursion_test_machines,
186        runtime::{instruction as instr, ExecutionRecord},
187        CommitPublicValuesEvent, MemAccessKind, RecursionProgram, DIGEST_SIZE,
188    };
189
190    #[test]
191    fn prove_babybear_circuit_public_values() {
192        setup_logger();
193        type SC = BabyBearPoseidon2Outer;
194        type F = <SC as StarkGenericConfig>::Val;
195
196        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
197        let mut random_felt = move || -> F { F::from_canonical_u32(rng.gen_range(0..1 << 16)) };
198        let random_pv_elms: [F; RECURSIVE_PROOF_NUM_PV_ELTS] = array::from_fn(|_| random_felt());
199        let addr = 0u32;
200        let public_values_a: [u32; RECURSIVE_PROOF_NUM_PV_ELTS] =
201            array::from_fn(|i| i as u32 + addr);
202
203        let mut instructions = Vec::new();
204        // Allocate the memory for the public values hash.
205
206        for i in 0..RECURSIVE_PROOF_NUM_PV_ELTS {
207            let mult = (NUM_PV_ELMS_TO_HASH..NUM_PV_ELMS_TO_HASH + DIGEST_SIZE).contains(&i);
208            instructions.push(instr::mem_block(
209                MemAccessKind::Write,
210                mult as u32,
211                public_values_a[i],
212                random_pv_elms[i].into(),
213            ));
214        }
215        let public_values_a: &RecursionPublicValues<u32> = public_values_a.as_slice().borrow();
216        instructions.push(instr::commit_public_values(public_values_a));
217
218        let program = RecursionProgram { instructions, ..Default::default() };
219
220        run_recursion_test_machines(program);
221    }
222
223    #[test]
224    fn generate_public_values_circuit_trace() {
225        type F = BabyBear;
226
227        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
228        let random_felts: [F; RECURSIVE_PROOF_NUM_PV_ELTS] =
229            array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..1 << 16)));
230        let random_public_values: &RecursionPublicValues<F> = random_felts.as_slice().borrow();
231
232        let shard = ExecutionRecord {
233            commit_pv_hash_events: vec![CommitPublicValuesEvent {
234                public_values: *random_public_values,
235            }],
236            ..Default::default()
237        };
238        let chip = PublicValuesChip::default();
239        let trace: RowMajorMatrix<F> = chip.generate_trace(&shard, &mut ExecutionRecord::default());
240        println!("{:?}", trace.values)
241    }
242}