sp1_recursion_core/chips/
public_values.rs1use crate::{
2 air::{RecursionPublicValues, RECURSIVE_PROOF_NUM_PV_ELTS},
3 builder::SP1RecursionAirBuilder,
4 runtime::Instruction,
5 CommitPublicValuesEvent, CommitPublicValuesInstr, ExecutionRecord, DIGEST_SIZE,
6};
7use p3_air::{Air, AirBuilder, BaseAir, PairBuilder};
8use p3_baby_bear::BabyBear;
9use p3_field::{AbstractField, PrimeField32};
10use p3_matrix::{dense::RowMajorMatrix, Matrix};
11use sp1_core_machine::utils::pad_rows_fixed;
12use sp1_derive::AlignedBorrow;
13use sp1_stark::air::MachineAir;
14use std::borrow::{Borrow, BorrowMut};
15
16use super::mem::MemoryAccessColsChips;
17
18pub const NUM_PUBLIC_VALUES_COLS: usize = core::mem::size_of::<PublicValuesCols<u8>>();
19pub const NUM_PUBLIC_VALUES_PREPROCESSED_COLS: usize =
20 core::mem::size_of::<PublicValuesPreprocessedCols<u8>>();
21
22pub const PUB_VALUES_LOG_HEIGHT: usize = 4;
23
24#[derive(Default)]
25pub struct PublicValuesChip;
26
27#[derive(AlignedBorrow, Debug, Clone, Copy)]
29#[repr(C)]
30pub struct PublicValuesPreprocessedCols<T: Copy> {
31 pub pv_idx: [T; DIGEST_SIZE],
32 pub pv_mem: MemoryAccessColsChips<T>,
33}
34
35#[derive(AlignedBorrow, Debug, Clone, Copy)]
37#[repr(C)]
38pub struct PublicValuesCols<T: Copy> {
39 pub pv_element: T,
40}
41
42impl<F> BaseAir<F> for PublicValuesChip {
43 fn width(&self) -> usize {
44 NUM_PUBLIC_VALUES_COLS
45 }
46}
47
48impl<F: PrimeField32> MachineAir<F> for PublicValuesChip {
49 type Record = ExecutionRecord<F>;
50
51 type Program = crate::RecursionProgram<F>;
52
53 fn name(&self) -> String {
54 "PublicValues".to_string()
55 }
56
57 fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
58 }
60
61 fn preprocessed_width(&self) -> usize {
62 NUM_PUBLIC_VALUES_PREPROCESSED_COLS
63 }
64
65 fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
66 assert_eq!(
67 std::any::TypeId::of::<F>(),
68 std::any::TypeId::of::<BabyBear>(),
69 "generate_preprocessed_trace only supports BabyBear field"
70 );
71
72 let mut rows: Vec<[BabyBear; NUM_PUBLIC_VALUES_PREPROCESSED_COLS]> = Vec::new();
73 let commit_pv_hash_instrs: Vec<&Box<CommitPublicValuesInstr<BabyBear>>> = program
74 .inner
75 .iter()
76 .filter_map(|instruction| {
77 if let Instruction::CommitPublicValues(instr) = instruction {
78 Some(unsafe {
79 std::mem::transmute::<
80 &Box<CommitPublicValuesInstr<F>>,
81 &Box<CommitPublicValuesInstr<BabyBear>>,
82 >(instr)
83 })
84 } else {
85 None
86 }
87 })
88 .collect::<Vec<_>>();
89
90 if commit_pv_hash_instrs.len() != 1 {
91 tracing::warn!("Expected exactly one CommitPVHash instruction.");
92 }
93
94 for instr in commit_pv_hash_instrs.iter().take(1) {
97 for i in 0..DIGEST_SIZE {
98 let mut row = [BabyBear::zero(); NUM_PUBLIC_VALUES_PREPROCESSED_COLS];
99 let cols: &mut PublicValuesPreprocessedCols<BabyBear> =
100 row.as_mut_slice().borrow_mut();
101 unsafe {
102 crate::sys::public_values_instr_to_row_babybear(instr, i, cols);
103 }
104 rows.push(row);
105 }
106 }
107
108 pad_rows_fixed(
111 &mut rows,
112 || [BabyBear::zero(); NUM_PUBLIC_VALUES_PREPROCESSED_COLS],
113 Some(PUB_VALUES_LOG_HEIGHT),
114 );
115
116 let trace = RowMajorMatrix::new(
117 unsafe {
118 std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
119 rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
120 )
121 },
122 NUM_PUBLIC_VALUES_PREPROCESSED_COLS,
123 );
124 Some(trace)
125 }
126
127 fn generate_trace(
128 &self,
129 input: &ExecutionRecord<F>,
130 _: &mut ExecutionRecord<F>,
131 ) -> RowMajorMatrix<F> {
132 assert_eq!(
133 std::any::TypeId::of::<F>(),
134 std::any::TypeId::of::<BabyBear>(),
135 "generate_trace only supports BabyBear field"
136 );
137
138 if input.commit_pv_hash_events.len() != 1 {
139 tracing::warn!("Expected exactly one CommitPVHash event.");
140 }
141
142 let mut rows: Vec<[BabyBear; NUM_PUBLIC_VALUES_COLS]> = Vec::new();
143
144 for event in input.commit_pv_hash_events.iter().take(1) {
147 let bb_event = unsafe {
148 std::mem::transmute::<&CommitPublicValuesEvent<F>, &CommitPublicValuesEvent<BabyBear>>(
149 event,
150 )
151 };
152 for i in 0..DIGEST_SIZE {
153 let mut row = [BabyBear::zero(); NUM_PUBLIC_VALUES_COLS];
154 let cols: &mut PublicValuesCols<BabyBear> = row.as_mut_slice().borrow_mut();
155 unsafe {
156 crate::sys::public_values_event_to_row_babybear(bb_event, i, cols);
157 }
158 rows.push(row);
159 }
160 }
161
162 pad_rows_fixed(
164 &mut rows,
165 || [BabyBear::zero(); NUM_PUBLIC_VALUES_COLS],
166 Some(PUB_VALUES_LOG_HEIGHT),
167 );
168
169 RowMajorMatrix::new(
171 unsafe {
172 std::mem::transmute::<Vec<BabyBear>, Vec<F>>(
173 rows.into_iter().flatten().collect::<Vec<BabyBear>>(),
174 )
175 },
176 NUM_PUBLIC_VALUES_COLS,
177 )
178 }
179
180 fn included(&self, _record: &Self::Record) -> bool {
181 true
182 }
183}
184
185impl<AB> Air<AB> for PublicValuesChip
186where
187 AB: SP1RecursionAirBuilder + PairBuilder,
188{
189 fn eval(&self, builder: &mut AB) {
190 let main = builder.main();
191 let local = main.row_slice(0);
192 let local: &PublicValuesCols<AB::Var> = (*local).borrow();
193 let prepr = builder.preprocessed();
194 let local_prepr = prepr.row_slice(0);
195 let local_prepr: &PublicValuesPreprocessedCols<AB::Var> = (*local_prepr).borrow();
196 let pv = builder.public_values();
197 let pv_elms: [AB::Expr; RECURSIVE_PROOF_NUM_PV_ELTS] =
198 core::array::from_fn(|i| pv[i].into());
199 let public_values: &RecursionPublicValues<AB::Expr> = pv_elms.as_slice().borrow();
200
201 builder.send_single(local_prepr.pv_mem.addr, local.pv_element, local_prepr.pv_mem.mult);
203
204 for (i, pv_elm) in public_values.digest.iter().enumerate() {
205 builder.when(local_prepr.pv_idx[i]).assert_eq(pv_elm.clone(), local.pv_element);
208 }
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 #![allow(clippy::print_stdout)]
215
216 use crate::{
217 air::{RecursionPublicValues, NUM_PV_ELMS_TO_HASH, RECURSIVE_PROOF_NUM_PV_ELTS},
218 chips::{
219 mem::MemoryAccessCols,
220 public_values::{
221 PublicValuesChip, PublicValuesCols, PublicValuesPreprocessedCols,
222 NUM_PUBLIC_VALUES_COLS, NUM_PUBLIC_VALUES_PREPROCESSED_COLS, PUB_VALUES_LOG_HEIGHT,
223 },
224 test_fixtures,
225 },
226 machine::tests::test_recursion_linear_program,
227 runtime::{instruction as instr, ExecutionRecord},
228 stark::BabyBearPoseidon2Outer,
229 Instruction, MemAccessKind, RecursionProgram, DIGEST_SIZE,
230 };
231 use p3_baby_bear::BabyBear;
232 use p3_field::AbstractField;
233 use p3_matrix::{dense::RowMajorMatrix, Matrix};
234 use rand::{rngs::StdRng, Rng, SeedableRng};
235 use sp1_core_machine::utils::{pad_rows_fixed, setup_logger};
236 use sp1_stark::{air::MachineAir, StarkGenericConfig};
237 use std::{
238 array,
239 borrow::{Borrow, BorrowMut},
240 };
241
242 #[test]
243 fn prove_babybear_circuit_public_values() {
244 setup_logger();
245 type SC = BabyBearPoseidon2Outer;
246 type F = <SC as StarkGenericConfig>::Val;
247
248 let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
249 let mut random_felt = move || -> F { F::from_canonical_u32(rng.gen_range(0..1 << 16)) };
250 let random_pv_elms: [F; RECURSIVE_PROOF_NUM_PV_ELTS] = array::from_fn(|_| random_felt());
251 let public_values_a: [u32; RECURSIVE_PROOF_NUM_PV_ELTS] = array::from_fn(|i| i as u32);
252
253 let mut instructions = Vec::new();
254 for i in 0..RECURSIVE_PROOF_NUM_PV_ELTS {
257 let mult = (NUM_PV_ELMS_TO_HASH..NUM_PV_ELMS_TO_HASH + DIGEST_SIZE).contains(&i);
258 instructions.push(instr::mem_block(
259 MemAccessKind::Write,
260 mult as u32,
261 public_values_a[i],
262 random_pv_elms[i].into(),
263 ));
264 }
265 let public_values_a: &RecursionPublicValues<u32> = public_values_a.as_slice().borrow();
266 instructions.push(instr::commit_public_values(public_values_a));
267
268 test_recursion_linear_program(instructions);
269 }
270
271 #[test]
272 #[ignore = "Failing due to merge conflicts. Will be fixed shortly."]
273 fn generate_public_values_preprocessed_trace() {
274 let program = test_fixtures::program();
275
276 let chip = PublicValuesChip;
277 let trace = chip.generate_preprocessed_trace(&program).unwrap();
278 println!("{:?}", trace.values);
279 }
280
281 fn generate_trace_reference(
282 input: &ExecutionRecord<BabyBear>,
283 _: &mut ExecutionRecord<BabyBear>,
284 ) -> RowMajorMatrix<BabyBear> {
285 type F = BabyBear;
286
287 if input.commit_pv_hash_events.len() != 1 {
288 tracing::warn!("Expected exactly one CommitPVHash event.");
289 }
290
291 let mut rows: Vec<[F; NUM_PUBLIC_VALUES_COLS]> = Vec::new();
292
293 for event in input.commit_pv_hash_events.iter().take(1) {
296 for element in event.public_values.digest.iter() {
297 let mut row = [F::zero(); NUM_PUBLIC_VALUES_COLS];
298 let cols: &mut PublicValuesCols<F> = row.as_mut_slice().borrow_mut();
299
300 cols.pv_element = *element;
301 rows.push(row);
302 }
303 }
304
305 pad_rows_fixed(
307 &mut rows,
308 || [F::zero(); NUM_PUBLIC_VALUES_COLS],
309 Some(PUB_VALUES_LOG_HEIGHT),
310 );
311
312 RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_PUBLIC_VALUES_COLS)
313 }
314
315 #[test]
316 fn test_generate_trace() {
317 let shard = test_fixtures::shard();
318 let trace = PublicValuesChip.generate_trace(&shard, &mut ExecutionRecord::default());
319 assert_eq!(trace.height(), 16);
320
321 assert_eq!(trace, generate_trace_reference(&shard, &mut ExecutionRecord::default()));
322 }
323
324 fn generate_preprocessed_trace_reference(
325 program: &RecursionProgram<BabyBear>,
326 ) -> RowMajorMatrix<BabyBear> {
327 type F = BabyBear;
328
329 let mut rows: Vec<[F; NUM_PUBLIC_VALUES_PREPROCESSED_COLS]> = Vec::new();
330 let commit_pv_hash_instrs = program
331 .inner
332 .iter()
333 .filter_map(|instruction| {
334 if let Instruction::CommitPublicValues(instr) = instruction {
335 Some(instr)
336 } else {
337 None
338 }
339 })
340 .collect::<Vec<_>>();
341
342 if commit_pv_hash_instrs.len() != 1 {
343 tracing::warn!("Expected exactly one CommitPVHash instruction.");
344 }
345
346 for instr in commit_pv_hash_instrs.iter().take(1) {
348 for (i, addr) in instr.pv_addrs.digest.iter().enumerate() {
349 let mut row = [F::zero(); NUM_PUBLIC_VALUES_PREPROCESSED_COLS];
350 let cols: &mut PublicValuesPreprocessedCols<F> = row.as_mut_slice().borrow_mut();
351 cols.pv_idx[i] = F::one();
352 cols.pv_mem = MemoryAccessCols { addr: *addr, mult: F::neg_one() };
353 rows.push(row);
354 }
355 }
356
357 pad_rows_fixed(
359 &mut rows,
360 || [F::zero(); NUM_PUBLIC_VALUES_PREPROCESSED_COLS],
361 Some(PUB_VALUES_LOG_HEIGHT),
362 );
363
364 RowMajorMatrix::new(
365 rows.into_iter().flatten().collect(),
366 NUM_PUBLIC_VALUES_PREPROCESSED_COLS,
367 )
368 }
369
370 #[test]
371 #[ignore = "Failing due to merge conflicts. Will be fixed shortly."]
372 fn test_generate_preprocessed_trace() {
373 let program = test_fixtures::program();
374 let trace = PublicValuesChip.generate_preprocessed_trace(&program).unwrap();
375 assert_eq!(trace.height(), 16);
376
377 assert_eq!(trace, generate_preprocessed_trace_reference(&program));
378 }
379}