1mod shapeable;
2
3pub use shapeable::*;
4
5use std::{collections::BTreeMap, marker::PhantomData, str::FromStr};
6
7use hashbrown::HashMap;
8use itertools::Itertools;
9use num::Integer;
10use p3_baby_bear::BabyBear;
11use p3_field::PrimeField32;
12use p3_util::log2_ceil_usize;
13use sp1_core_executor::{ExecutionRecord, Instruction, Opcode, Program, RiscvAirId};
14use sp1_stark::{
15 air::MachineAir,
16 shape::{OrderedShape, Shape, ShapeCluster},
17};
18use thiserror::Error;
19
20use super::riscv::riscv_chips::{ByteChip, ProgramChip, SyscallChip};
21use crate::{
22 global::GlobalChip,
23 memory::{MemoryLocalChip, NUM_LOCAL_MEMORY_ENTRIES_PER_ROW},
24 riscv::RiscvAir,
25};
26
27const MAXIMAL_SHAPES: &[u8] = include_bytes!("maximal_shapes.json");
33
34const SMALL_SHAPES: &[u8] = include_bytes!("small_shapes.json");
38
39#[derive(Debug)]
41pub struct CoreShapeConfig<F: PrimeField32> {
42 partial_preprocessed_shapes: ShapeCluster<RiscvAirId>,
43 partial_core_shapes: BTreeMap<usize, Vec<ShapeCluster<RiscvAirId>>>,
44 partial_memory_shapes: ShapeCluster<RiscvAirId>,
45 partial_precompile_shapes: HashMap<RiscvAirId, (usize, Vec<usize>)>,
46 partial_small_shapes: Vec<ShapeCluster<RiscvAirId>>,
47 costs: HashMap<RiscvAirId, usize>,
48 _data: PhantomData<F>,
49}
50
51impl<F: PrimeField32> CoreShapeConfig<F> {
52 pub fn fix_preprocessed_shape(&self, program: &mut Program) -> Result<(), CoreShapeError> {
54 if program.preprocessed_shape.is_some() {
56 return Err(CoreShapeError::PreprocessedShapeAlreadyFixed);
57 }
58
59 let preprocessed_heights = RiscvAir::<F>::preprocessed_heights(program);
61 let preprocessed_shape = self
62 .partial_preprocessed_shapes
63 .find_shape(&preprocessed_heights)
64 .ok_or(CoreShapeError::PreprocessedShapeError)?;
65
66 program.preprocessed_shape = Some(preprocessed_shape);
68
69 Ok(())
70 }
71
72 pub fn fix_shape(&self, record: &mut ExecutionRecord) -> Result<(), CoreShapeError> {
74 if record.program.preprocessed_shape.is_none() {
75 return Err(CoreShapeError::PreprocessedShapeMissing);
76 }
77 if record.shape.is_some() {
78 return Err(CoreShapeError::ShapeAlreadyFixed);
79 }
80
81 record.shape.clone_from(&record.program.preprocessed_shape);
84
85 let shape = self.find_shape(record)?;
86 record.shape.as_mut().unwrap().extend(shape);
87 Ok(())
88 }
89
90 pub fn find_shape<R: Shapeable>(
92 &self,
93 record: &R,
94 ) -> Result<Shape<RiscvAirId>, CoreShapeError> {
95 match record.kind() {
96 ShardKind::PackedCore => {
99 let mut heights = record.core_heights();
101 heights.extend(record.memory_heights());
102
103 let (cluster_index, shape, _) = self
104 .minimal_cluster_shape(self.partial_small_shapes.iter().enumerate(), &heights)
105 .ok_or_else(|| {
106 CoreShapeError::ShapeError(
108 heights
109 .iter()
110 .map(|(air, height)| (air.to_string(), log2_ceil_usize(*height)))
111 .collect(),
112 )
113 })?;
114
115 let shard = record.shard();
116 tracing::debug!("Shard Lifted: Index={}, Cluster={}", shard, cluster_index);
117 for (air, height) in heights.iter() {
118 if shape.contains(air) {
119 tracing::debug!(
120 "Chip {:<20}: {:<3} -> {:<3}",
121 air,
122 log2_ceil_usize(*height),
123 shape.log2_height(air).unwrap(),
124 );
125 }
126 }
127 Ok(shape)
128 }
129 ShardKind::Core => {
130 let heights = record.core_heights();
134
135 let log2_shard_size = record.log2_shard_size();
138
139 let (cluster_index, shape, _) = self
140 .minimal_cluster_shape(
141 self.partial_core_shapes
142 .range(log2_shard_size..)
143 .flat_map(|(_, clusters)| clusters.iter().enumerate()),
144 &heights,
145 )
146 .ok_or_else(|| CoreShapeError::ShapeError(record.debug_stats()))?;
148
149 let shard = record.shard();
150 tracing::debug!("Shard Lifted: Index={}, Cluster={}", shard, cluster_index);
151
152 for (air, height) in heights.iter() {
153 if shape.contains(air) {
154 tracing::debug!(
155 "Chip {:<20}: {:<3} -> {:<3}",
156 air,
157 log2_ceil_usize(*height),
158 shape.log2_height(air).unwrap(),
159 );
160 }
161 }
162 Ok(shape)
163 }
164 ShardKind::GlobalMemory => {
165 let heights = record.memory_heights();
168 let shape = self
169 .partial_memory_shapes
170 .find_shape(&heights)
171 .ok_or(CoreShapeError::ShapeError(record.debug_stats()))?;
172 Ok(shape)
173 }
174 ShardKind::Precompile => {
175 for (&air, (memory_events_per_row, allowed_log2_heights)) in
177 self.partial_precompile_shapes.iter()
178 {
179 let Some((height, num_memory_local_events, num_global_events)) =
181 record.precompile_heights().find_map(|x| (x.0 == air).then_some(x.1))
182 else {
183 continue;
184 };
185 for allowed_log2_height in allowed_log2_heights {
186 let allowed_height = 1 << allowed_log2_height;
187 if height <= allowed_height {
188 for shape in self.get_precompile_shapes(
189 air,
190 *memory_events_per_row,
191 *allowed_log2_height,
192 ) {
193 let mem_events_height = shape[2].1;
194 let global_events_height = shape[3].1;
195 if num_memory_local_events
196 .div_ceil(NUM_LOCAL_MEMORY_ENTRIES_PER_ROW) <=
197 (1 << mem_events_height) &&
198 num_global_events <= (1 << global_events_height)
199 {
200 let mut actual_shape: Shape<RiscvAirId> = Shape::default();
201 actual_shape.extend(
202 shape
203 .iter()
204 .map(|x| (RiscvAirId::from_str(&x.0).unwrap(), x.1)),
205 );
206 return Ok(actual_shape);
207 }
208 }
209 }
210 }
211 tracing::error!(
212 "Cannot find shape for precompile {:?}, height {:?}, and mem events {:?}",
213 air,
214 height,
215 num_memory_local_events
216 );
217 return Err(CoreShapeError::ShapeError(record.debug_stats()));
218 }
219 Err(CoreShapeError::PrecompileNotIncluded(record.debug_stats()))
220 }
221 }
222 }
223
224 pub fn minimal_cluster_shape<'a, N, I>(
227 &self,
228 indexed_shape_clusters: I,
229 heights: &[(RiscvAirId, usize)],
230 ) -> Option<(N, Shape<RiscvAirId>, usize)>
231 where
232 I: IntoIterator<Item = (N, &'a ShapeCluster<RiscvAirId>)>,
233 {
234 indexed_shape_clusters
236 .into_iter()
237 .filter_map(|(i, cluster)| {
238 let shape = cluster.find_shape(heights)?;
239 let area = self.estimate_lde_size(&shape);
240 Some((i, shape, area))
241 })
242 .min_by_key(|x| x.2) }
244
245 fn get_precompile_shapes(
247 &self,
248 air_id: RiscvAirId,
249 memory_events_per_row: usize,
250 allowed_log2_height: usize,
251 ) -> Vec<[(String, usize); 4]> {
252 (1..=4 * air_id.rows_per_event())
254 .rev()
255 .map(|rows_per_event| {
256 let num_local_mem_events =
257 ((1 << allowed_log2_height) * memory_events_per_row).div_ceil(rows_per_event);
258 [
259 (air_id.to_string(), allowed_log2_height),
260 (
261 RiscvAir::<F>::SyscallPrecompile(SyscallChip::precompile()).name(),
262 ((1 << allowed_log2_height)
263 .div_ceil(&air_id.rows_per_event())
264 .next_power_of_two()
265 .ilog2() as usize)
266 .max(4),
267 ),
268 (
269 RiscvAir::<F>::MemoryLocal(MemoryLocalChip::new()).name(),
270 (num_local_mem_events
271 .div_ceil(NUM_LOCAL_MEMORY_ENTRIES_PER_ROW)
272 .next_power_of_two()
273 .ilog2() as usize)
274 .max(4),
275 ),
276 (
277 RiscvAir::<F>::Global(GlobalChip).name(),
278 ((2 * num_local_mem_events +
279 (1 << allowed_log2_height).div_ceil(&air_id.rows_per_event()))
280 .next_power_of_two()
281 .ilog2() as usize)
282 .max(4),
283 ),
284 ]
285 })
286 .filter(|shape| shape[3].1 <= 22)
287 .collect::<Vec<_>>()
288 }
289
290 fn generate_all_shapes_from_allowed_log_heights(
291 allowed_log_heights: impl IntoIterator<Item = (String, Vec<Option<usize>>)>,
292 ) -> impl Iterator<Item = OrderedShape> {
293 allowed_log_heights
294 .into_iter()
295 .map(|(name, heights)| heights.into_iter().map(move |height| (name.clone(), height)))
296 .multi_cartesian_product()
297 .map(|iter| {
298 iter.into_iter()
299 .filter_map(|(name, maybe_height)| {
300 maybe_height.map(|log_height| (name, log_height))
301 })
302 .collect::<OrderedShape>()
303 })
304 }
305
306 pub fn all_shapes(&self) -> impl Iterator<Item = OrderedShape> + '_ {
307 let preprocessed_heights = self
308 .partial_preprocessed_shapes
309 .iter()
310 .map(|(air, heights)| (air.to_string(), heights.clone()))
311 .collect::<HashMap<_, _>>();
312
313 let mut memory_heights = self
314 .partial_memory_shapes
315 .iter()
316 .map(|(air, heights)| (air.to_string(), heights.clone()))
317 .collect::<HashMap<_, _>>();
318 memory_heights.extend(preprocessed_heights.clone());
319
320 let precompile_only_shapes = self.partial_precompile_shapes.iter().flat_map(
321 move |(&air, (mem_events_per_row, allowed_log_heights))| {
322 allowed_log_heights.iter().flat_map(move |allowed_log_height| {
323 self.get_precompile_shapes(air, *mem_events_per_row, *allowed_log_height)
324 })
325 },
326 );
327
328 let precompile_shapes =
329 Self::generate_all_shapes_from_allowed_log_heights(preprocessed_heights.clone())
330 .flat_map(move |preprocessed_shape| {
331 precompile_only_shapes.clone().map(move |precompile_shape| {
332 preprocessed_shape
333 .clone()
334 .into_iter()
335 .chain(precompile_shape)
336 .collect::<OrderedShape>()
337 })
338 });
339
340 self.partial_core_shapes
341 .values()
342 .flatten()
343 .chain(self.partial_small_shapes.iter())
344 .flat_map(move |allowed_log_heights| {
345 Self::generate_all_shapes_from_allowed_log_heights({
346 let mut log_heights = allowed_log_heights
347 .iter()
348 .map(|(air, heights)| (air.to_string(), heights.clone()))
349 .collect::<HashMap<_, _>>();
350 log_heights.extend(preprocessed_heights.clone());
351 log_heights
352 })
353 })
354 .chain(Self::generate_all_shapes_from_allowed_log_heights(memory_heights))
355 .chain(precompile_shapes)
356 }
357
358 pub fn maximal_core_shapes(&self, max_log_shard_size: usize) -> Vec<Shape<RiscvAirId>> {
359 let max_shard_size: usize = core::cmp::max(
360 1 << max_log_shard_size,
361 1 << self.partial_core_shapes.keys().min().unwrap(),
362 );
363
364 let log_shard_size = max_shard_size.ilog2() as usize;
365 debug_assert_eq!(1 << log_shard_size, max_shard_size);
366 let max_preprocessed = self
367 .partial_preprocessed_shapes
368 .iter()
369 .map(|(air, allowed_heights)| {
370 (air.to_string(), allowed_heights.last().unwrap().unwrap())
371 })
372 .collect::<HashMap<_, _>>();
373
374 let max_core_shapes =
375 self.partial_core_shapes[&log_shard_size].iter().map(|allowed_log_heights| {
376 max_preprocessed
377 .clone()
378 .into_iter()
379 .chain(allowed_log_heights.iter().flat_map(|(air, allowed_heights)| {
380 allowed_heights
381 .last()
382 .unwrap()
383 .map(|log_height| (air.to_string(), log_height))
384 }))
385 .map(|(air, log_height)| (RiscvAirId::from_str(&air).unwrap(), log_height))
386 .collect::<Shape<RiscvAirId>>()
387 });
388
389 max_core_shapes.collect()
390 }
391
392 pub fn maximal_core_plus_precompile_shapes(
393 &self,
394 max_log_shard_size: usize,
395 ) -> Vec<Shape<RiscvAirId>> {
396 let max_preprocessed = self
397 .partial_preprocessed_shapes
398 .iter()
399 .map(|(air, allowed_heights)| {
400 (air.to_string(), allowed_heights.last().unwrap().unwrap())
401 })
402 .collect::<HashMap<_, _>>();
403
404 let precompile_only_shapes = self.partial_precompile_shapes.iter().flat_map(
405 move |(&air, (mem_events_per_row, allowed_log_heights))| {
406 self.get_precompile_shapes(
407 air,
408 *mem_events_per_row,
409 *allowed_log_heights.last().unwrap(),
410 )
411 },
412 );
413
414 let precompile_shapes: Vec<Shape<RiscvAirId>> = precompile_only_shapes
415 .map(|x| {
416 max_preprocessed
417 .clone()
418 .into_iter()
419 .chain(x)
420 .map(|(air, log_height)| (RiscvAirId::from_str(&air).unwrap(), log_height))
421 .collect::<Shape<RiscvAirId>>()
422 })
423 .filter(|shape| shape.log2_height(&RiscvAirId::Global).unwrap() < 21)
424 .collect();
425
426 self.maximal_core_shapes(max_log_shard_size).into_iter().chain(precompile_shapes).collect()
427 }
428
429 pub fn estimate_lde_size(&self, shape: &Shape<RiscvAirId>) -> usize {
430 shape.iter().map(|(air, height)| self.costs[air] * (1 << height)).sum()
431 }
432
433 pub fn small_program_shapes(&self) -> Vec<OrderedShape> {
435 self.partial_small_shapes
436 .iter()
437 .map(|log_heights| {
438 OrderedShape::from_log2_heights(
439 &log_heights
440 .iter()
441 .filter(|(_, v)| v[0].is_some())
442 .map(|(k, v)| (k.to_string(), v.last().unwrap().unwrap()))
443 .chain(vec![
444 (MachineAir::<BabyBear>::name(&ProgramChip), 19),
445 (MachineAir::<BabyBear>::name(&ByteChip::default()), 16),
446 ])
447 .collect::<Vec<_>>(),
448 )
449 })
450 .collect()
451 }
452}
453
454impl<F: PrimeField32> Default for CoreShapeConfig<F> {
455 fn default() -> Self {
456 let maximal_shapes: BTreeMap<usize, Vec<Shape<RiscvAirId>>> =
458 serde_json::from_slice(MAXIMAL_SHAPES).unwrap();
459 let small_shapes: Vec<Shape<RiscvAirId>> = serde_json::from_slice(SMALL_SHAPES).unwrap();
460
461 let allowed_preprocessed_log2_heights = HashMap::from([
463 (RiscvAirId::Program, vec![Some(19), Some(20), Some(21), Some(22)]),
464 (RiscvAirId::Byte, vec![Some(16)]),
465 ]);
466
467 let blacklist = [
470 27, 33, 47, 68, 75, 102, 104, 114, 116, 118, 137, 138, 139, 144, 145, 153, 155, 157,
471 158, 169, 170, 171, 184, 185, 187, 195, 216, 243, 252, 275, 281, 282, 285,
472 ];
473 let mut core_allowed_log2_heights = BTreeMap::new();
474 for (log2_shard_size, maximal_shapes) in maximal_shapes {
475 let mut clusters = vec![];
476
477 for (i, maximal_shape) in maximal_shapes.iter().enumerate() {
478 if log2_shard_size == 21 && blacklist.contains(&i) {
482 continue;
483 }
484
485 let cluster = derive_cluster_from_maximal_shape(maximal_shape);
486 clusters.push(cluster);
487 }
488
489 core_allowed_log2_heights.insert(log2_shard_size, clusters);
490 }
491
492 let memory_allowed_log2_heights = HashMap::from(
494 [
495 (
496 RiscvAirId::MemoryGlobalInit,
497 vec![None, Some(10), Some(16), Some(18), Some(19), Some(20), Some(21)],
498 ),
499 (
500 RiscvAirId::MemoryGlobalFinalize,
501 vec![None, Some(10), Some(16), Some(18), Some(19), Some(20), Some(21)],
502 ),
503 (RiscvAirId::Global, vec![None, Some(11), Some(17), Some(19), Some(21), Some(22)]),
504 ]
505 .map(|(air, log_heights)| (air, log_heights)),
506 );
507
508 let mut precompile_allowed_log2_heights = HashMap::new();
509 let precompile_heights = (3..21).collect::<Vec<_>>();
510 for (air, memory_events_per_row) in
511 RiscvAir::<F>::precompile_airs_with_memory_events_per_row()
512 {
513 precompile_allowed_log2_heights
514 .insert(air, (memory_events_per_row, precompile_heights.clone()));
515 }
516
517 Self {
518 partial_preprocessed_shapes: ShapeCluster::new(allowed_preprocessed_log2_heights),
519 partial_core_shapes: core_allowed_log2_heights,
520 partial_memory_shapes: ShapeCluster::new(memory_allowed_log2_heights),
521 partial_precompile_shapes: precompile_allowed_log2_heights,
522 partial_small_shapes: small_shapes
523 .into_iter()
524 .map(|x| {
525 ShapeCluster::new(x.into_iter().map(|(k, v)| (k, vec![Some(v)])).collect())
526 })
527 .collect(),
528 costs: serde_json::from_str(include_str!("rv32im_costs.json"))
529 .expect("Failed to load rv32im_costs.json file. Verify that `git config core.symlinks` is not set to false."),
530 _data: PhantomData,
531 }
532 }
533}
534
535fn derive_cluster_from_maximal_shape(shape: &Shape<RiscvAirId>) -> ShapeCluster<RiscvAirId> {
536 let log2_gap_from_21 = 21 - shape.log2_height(&RiscvAirId::Cpu).unwrap();
538 let min_log2_height_threshold = 18 - log2_gap_from_21;
539 let log2_height_buffer = 10;
540 let heuristic = |maximal_log2_height: Option<usize>, min_offset: usize| {
541 if let Some(maximal_log2_height) = maximal_log2_height {
542 let tallest_log2_height = std::cmp::max(maximal_log2_height, min_log2_height_threshold);
543 let shortest_log2_height = tallest_log2_height.saturating_sub(min_offset);
544
545 let mut range =
546 (shortest_log2_height..=tallest_log2_height).map(Some).collect::<Vec<_>>();
547
548 if shortest_log2_height > maximal_log2_height {
549 range.insert(0, Some(shortest_log2_height));
550 }
551
552 range
553 } else {
554 vec![None, Some(log2_height_buffer)]
555 }
556 };
557
558 let mut maybe_log2_heights = HashMap::new();
559
560 let cpu_log_height = shape.log2_height(&RiscvAirId::Cpu);
561 maybe_log2_heights.insert(RiscvAirId::Cpu, heuristic(cpu_log_height, 0));
562
563 let addsub_log_height = shape.log2_height(&RiscvAirId::AddSub);
564 maybe_log2_heights.insert(RiscvAirId::AddSub, heuristic(addsub_log_height, 0));
565
566 let lt_log_height = shape.log2_height(&RiscvAirId::Lt);
567 maybe_log2_heights.insert(RiscvAirId::Lt, heuristic(lt_log_height, 0));
568
569 let memory_local_log_height = shape.log2_height(&RiscvAirId::MemoryLocal);
570 maybe_log2_heights.insert(RiscvAirId::MemoryLocal, heuristic(memory_local_log_height, 0));
571
572 let divrem_log_height = shape.log2_height(&RiscvAirId::DivRem);
573 maybe_log2_heights.insert(RiscvAirId::DivRem, heuristic(divrem_log_height, 1));
574
575 let bitwise_log_height = shape.log2_height(&RiscvAirId::Bitwise);
576 maybe_log2_heights.insert(RiscvAirId::Bitwise, heuristic(bitwise_log_height, 1));
577
578 let mul_log_height = shape.log2_height(&RiscvAirId::Mul);
579 maybe_log2_heights.insert(RiscvAirId::Mul, heuristic(mul_log_height, 1));
580
581 let shift_right_log_height = shape.log2_height(&RiscvAirId::ShiftRight);
582 maybe_log2_heights.insert(RiscvAirId::ShiftRight, heuristic(shift_right_log_height, 1));
583
584 let shift_left_log_height = shape.log2_height(&RiscvAirId::ShiftLeft);
585 maybe_log2_heights.insert(RiscvAirId::ShiftLeft, heuristic(shift_left_log_height, 1));
586
587 let memory_instrs_log_height = shape.log2_height(&RiscvAirId::MemoryInstrs);
588 maybe_log2_heights.insert(RiscvAirId::MemoryInstrs, heuristic(memory_instrs_log_height, 0));
589
590 let auipc_log_height = shape.log2_height(&RiscvAirId::Auipc);
591 maybe_log2_heights.insert(RiscvAirId::Auipc, heuristic(auipc_log_height, 0));
592
593 let branch_log_height = shape.log2_height(&RiscvAirId::Branch);
594 maybe_log2_heights.insert(RiscvAirId::Branch, heuristic(branch_log_height, 0));
595
596 let jump_log_height = shape.log2_height(&RiscvAirId::Jump);
597 maybe_log2_heights.insert(RiscvAirId::Jump, heuristic(jump_log_height, 0));
598
599 let syscall_core_log_height = shape.log2_height(&RiscvAirId::SyscallCore);
600 maybe_log2_heights.insert(RiscvAirId::SyscallCore, heuristic(syscall_core_log_height, 0));
601
602 let syscall_instrs_log_height = shape.log2_height(&RiscvAirId::SyscallInstrs);
603 maybe_log2_heights.insert(RiscvAirId::SyscallInstrs, heuristic(syscall_instrs_log_height, 0));
604
605 let global_log_height = shape.log2_height(&RiscvAirId::Global);
606 maybe_log2_heights.insert(RiscvAirId::Global, heuristic(global_log_height, 1));
607
608 assert!(maybe_log2_heights.len() >= shape.len(), "not all chips were included in the shape");
609
610 ShapeCluster::new(maybe_log2_heights)
611}
612
613#[derive(Debug, Error)]
614pub enum CoreShapeError {
615 #[error("no preprocessed shape found")]
616 PreprocessedShapeError,
617 #[error("Preprocessed shape already fixed")]
618 PreprocessedShapeAlreadyFixed,
619 #[error("no shape found {0:?}")]
620 ShapeError(HashMap<String, usize>),
621 #[error("Preprocessed shape missing")]
622 PreprocessedShapeMissing,
623 #[error("Shape already fixed")]
624 ShapeAlreadyFixed,
625 #[error("Precompile not included in allowed shapes {0:?}")]
626 PrecompileNotIncluded(HashMap<String, usize>),
627}
628
629pub fn create_dummy_program(shape: &Shape<RiscvAirId>) -> Program {
630 let mut program =
631 Program::new(vec![Instruction::new(Opcode::ADD, 30, 0, 0, false, false)], 1 << 5, 1 << 5);
632 program.preprocessed_shape = Some(shape.clone());
633 program
634}
635
636pub fn create_dummy_record(shape: &Shape<RiscvAirId>) -> ExecutionRecord {
637 let program = std::sync::Arc::new(create_dummy_program(shape));
638 let mut record = ExecutionRecord::new(program);
639 record.shape = Some(shape.clone());
640 record
641}
642
643#[cfg(test)]
644pub mod tests {
645 #![allow(clippy::print_stdout)]
646
647 use hashbrown::HashSet;
648 use sp1_stark::{Dom, MachineProver, StarkGenericConfig};
649
650 use super::*;
651
652 fn try_generate_dummy_proof<SC: StarkGenericConfig, P: MachineProver<SC, RiscvAir<SC::Val>>>(
653 prover: &P,
654 shape: &Shape<RiscvAirId>,
655 ) where
656 SC::Val: PrimeField32,
657 Dom<SC>: core::fmt::Debug,
658 {
659 let program = create_dummy_program(shape);
660 let record = create_dummy_record(shape);
661
662 let (pk, _) = prover.setup(&program);
664
665 let main_traces = prover.generate_traces(&record);
667
668 let main_data = prover.commit(&record, main_traces);
670
671 let mut challenger = prover.machine().config().challenger();
672
673 prover.open(&pk, main_data, &mut challenger).unwrap();
675 }
676
677 #[test]
678 #[ignore]
679 fn test_making_shapes() {
680 use p3_baby_bear::BabyBear;
681 let shape_config = CoreShapeConfig::<BabyBear>::default();
682 let num_shapes = shape_config.all_shapes().collect::<HashSet<_>>().len();
683 println!("There are {num_shapes} core shapes");
684 assert!(num_shapes < 1 << 24);
685 }
686
687 #[test]
688 fn test_dummy_record() {
689 use crate::utils::setup_logger;
690 use p3_baby_bear::BabyBear;
691 use sp1_stark::{baby_bear_poseidon2::BabyBearPoseidon2, CpuProver};
692
693 type SC = BabyBearPoseidon2;
694 type A = RiscvAir<BabyBear>;
695
696 setup_logger();
697
698 let preprocessed_log_heights = [(RiscvAirId::Program, 10), (RiscvAirId::Byte, 16)];
699
700 let core_log_heights = [
701 (RiscvAirId::Cpu, 11),
702 (RiscvAirId::DivRem, 11),
703 (RiscvAirId::AddSub, 10),
704 (RiscvAirId::Bitwise, 10),
705 (RiscvAirId::Mul, 10),
706 (RiscvAirId::ShiftRight, 10),
707 (RiscvAirId::ShiftLeft, 10),
708 (RiscvAirId::Lt, 10),
709 (RiscvAirId::MemoryLocal, 10),
710 (RiscvAirId::SyscallCore, 10),
711 (RiscvAirId::Global, 10),
712 ];
713
714 let height_map =
715 preprocessed_log_heights.into_iter().chain(core_log_heights).collect::<HashMap<_, _>>();
716
717 let shape = Shape::new(height_map);
718
719 let config = SC::default();
721 let machine = A::machine(config);
722 let prover = CpuProver::new(machine);
723
724 try_generate_dummy_proof(&prover, &shape);
725 }
726}