Skip to main content

sp1_recursion_machine/chips/
public_values.rs

1use crate::builder::SP1RecursionAirBuilder;
2use slop_air::{Air, AirBuilder, BaseAir, PairBuilder};
3use slop_algebra::PrimeField32;
4use slop_matrix::Matrix;
5use sp1_derive::AlignedBorrow;
6use sp1_hypercube::air::MachineAir;
7use sp1_primitives::SP1Field;
8use sp1_recursion_executor::{
9    ExecutionRecord, Instruction, RecursionProgram, RecursionPublicValues, DIGEST_SIZE,
10    RECURSIVE_PROOF_NUM_PV_ELTS,
11};
12use std::{
13    borrow::{Borrow, BorrowMut},
14    mem::MaybeUninit,
15};
16
17use super::mem::MemoryAccessColsChips;
18use crate::chips::mem::MemoryAccessCols;
19
20pub const NUM_PUBLIC_VALUES_COLS: usize = core::mem::size_of::<PublicValuesCols<u8>>();
21pub const NUM_PUBLIC_VALUES_PREPROCESSED_COLS: usize =
22    core::mem::size_of::<PublicValuesPreprocessedCols<u8>>();
23
24pub const PUB_VALUES_LOG_HEIGHT: usize = 4;
25
26#[derive(Default, Clone)]
27pub struct PublicValuesChip;
28
29/// The preprocessed columns for the CommitPVHash instruction.
30#[derive(AlignedBorrow, Debug, Clone, Copy)]
31#[repr(C)]
32pub struct PublicValuesPreprocessedCols<T: Copy> {
33    pub pv_idx: [T; DIGEST_SIZE],
34    pub pv_mem: MemoryAccessColsChips<T>,
35}
36
37/// The cols for a CommitPVHash invocation.
38#[derive(AlignedBorrow, Debug, Clone, Copy)]
39#[repr(C)]
40pub struct PublicValuesCols<T: Copy> {
41    pub pv_element: T,
42}
43
44impl<F> BaseAir<F> for PublicValuesChip {
45    fn width(&self) -> usize {
46        NUM_PUBLIC_VALUES_COLS
47    }
48}
49
50impl<F: PrimeField32> MachineAir<F> for PublicValuesChip {
51    type Record = ExecutionRecord<F>;
52
53    type Program = RecursionProgram<F>;
54
55    fn name(&self) -> &'static str {
56        "PublicValues"
57    }
58
59    fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
60        // This is a no-op.
61    }
62
63    fn preprocessed_width(&self) -> usize {
64        NUM_PUBLIC_VALUES_PREPROCESSED_COLS
65    }
66
67    fn num_rows(&self, _: &Self::Record) -> Option<usize> {
68        Some(1 << PUB_VALUES_LOG_HEIGHT)
69    }
70
71    fn preprocessed_num_rows(&self, _program: &Self::Program) -> Option<usize> {
72        Some(1 << PUB_VALUES_LOG_HEIGHT)
73    }
74
75    fn preprocessed_num_rows_with_instrs_len(&self, _: &Self::Program, _: usize) -> Option<usize> {
76        Some(1 << PUB_VALUES_LOG_HEIGHT)
77    }
78
79    fn generate_preprocessed_trace_into(
80        &self,
81        program: &Self::Program,
82        buffer: &mut [MaybeUninit<F>],
83    ) {
84        assert_eq!(
85            std::any::TypeId::of::<F>(),
86            std::any::TypeId::of::<SP1Field>(),
87            "generate_preprocessed_trace only supports SP1Field field"
88        );
89
90        let padded_nb_rows = self.preprocessed_num_rows(program).unwrap();
91
92        unsafe {
93            let padding_size = padded_nb_rows * NUM_PUBLIC_VALUES_PREPROCESSED_COLS;
94            core::ptr::write_bytes(buffer.as_mut_ptr(), 0, padding_size);
95        }
96
97        let buffer_ptr = buffer.as_mut_ptr() as *mut F;
98        let values = unsafe {
99            core::slice::from_raw_parts_mut(
100                buffer_ptr,
101                padded_nb_rows * NUM_PUBLIC_VALUES_PREPROCESSED_COLS,
102            )
103        };
104
105        let commit_pv_hash_instrs = program
106            .inner
107            .iter()
108            .filter_map(|instruction| {
109                if let Instruction::CommitPublicValues(instr) = instruction.inner() {
110                    Some(instr)
111                } else {
112                    None
113                }
114            })
115            .collect::<Vec<_>>();
116
117        if commit_pv_hash_instrs.len() != 1 {
118            tracing::warn!("Expected exactly one CommitPVHash instruction.");
119        }
120
121        // We only take 1 commit pv hash instruction, since our air only checks for one public
122        // values hash.
123        for instr in commit_pv_hash_instrs.iter().take(1) {
124            for (i, addr) in instr.pv_addrs.digest.iter().enumerate() {
125                let start = i * NUM_PUBLIC_VALUES_PREPROCESSED_COLS;
126                let end = (i + 1) * NUM_PUBLIC_VALUES_PREPROCESSED_COLS;
127                let cols: &mut PublicValuesPreprocessedCols<F> = values[start..end].borrow_mut();
128                cols.pv_idx[i] = F::one();
129                cols.pv_mem = MemoryAccessCols { addr: *addr, mult: F::one() };
130            }
131        }
132    }
133
134    fn generate_trace_into(
135        &self,
136        input: &ExecutionRecord<F>,
137        _: &mut ExecutionRecord<F>,
138        buffer: &mut [MaybeUninit<F>],
139    ) {
140        assert_eq!(
141            std::any::TypeId::of::<F>(),
142            std::any::TypeId::of::<SP1Field>(),
143            "generate_trace_into only supports SP1Field"
144        );
145        let padded_nb_rows = <PublicValuesChip as MachineAir<F>>::num_rows(self, input).unwrap();
146
147        unsafe {
148            let padding_size = padded_nb_rows * NUM_PUBLIC_VALUES_COLS;
149            core::ptr::write_bytes(buffer.as_mut_ptr(), 0, padding_size);
150        }
151
152        let buffer_ptr = buffer.as_mut_ptr() as *mut F;
153        let values = unsafe {
154            core::slice::from_raw_parts_mut(buffer_ptr, padded_nb_rows * NUM_PUBLIC_VALUES_COLS)
155        };
156
157        for event in input.commit_pv_hash_events.iter().take(1) {
158            for (idx, element) in event.public_values.digest.iter().enumerate() {
159                let start = idx * NUM_PUBLIC_VALUES_COLS;
160                let end = (idx + 1) * NUM_PUBLIC_VALUES_COLS;
161                let cols: &mut PublicValuesCols<F> = values[start..end].borrow_mut();
162                cols.pv_element = *element;
163            }
164        }
165    }
166
167    fn included(&self, _record: &Self::Record) -> bool {
168        true
169    }
170}
171
172impl<AB> Air<AB> for PublicValuesChip
173where
174    AB: SP1RecursionAirBuilder + PairBuilder,
175{
176    fn eval(&self, builder: &mut AB) {
177        let main = builder.main();
178        let local = main.row_slice(0);
179        let local: &PublicValuesCols<AB::Var> = (*local).borrow();
180        let prepr = builder.preprocessed();
181        let local_prepr = prepr.row_slice(0);
182        let local_prepr: &PublicValuesPreprocessedCols<AB::Var> = (*local_prepr).borrow();
183        let pv = builder.public_values();
184        let pv_elms: [AB::Expr; RECURSIVE_PROOF_NUM_PV_ELTS] =
185            core::array::from_fn(|i| pv[i].into());
186        let public_values: &RecursionPublicValues<AB::Expr> = pv_elms.as_slice().borrow();
187
188        // Constrain mem read for the public value element.
189        builder.receive_single(local_prepr.pv_mem.addr, local.pv_element, local_prepr.pv_mem.mult);
190
191        for (i, pv_elm) in public_values.digest.iter().enumerate() {
192            builder.when(local_prepr.pv_idx[i]).assert_eq(pv_elm.clone(), local.pv_element);
193        }
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    #![allow(clippy::print_stdout)]
200
201    use crate::{
202        chips::{public_values::PublicValuesChip, test_fixtures},
203        test::test_recursion_linear_program,
204    };
205    use rand::{rngs::StdRng, Rng, SeedableRng};
206    use slop_algebra::AbstractField;
207
208    use slop_challenger::IopCtx;
209    use slop_matrix::Matrix;
210    use sp1_core_machine::utils::setup_logger;
211    use sp1_hypercube::air::MachineAir;
212    use sp1_primitives::SP1GlobalContext;
213    use sp1_recursion_executor::{
214        instruction as instr, ExecutionRecord, MemAccessKind, RecursionPublicValues, DIGEST_SIZE,
215        NUM_PV_ELMS_TO_HASH, RECURSIVE_PROOF_NUM_PV_ELTS,
216    };
217    use std::{array, borrow::Borrow};
218
219    #[tokio::test]
220    async fn prove_koalabear_circuit_public_values() {
221        setup_logger();
222        type F = <SP1GlobalContext as IopCtx>::F;
223
224        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
225        let mut random_felt = move || -> F { F::from_canonical_u32(rng.gen_range(0..1 << 16)) };
226        let random_pv_elms: [F; RECURSIVE_PROOF_NUM_PV_ELTS] = array::from_fn(|_| random_felt());
227        let public_values_a: [u32; RECURSIVE_PROOF_NUM_PV_ELTS] = array::from_fn(|i| i as u32);
228
229        let mut instructions = Vec::new();
230        // Allocate the memory for the public values hash.
231
232        for i in 0..RECURSIVE_PROOF_NUM_PV_ELTS {
233            let mult = (NUM_PV_ELMS_TO_HASH..NUM_PV_ELMS_TO_HASH + DIGEST_SIZE).contains(&i);
234            instructions.push(instr::mem_block(
235                MemAccessKind::Write,
236                mult as u32,
237                public_values_a[i],
238                random_pv_elms[i].into(),
239            ));
240        }
241        let public_values_a: &RecursionPublicValues<u32> = public_values_a.as_slice().borrow();
242        instructions.push(instr::commit_public_values(public_values_a));
243
244        test_recursion_linear_program(instructions).await;
245    }
246
247    #[tokio::test]
248    async fn generate_trace() {
249        let shard = test_fixtures::shard().await;
250        let trace = PublicValuesChip.generate_trace(shard, &mut ExecutionRecord::default());
251        assert_eq!(trace.height(), 16);
252    }
253
254    #[tokio::test]
255    async fn generate_preprocessed_trace() {
256        let program = &test_fixtures::program_with_input().await.0;
257        let trace = PublicValuesChip.generate_preprocessed_trace(program).unwrap();
258        assert_eq!(trace.height(), 16);
259    }
260}