sp1_recursion_machine/chips/
public_values.rs1use crate::builder::SP1RecursionAirBuilder;
2use slop_air::{Air, AirBuilder, BaseAir, PairBuilder};
3use slop_algebra::PrimeField32;
4use slop_matrix::Matrix;
5use sp1_derive::AlignedBorrow;
6use sp1_hypercube::air::MachineAir;
7use sp1_primitives::SP1Field;
8use sp1_recursion_executor::{
9 ExecutionRecord, Instruction, RecursionProgram, RecursionPublicValues, DIGEST_SIZE,
10 RECURSIVE_PROOF_NUM_PV_ELTS,
11};
12use std::{
13 borrow::{Borrow, BorrowMut},
14 mem::MaybeUninit,
15};
16
17use super::mem::MemoryAccessColsChips;
18use crate::chips::mem::MemoryAccessCols;
19
20pub const NUM_PUBLIC_VALUES_COLS: usize = core::mem::size_of::<PublicValuesCols<u8>>();
21pub const NUM_PUBLIC_VALUES_PREPROCESSED_COLS: usize =
22 core::mem::size_of::<PublicValuesPreprocessedCols<u8>>();
23
24pub const PUB_VALUES_LOG_HEIGHT: usize = 4;
25
26#[derive(Default, Clone)]
27pub struct PublicValuesChip;
28
29#[derive(AlignedBorrow, Debug, Clone, Copy)]
31#[repr(C)]
32pub struct PublicValuesPreprocessedCols<T: Copy> {
33 pub pv_idx: [T; DIGEST_SIZE],
34 pub pv_mem: MemoryAccessColsChips<T>,
35}
36
37#[derive(AlignedBorrow, Debug, Clone, Copy)]
39#[repr(C)]
40pub struct PublicValuesCols<T: Copy> {
41 pub pv_element: T,
42}
43
44impl<F> BaseAir<F> for PublicValuesChip {
45 fn width(&self) -> usize {
46 NUM_PUBLIC_VALUES_COLS
47 }
48}
49
50impl<F: PrimeField32> MachineAir<F> for PublicValuesChip {
51 type Record = ExecutionRecord<F>;
52
53 type Program = RecursionProgram<F>;
54
55 fn name(&self) -> &'static str {
56 "PublicValues"
57 }
58
59 fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
60 }
62
63 fn preprocessed_width(&self) -> usize {
64 NUM_PUBLIC_VALUES_PREPROCESSED_COLS
65 }
66
67 fn num_rows(&self, _: &Self::Record) -> Option<usize> {
68 Some(1 << PUB_VALUES_LOG_HEIGHT)
69 }
70
71 fn preprocessed_num_rows(&self, _program: &Self::Program) -> Option<usize> {
72 Some(1 << PUB_VALUES_LOG_HEIGHT)
73 }
74
75 fn preprocessed_num_rows_with_instrs_len(&self, _: &Self::Program, _: usize) -> Option<usize> {
76 Some(1 << PUB_VALUES_LOG_HEIGHT)
77 }
78
79 fn generate_preprocessed_trace_into(
80 &self,
81 program: &Self::Program,
82 buffer: &mut [MaybeUninit<F>],
83 ) {
84 assert_eq!(
85 std::any::TypeId::of::<F>(),
86 std::any::TypeId::of::<SP1Field>(),
87 "generate_preprocessed_trace only supports SP1Field field"
88 );
89
90 let padded_nb_rows = self.preprocessed_num_rows(program).unwrap();
91
92 unsafe {
93 let padding_size = padded_nb_rows * NUM_PUBLIC_VALUES_PREPROCESSED_COLS;
94 core::ptr::write_bytes(buffer.as_mut_ptr(), 0, padding_size);
95 }
96
97 let buffer_ptr = buffer.as_mut_ptr() as *mut F;
98 let values = unsafe {
99 core::slice::from_raw_parts_mut(
100 buffer_ptr,
101 padded_nb_rows * NUM_PUBLIC_VALUES_PREPROCESSED_COLS,
102 )
103 };
104
105 let commit_pv_hash_instrs = program
106 .inner
107 .iter()
108 .filter_map(|instruction| {
109 if let Instruction::CommitPublicValues(instr) = instruction.inner() {
110 Some(instr)
111 } else {
112 None
113 }
114 })
115 .collect::<Vec<_>>();
116
117 if commit_pv_hash_instrs.len() != 1 {
118 tracing::warn!("Expected exactly one CommitPVHash instruction.");
119 }
120
121 for instr in commit_pv_hash_instrs.iter().take(1) {
124 for (i, addr) in instr.pv_addrs.digest.iter().enumerate() {
125 let start = i * NUM_PUBLIC_VALUES_PREPROCESSED_COLS;
126 let end = (i + 1) * NUM_PUBLIC_VALUES_PREPROCESSED_COLS;
127 let cols: &mut PublicValuesPreprocessedCols<F> = values[start..end].borrow_mut();
128 cols.pv_idx[i] = F::one();
129 cols.pv_mem = MemoryAccessCols { addr: *addr, mult: F::one() };
130 }
131 }
132 }
133
134 fn generate_trace_into(
135 &self,
136 input: &ExecutionRecord<F>,
137 _: &mut ExecutionRecord<F>,
138 buffer: &mut [MaybeUninit<F>],
139 ) {
140 assert_eq!(
141 std::any::TypeId::of::<F>(),
142 std::any::TypeId::of::<SP1Field>(),
143 "generate_trace_into only supports SP1Field"
144 );
145 let padded_nb_rows = <PublicValuesChip as MachineAir<F>>::num_rows(self, input).unwrap();
146
147 unsafe {
148 let padding_size = padded_nb_rows * NUM_PUBLIC_VALUES_COLS;
149 core::ptr::write_bytes(buffer.as_mut_ptr(), 0, padding_size);
150 }
151
152 let buffer_ptr = buffer.as_mut_ptr() as *mut F;
153 let values = unsafe {
154 core::slice::from_raw_parts_mut(buffer_ptr, padded_nb_rows * NUM_PUBLIC_VALUES_COLS)
155 };
156
157 for event in input.commit_pv_hash_events.iter().take(1) {
158 for (idx, element) in event.public_values.digest.iter().enumerate() {
159 let start = idx * NUM_PUBLIC_VALUES_COLS;
160 let end = (idx + 1) * NUM_PUBLIC_VALUES_COLS;
161 let cols: &mut PublicValuesCols<F> = values[start..end].borrow_mut();
162 cols.pv_element = *element;
163 }
164 }
165 }
166
167 fn included(&self, _record: &Self::Record) -> bool {
168 true
169 }
170}
171
172impl<AB> Air<AB> for PublicValuesChip
173where
174 AB: SP1RecursionAirBuilder + PairBuilder,
175{
176 fn eval(&self, builder: &mut AB) {
177 let main = builder.main();
178 let local = main.row_slice(0);
179 let local: &PublicValuesCols<AB::Var> = (*local).borrow();
180 let prepr = builder.preprocessed();
181 let local_prepr = prepr.row_slice(0);
182 let local_prepr: &PublicValuesPreprocessedCols<AB::Var> = (*local_prepr).borrow();
183 let pv = builder.public_values();
184 let pv_elms: [AB::Expr; RECURSIVE_PROOF_NUM_PV_ELTS] =
185 core::array::from_fn(|i| pv[i].into());
186 let public_values: &RecursionPublicValues<AB::Expr> = pv_elms.as_slice().borrow();
187
188 builder.receive_single(local_prepr.pv_mem.addr, local.pv_element, local_prepr.pv_mem.mult);
190
191 for (i, pv_elm) in public_values.digest.iter().enumerate() {
192 builder.when(local_prepr.pv_idx[i]).assert_eq(pv_elm.clone(), local.pv_element);
193 }
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 #![allow(clippy::print_stdout)]
200
201 use crate::{
202 chips::{public_values::PublicValuesChip, test_fixtures},
203 test::test_recursion_linear_program,
204 };
205 use rand::{rngs::StdRng, Rng, SeedableRng};
206 use slop_algebra::AbstractField;
207
208 use slop_challenger::IopCtx;
209 use slop_matrix::Matrix;
210 use sp1_core_machine::utils::setup_logger;
211 use sp1_hypercube::air::MachineAir;
212 use sp1_primitives::SP1GlobalContext;
213 use sp1_recursion_executor::{
214 instruction as instr, ExecutionRecord, MemAccessKind, RecursionPublicValues, DIGEST_SIZE,
215 NUM_PV_ELMS_TO_HASH, RECURSIVE_PROOF_NUM_PV_ELTS,
216 };
217 use std::{array, borrow::Borrow};
218
219 #[tokio::test]
220 async fn prove_koalabear_circuit_public_values() {
221 setup_logger();
222 type F = <SP1GlobalContext as IopCtx>::F;
223
224 let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
225 let mut random_felt = move || -> F { F::from_canonical_u32(rng.gen_range(0..1 << 16)) };
226 let random_pv_elms: [F; RECURSIVE_PROOF_NUM_PV_ELTS] = array::from_fn(|_| random_felt());
227 let public_values_a: [u32; RECURSIVE_PROOF_NUM_PV_ELTS] = array::from_fn(|i| i as u32);
228
229 let mut instructions = Vec::new();
230 for i in 0..RECURSIVE_PROOF_NUM_PV_ELTS {
233 let mult = (NUM_PV_ELMS_TO_HASH..NUM_PV_ELMS_TO_HASH + DIGEST_SIZE).contains(&i);
234 instructions.push(instr::mem_block(
235 MemAccessKind::Write,
236 mult as u32,
237 public_values_a[i],
238 random_pv_elms[i].into(),
239 ));
240 }
241 let public_values_a: &RecursionPublicValues<u32> = public_values_a.as_slice().borrow();
242 instructions.push(instr::commit_public_values(public_values_a));
243
244 test_recursion_linear_program(instructions).await;
245 }
246
247 #[tokio::test]
248 async fn generate_trace() {
249 let shard = test_fixtures::shard().await;
250 let trace = PublicValuesChip.generate_trace(shard, &mut ExecutionRecord::default());
251 assert_eq!(trace.height(), 16);
252 }
253
254 #[tokio::test]
255 async fn generate_preprocessed_trace() {
256 let program = &test_fixtures::program_with_input().await.0;
257 let trace = PublicValuesChip.generate_preprocessed_trace(program).unwrap();
258 assert_eq!(trace.height(), 16);
259 }
260}