sp1_core_machine/shape/
mod.rs

1mod shapeable;
2
3pub use shapeable::*;
4
5use std::{collections::BTreeMap, marker::PhantomData, str::FromStr};
6
7use hashbrown::HashMap;
8use itertools::Itertools;
9use num::Integer;
10use p3_baby_bear::BabyBear;
11use p3_field::PrimeField32;
12use p3_util::log2_ceil_usize;
13use sp1_core_executor::{ExecutionRecord, Instruction, Opcode, Program, RiscvAirId};
14use sp1_stark::{
15    air::MachineAir,
16    shape::{OrderedShape, Shape, ShapeCluster},
17};
18use thiserror::Error;
19
20use super::riscv::riscv_chips::{ByteChip, ProgramChip, SyscallChip};
21use crate::{
22    global::GlobalChip,
23    memory::{MemoryLocalChip, NUM_LOCAL_MEMORY_ENTRIES_PER_ROW},
24    riscv::RiscvAir,
25};
26
27/// The set of maximal shapes.
28///
29/// These shapes define the "worst-case" shapes for typical shards that are proving `rv32im`
30/// execution. We use a variant of a cartesian product of the allowed log heights to generate
31/// smaller shapes from these ones.
32const MAXIMAL_SHAPES: &[u8] = include_bytes!("maximal_shapes.json");
33
34/// The set of tiny shapes.
35///
36/// These shapes are used to optimize performance for smaller programs.
37const SMALL_SHAPES: &[u8] = include_bytes!("small_shapes.json");
38
39/// A configuration for what shapes are allowed to be used by the prover.
40#[derive(Debug)]
41pub struct CoreShapeConfig<F: PrimeField32> {
42    partial_preprocessed_shapes: ShapeCluster<RiscvAirId>,
43    partial_core_shapes: BTreeMap<usize, Vec<ShapeCluster<RiscvAirId>>>,
44    partial_memory_shapes: ShapeCluster<RiscvAirId>,
45    partial_precompile_shapes: HashMap<RiscvAirId, (usize, Vec<usize>)>,
46    partial_small_shapes: Vec<ShapeCluster<RiscvAirId>>,
47    costs: HashMap<RiscvAirId, usize>,
48    _data: PhantomData<F>,
49}
50
51impl<F: PrimeField32> CoreShapeConfig<F> {
52    /// Fix the preprocessed shape of the proof.
53    pub fn fix_preprocessed_shape(&self, program: &mut Program) -> Result<(), CoreShapeError> {
54        // If the preprocessed shape is already fixed, return an error.
55        if program.preprocessed_shape.is_some() {
56            return Err(CoreShapeError::PreprocessedShapeAlreadyFixed);
57        }
58
59        // Get the heights of the preprocessed chips and find a shape that fits.
60        let preprocessed_heights = RiscvAir::<F>::preprocessed_heights(program);
61        let preprocessed_shape = self
62            .partial_preprocessed_shapes
63            .find_shape(&preprocessed_heights)
64            .ok_or(CoreShapeError::PreprocessedShapeError)?;
65
66        // Set the preprocessed shape.
67        program.preprocessed_shape = Some(preprocessed_shape);
68
69        Ok(())
70    }
71
72    /// Fix the shape of the proof.
73    pub fn fix_shape(&self, record: &mut ExecutionRecord) -> Result<(), CoreShapeError> {
74        if record.program.preprocessed_shape.is_none() {
75            return Err(CoreShapeError::PreprocessedShapeMissing);
76        }
77        if record.shape.is_some() {
78            return Err(CoreShapeError::ShapeAlreadyFixed);
79        }
80
81        // Set the shape of the chips with prepcoded shapes to match the preprocessed shape from the
82        // program.
83        record.shape.clone_from(&record.program.preprocessed_shape);
84
85        let shape = self.find_shape(record)?;
86        record.shape.as_mut().unwrap().extend(shape);
87        Ok(())
88    }
89
90    /// TODO move this into the executor crate
91    pub fn find_shape<R: Shapeable>(
92        &self,
93        record: &R,
94    ) -> Result<Shape<RiscvAirId>, CoreShapeError> {
95        match record.kind() {
96            // If this is a packed "core" record where the cpu events are alongisde the memory init
97            // and finalize events, try to fix the shape using the tiny shapes.
98            ShardKind::PackedCore => {
99                // Get the heights of the core airs in the record.
100                let mut heights = record.core_heights();
101                heights.extend(record.memory_heights());
102
103                let (cluster_index, shape, _) = self
104                    .minimal_cluster_shape(self.partial_small_shapes.iter().enumerate(), &heights)
105                    .ok_or_else(|| {
106                        // No shape found, so return an error.
107                        CoreShapeError::ShapeError(
108                            heights
109                                .iter()
110                                .map(|(air, height)| (air.to_string(), log2_ceil_usize(*height)))
111                                .collect(),
112                        )
113                    })?;
114
115                let shard = record.shard();
116                tracing::debug!("Shard Lifted: Index={}, Cluster={}", shard, cluster_index);
117                for (air, height) in heights.iter() {
118                    if shape.contains(air) {
119                        tracing::debug!(
120                            "Chip {:<20}: {:<3} -> {:<3}",
121                            air,
122                            log2_ceil_usize(*height),
123                            shape.log2_height(air).unwrap(),
124                        );
125                    }
126                }
127                Ok(shape)
128            }
129            ShardKind::Core => {
130                // If this is a normal "core" record, try to fix the shape as such.
131
132                // Get the heights of the core airs in the record.
133                let heights = record.core_heights();
134
135                // Try to find the smallest shape fitting within at least one of the candidate
136                // shapes.
137                let log2_shard_size = record.log2_shard_size();
138
139                let (cluster_index, shape, _) = self
140                    .minimal_cluster_shape(
141                        self.partial_core_shapes
142                            .range(log2_shard_size..)
143                            .flat_map(|(_, clusters)| clusters.iter().enumerate()),
144                        &heights,
145                    )
146                    // No shape found, so return an error.
147                    .ok_or_else(|| CoreShapeError::ShapeError(record.debug_stats()))?;
148
149                let shard = record.shard();
150                tracing::debug!("Shard Lifted: Index={}, Cluster={}", shard, cluster_index);
151
152                for (air, height) in heights.iter() {
153                    if shape.contains(air) {
154                        tracing::debug!(
155                            "Chip {:<20}: {:<3} -> {:<3}",
156                            air,
157                            log2_ceil_usize(*height),
158                            shape.log2_height(air).unwrap(),
159                        );
160                    }
161                }
162                Ok(shape)
163            }
164            ShardKind::GlobalMemory => {
165                // If the record is a does not have the CPU chip and is a global memory
166                // init/finalize record, try to fix the shape as such.
167                let heights = record.memory_heights();
168                let shape = self
169                    .partial_memory_shapes
170                    .find_shape(&heights)
171                    .ok_or(CoreShapeError::ShapeError(record.debug_stats()))?;
172                Ok(shape)
173            }
174            ShardKind::Precompile => {
175                // Try to fix the shape as a precompile record.
176                for (&air, (memory_events_per_row, allowed_log2_heights)) in
177                    self.partial_precompile_shapes.iter()
178                {
179                    // Filter to check that the shard and shape air match.
180                    let Some((height, num_memory_local_events, num_global_events)) =
181                        record.precompile_heights().find_map(|x| (x.0 == air).then_some(x.1))
182                    else {
183                        continue;
184                    };
185                    for allowed_log2_height in allowed_log2_heights {
186                        let allowed_height = 1 << allowed_log2_height;
187                        if height <= allowed_height {
188                            for shape in self.get_precompile_shapes(
189                                air,
190                                *memory_events_per_row,
191                                *allowed_log2_height,
192                            ) {
193                                let mem_events_height = shape[2].1;
194                                let global_events_height = shape[3].1;
195                                if num_memory_local_events
196                                    .div_ceil(NUM_LOCAL_MEMORY_ENTRIES_PER_ROW) <=
197                                    (1 << mem_events_height) &&
198                                    num_global_events <= (1 << global_events_height)
199                                {
200                                    let mut actual_shape: Shape<RiscvAirId> = Shape::default();
201                                    actual_shape.extend(
202                                        shape
203                                            .iter()
204                                            .map(|x| (RiscvAirId::from_str(&x.0).unwrap(), x.1)),
205                                    );
206                                    return Ok(actual_shape);
207                                }
208                            }
209                        }
210                    }
211                    tracing::error!(
212                        "Cannot find shape for precompile {:?}, height {:?}, and mem events {:?}",
213                        air,
214                        height,
215                        num_memory_local_events
216                    );
217                    return Err(CoreShapeError::ShapeError(record.debug_stats()));
218                }
219                Err(CoreShapeError::PrecompileNotIncluded(record.debug_stats()))
220            }
221        }
222    }
223
224    /// Returns the area, cluster index, and shape of the minimal shape from candidates that fit a
225    /// given collection of heights.
226    pub fn minimal_cluster_shape<'a, N, I>(
227        &self,
228        indexed_shape_clusters: I,
229        heights: &[(RiscvAirId, usize)],
230    ) -> Option<(N, Shape<RiscvAirId>, usize)>
231    where
232        I: IntoIterator<Item = (N, &'a ShapeCluster<RiscvAirId>)>,
233    {
234        // Try to find a shape fitting within at least one of the candidate shapes.
235        indexed_shape_clusters
236            .into_iter()
237            .filter_map(|(i, cluster)| {
238                let shape = cluster.find_shape(heights)?;
239                let area = self.estimate_lde_size(&shape);
240                Some((i, shape, area))
241            })
242            .min_by_key(|x| x.2) // Find minimum by area.
243    }
244
245    // TODO: this function is atrocious, fix this
246    fn get_precompile_shapes(
247        &self,
248        air_id: RiscvAirId,
249        memory_events_per_row: usize,
250        allowed_log2_height: usize,
251    ) -> Vec<[(String, usize); 4]> {
252        // TODO: This is a temporary fix to the shape, concretely fix this
253        (1..=4 * air_id.rows_per_event())
254            .rev()
255            .map(|rows_per_event| {
256                let num_local_mem_events =
257                    ((1 << allowed_log2_height) * memory_events_per_row).div_ceil(rows_per_event);
258                [
259                    (air_id.to_string(), allowed_log2_height),
260                    (
261                        RiscvAir::<F>::SyscallPrecompile(SyscallChip::precompile()).name(),
262                        ((1 << allowed_log2_height)
263                            .div_ceil(&air_id.rows_per_event())
264                            .next_power_of_two()
265                            .ilog2() as usize)
266                            .max(4),
267                    ),
268                    (
269                        RiscvAir::<F>::MemoryLocal(MemoryLocalChip::new()).name(),
270                        (num_local_mem_events
271                            .div_ceil(NUM_LOCAL_MEMORY_ENTRIES_PER_ROW)
272                            .next_power_of_two()
273                            .ilog2() as usize)
274                            .max(4),
275                    ),
276                    (
277                        RiscvAir::<F>::Global(GlobalChip).name(),
278                        ((2 * num_local_mem_events +
279                            (1 << allowed_log2_height).div_ceil(&air_id.rows_per_event()))
280                        .next_power_of_two()
281                        .ilog2() as usize)
282                            .max(4),
283                    ),
284                ]
285            })
286            .filter(|shape| shape[3].1 <= 22)
287            .collect::<Vec<_>>()
288    }
289
290    fn generate_all_shapes_from_allowed_log_heights(
291        allowed_log_heights: impl IntoIterator<Item = (String, Vec<Option<usize>>)>,
292    ) -> impl Iterator<Item = OrderedShape> {
293        allowed_log_heights
294            .into_iter()
295            .map(|(name, heights)| heights.into_iter().map(move |height| (name.clone(), height)))
296            .multi_cartesian_product()
297            .map(|iter| {
298                iter.into_iter()
299                    .filter_map(|(name, maybe_height)| {
300                        maybe_height.map(|log_height| (name, log_height))
301                    })
302                    .collect::<OrderedShape>()
303            })
304    }
305
306    pub fn all_shapes(&self) -> impl Iterator<Item = OrderedShape> + '_ {
307        let preprocessed_heights = self
308            .partial_preprocessed_shapes
309            .iter()
310            .map(|(air, heights)| (air.to_string(), heights.clone()))
311            .collect::<HashMap<_, _>>();
312
313        let mut memory_heights = self
314            .partial_memory_shapes
315            .iter()
316            .map(|(air, heights)| (air.to_string(), heights.clone()))
317            .collect::<HashMap<_, _>>();
318        memory_heights.extend(preprocessed_heights.clone());
319
320        let precompile_only_shapes = self.partial_precompile_shapes.iter().flat_map(
321            move |(&air, (mem_events_per_row, allowed_log_heights))| {
322                allowed_log_heights.iter().flat_map(move |allowed_log_height| {
323                    self.get_precompile_shapes(air, *mem_events_per_row, *allowed_log_height)
324                })
325            },
326        );
327
328        let precompile_shapes =
329            Self::generate_all_shapes_from_allowed_log_heights(preprocessed_heights.clone())
330                .flat_map(move |preprocessed_shape| {
331                    precompile_only_shapes.clone().map(move |precompile_shape| {
332                        preprocessed_shape
333                            .clone()
334                            .into_iter()
335                            .chain(precompile_shape)
336                            .collect::<OrderedShape>()
337                    })
338                });
339
340        self.partial_core_shapes
341            .values()
342            .flatten()
343            .chain(self.partial_small_shapes.iter())
344            .flat_map(move |allowed_log_heights| {
345                Self::generate_all_shapes_from_allowed_log_heights({
346                    let mut log_heights = allowed_log_heights
347                        .iter()
348                        .map(|(air, heights)| (air.to_string(), heights.clone()))
349                        .collect::<HashMap<_, _>>();
350                    log_heights.extend(preprocessed_heights.clone());
351                    log_heights
352                })
353            })
354            .chain(Self::generate_all_shapes_from_allowed_log_heights(memory_heights))
355            .chain(precompile_shapes)
356    }
357
358    pub fn maximal_core_shapes(&self, max_log_shard_size: usize) -> Vec<Shape<RiscvAirId>> {
359        let max_shard_size: usize = core::cmp::max(
360            1 << max_log_shard_size,
361            1 << self.partial_core_shapes.keys().min().unwrap(),
362        );
363
364        let log_shard_size = max_shard_size.ilog2() as usize;
365        debug_assert_eq!(1 << log_shard_size, max_shard_size);
366        let max_preprocessed = self
367            .partial_preprocessed_shapes
368            .iter()
369            .map(|(air, allowed_heights)| {
370                (air.to_string(), allowed_heights.last().unwrap().unwrap())
371            })
372            .collect::<HashMap<_, _>>();
373
374        let max_core_shapes =
375            self.partial_core_shapes[&log_shard_size].iter().map(|allowed_log_heights| {
376                max_preprocessed
377                    .clone()
378                    .into_iter()
379                    .chain(allowed_log_heights.iter().flat_map(|(air, allowed_heights)| {
380                        allowed_heights
381                            .last()
382                            .unwrap()
383                            .map(|log_height| (air.to_string(), log_height))
384                    }))
385                    .map(|(air, log_height)| (RiscvAirId::from_str(&air).unwrap(), log_height))
386                    .collect::<Shape<RiscvAirId>>()
387            });
388
389        max_core_shapes.collect()
390    }
391
392    pub fn maximal_core_plus_precompile_shapes(
393        &self,
394        max_log_shard_size: usize,
395    ) -> Vec<Shape<RiscvAirId>> {
396        let max_preprocessed = self
397            .partial_preprocessed_shapes
398            .iter()
399            .map(|(air, allowed_heights)| {
400                (air.to_string(), allowed_heights.last().unwrap().unwrap())
401            })
402            .collect::<HashMap<_, _>>();
403
404        let precompile_only_shapes = self.partial_precompile_shapes.iter().flat_map(
405            move |(&air, (mem_events_per_row, allowed_log_heights))| {
406                self.get_precompile_shapes(
407                    air,
408                    *mem_events_per_row,
409                    *allowed_log_heights.last().unwrap(),
410                )
411            },
412        );
413
414        let precompile_shapes: Vec<Shape<RiscvAirId>> = precompile_only_shapes
415            .map(|x| {
416                max_preprocessed
417                    .clone()
418                    .into_iter()
419                    .chain(x)
420                    .map(|(air, log_height)| (RiscvAirId::from_str(&air).unwrap(), log_height))
421                    .collect::<Shape<RiscvAirId>>()
422            })
423            .filter(|shape| shape.log2_height(&RiscvAirId::Global).unwrap() < 21)
424            .collect();
425
426        self.maximal_core_shapes(max_log_shard_size).into_iter().chain(precompile_shapes).collect()
427    }
428
429    pub fn estimate_lde_size(&self, shape: &Shape<RiscvAirId>) -> usize {
430        shape.iter().map(|(air, height)| self.costs[air] * (1 << height)).sum()
431    }
432
433    // TODO: cleanup..
434    pub fn small_program_shapes(&self) -> Vec<OrderedShape> {
435        self.partial_small_shapes
436            .iter()
437            .map(|log_heights| {
438                OrderedShape::from_log2_heights(
439                    &log_heights
440                        .iter()
441                        .filter(|(_, v)| v[0].is_some())
442                        .map(|(k, v)| (k.to_string(), v.last().unwrap().unwrap()))
443                        .chain(vec![
444                            (MachineAir::<BabyBear>::name(&ProgramChip), 19),
445                            (MachineAir::<BabyBear>::name(&ByteChip::default()), 16),
446                        ])
447                        .collect::<Vec<_>>(),
448                )
449            })
450            .collect()
451    }
452}
453
454impl<F: PrimeField32> Default for CoreShapeConfig<F> {
455    fn default() -> Self {
456        // Load the maximal shapes.
457        let maximal_shapes: BTreeMap<usize, Vec<Shape<RiscvAirId>>> =
458            serde_json::from_slice(MAXIMAL_SHAPES).unwrap();
459        let small_shapes: Vec<Shape<RiscvAirId>> = serde_json::from_slice(SMALL_SHAPES).unwrap();
460
461        // Set the allowed preprocessed log2 heights.
462        let allowed_preprocessed_log2_heights = HashMap::from([
463            (RiscvAirId::Program, vec![Some(19), Some(20), Some(21), Some(22)]),
464            (RiscvAirId::Byte, vec![Some(16)]),
465        ]);
466
467        // Generate the clusters from the maximal shapes and register them indexed by log2 shard
468        //  size.
469        let blacklist = [
470            27, 33, 47, 68, 75, 102, 104, 114, 116, 118, 137, 138, 139, 144, 145, 153, 155, 157,
471            158, 169, 170, 171, 184, 185, 187, 195, 216, 243, 252, 275, 281, 282, 285,
472        ];
473        let mut core_allowed_log2_heights = BTreeMap::new();
474        for (log2_shard_size, maximal_shapes) in maximal_shapes {
475            let mut clusters = vec![];
476
477            for (i, maximal_shape) in maximal_shapes.iter().enumerate() {
478                // WARNING: This must be tuned carefully.
479                //
480                // This is current hardcoded, but in the future it should be computed dynamically.
481                if log2_shard_size == 21 && blacklist.contains(&i) {
482                    continue;
483                }
484
485                let cluster = derive_cluster_from_maximal_shape(maximal_shape);
486                clusters.push(cluster);
487            }
488
489            core_allowed_log2_heights.insert(log2_shard_size, clusters);
490        }
491
492        // Set the memory init and finalize heights.
493        let memory_allowed_log2_heights = HashMap::from(
494            [
495                (
496                    RiscvAirId::MemoryGlobalInit,
497                    vec![None, Some(10), Some(16), Some(18), Some(19), Some(20), Some(21)],
498                ),
499                (
500                    RiscvAirId::MemoryGlobalFinalize,
501                    vec![None, Some(10), Some(16), Some(18), Some(19), Some(20), Some(21)],
502                ),
503                (RiscvAirId::Global, vec![None, Some(11), Some(17), Some(19), Some(21), Some(22)]),
504            ]
505            .map(|(air, log_heights)| (air, log_heights)),
506        );
507
508        let mut precompile_allowed_log2_heights = HashMap::new();
509        let precompile_heights = (3..21).collect::<Vec<_>>();
510        for (air, memory_events_per_row) in
511            RiscvAir::<F>::precompile_airs_with_memory_events_per_row()
512        {
513            precompile_allowed_log2_heights
514                .insert(air, (memory_events_per_row, precompile_heights.clone()));
515        }
516
517        Self {
518            partial_preprocessed_shapes: ShapeCluster::new(allowed_preprocessed_log2_heights),
519            partial_core_shapes: core_allowed_log2_heights,
520            partial_memory_shapes: ShapeCluster::new(memory_allowed_log2_heights),
521            partial_precompile_shapes: precompile_allowed_log2_heights,
522            partial_small_shapes: small_shapes
523                .into_iter()
524                .map(|x| {
525                    ShapeCluster::new(x.into_iter().map(|(k, v)| (k, vec![Some(v)])).collect())
526                })
527                .collect(),
528            costs: serde_json::from_str(include_str!("rv32im_costs.json"))
529                .expect("Failed to load rv32im_costs.json file. Verify that `git config core.symlinks` is not set to false."),
530            _data: PhantomData,
531        }
532    }
533}
534
535fn derive_cluster_from_maximal_shape(shape: &Shape<RiscvAirId>) -> ShapeCluster<RiscvAirId> {
536    // We first define a heuristic to derive the log heights from the maximal shape.
537    let log2_gap_from_21 = 21 - shape.log2_height(&RiscvAirId::Cpu).unwrap();
538    let min_log2_height_threshold = 18 - log2_gap_from_21;
539    let log2_height_buffer = 10;
540    let heuristic = |maximal_log2_height: Option<usize>, min_offset: usize| {
541        if let Some(maximal_log2_height) = maximal_log2_height {
542            let tallest_log2_height = std::cmp::max(maximal_log2_height, min_log2_height_threshold);
543            let shortest_log2_height = tallest_log2_height.saturating_sub(min_offset);
544
545            let mut range =
546                (shortest_log2_height..=tallest_log2_height).map(Some).collect::<Vec<_>>();
547
548            if shortest_log2_height > maximal_log2_height {
549                range.insert(0, Some(shortest_log2_height));
550            }
551
552            range
553        } else {
554            vec![None, Some(log2_height_buffer)]
555        }
556    };
557
558    let mut maybe_log2_heights = HashMap::new();
559
560    let cpu_log_height = shape.log2_height(&RiscvAirId::Cpu);
561    maybe_log2_heights.insert(RiscvAirId::Cpu, heuristic(cpu_log_height, 0));
562
563    let addsub_log_height = shape.log2_height(&RiscvAirId::AddSub);
564    maybe_log2_heights.insert(RiscvAirId::AddSub, heuristic(addsub_log_height, 0));
565
566    let lt_log_height = shape.log2_height(&RiscvAirId::Lt);
567    maybe_log2_heights.insert(RiscvAirId::Lt, heuristic(lt_log_height, 0));
568
569    let memory_local_log_height = shape.log2_height(&RiscvAirId::MemoryLocal);
570    maybe_log2_heights.insert(RiscvAirId::MemoryLocal, heuristic(memory_local_log_height, 0));
571
572    let divrem_log_height = shape.log2_height(&RiscvAirId::DivRem);
573    maybe_log2_heights.insert(RiscvAirId::DivRem, heuristic(divrem_log_height, 1));
574
575    let bitwise_log_height = shape.log2_height(&RiscvAirId::Bitwise);
576    maybe_log2_heights.insert(RiscvAirId::Bitwise, heuristic(bitwise_log_height, 1));
577
578    let mul_log_height = shape.log2_height(&RiscvAirId::Mul);
579    maybe_log2_heights.insert(RiscvAirId::Mul, heuristic(mul_log_height, 1));
580
581    let shift_right_log_height = shape.log2_height(&RiscvAirId::ShiftRight);
582    maybe_log2_heights.insert(RiscvAirId::ShiftRight, heuristic(shift_right_log_height, 1));
583
584    let shift_left_log_height = shape.log2_height(&RiscvAirId::ShiftLeft);
585    maybe_log2_heights.insert(RiscvAirId::ShiftLeft, heuristic(shift_left_log_height, 1));
586
587    let memory_instrs_log_height = shape.log2_height(&RiscvAirId::MemoryInstrs);
588    maybe_log2_heights.insert(RiscvAirId::MemoryInstrs, heuristic(memory_instrs_log_height, 0));
589
590    let auipc_log_height = shape.log2_height(&RiscvAirId::Auipc);
591    maybe_log2_heights.insert(RiscvAirId::Auipc, heuristic(auipc_log_height, 0));
592
593    let branch_log_height = shape.log2_height(&RiscvAirId::Branch);
594    maybe_log2_heights.insert(RiscvAirId::Branch, heuristic(branch_log_height, 0));
595
596    let jump_log_height = shape.log2_height(&RiscvAirId::Jump);
597    maybe_log2_heights.insert(RiscvAirId::Jump, heuristic(jump_log_height, 0));
598
599    let syscall_core_log_height = shape.log2_height(&RiscvAirId::SyscallCore);
600    maybe_log2_heights.insert(RiscvAirId::SyscallCore, heuristic(syscall_core_log_height, 0));
601
602    let syscall_instrs_log_height = shape.log2_height(&RiscvAirId::SyscallInstrs);
603    maybe_log2_heights.insert(RiscvAirId::SyscallInstrs, heuristic(syscall_instrs_log_height, 0));
604
605    let global_log_height = shape.log2_height(&RiscvAirId::Global);
606    maybe_log2_heights.insert(RiscvAirId::Global, heuristic(global_log_height, 1));
607
608    assert!(maybe_log2_heights.len() >= shape.len(), "not all chips were included in the shape");
609
610    ShapeCluster::new(maybe_log2_heights)
611}
612
613#[derive(Debug, Error)]
614pub enum CoreShapeError {
615    #[error("no preprocessed shape found")]
616    PreprocessedShapeError,
617    #[error("Preprocessed shape already fixed")]
618    PreprocessedShapeAlreadyFixed,
619    #[error("no shape found {0:?}")]
620    ShapeError(HashMap<String, usize>),
621    #[error("Preprocessed shape missing")]
622    PreprocessedShapeMissing,
623    #[error("Shape already fixed")]
624    ShapeAlreadyFixed,
625    #[error("Precompile not included in allowed shapes {0:?}")]
626    PrecompileNotIncluded(HashMap<String, usize>),
627}
628
629pub fn create_dummy_program(shape: &Shape<RiscvAirId>) -> Program {
630    let mut program =
631        Program::new(vec![Instruction::new(Opcode::ADD, 30, 0, 0, false, false)], 1 << 5, 1 << 5);
632    program.preprocessed_shape = Some(shape.clone());
633    program
634}
635
636pub fn create_dummy_record(shape: &Shape<RiscvAirId>) -> ExecutionRecord {
637    let program = std::sync::Arc::new(create_dummy_program(shape));
638    let mut record = ExecutionRecord::new(program);
639    record.shape = Some(shape.clone());
640    record
641}
642
643#[cfg(test)]
644pub mod tests {
645    #![allow(clippy::print_stdout)]
646
647    use hashbrown::HashSet;
648    use sp1_stark::{Dom, MachineProver, StarkGenericConfig};
649
650    use super::*;
651
652    fn try_generate_dummy_proof<SC: StarkGenericConfig, P: MachineProver<SC, RiscvAir<SC::Val>>>(
653        prover: &P,
654        shape: &Shape<RiscvAirId>,
655    ) where
656        SC::Val: PrimeField32,
657        Dom<SC>: core::fmt::Debug,
658    {
659        let program = create_dummy_program(shape);
660        let record = create_dummy_record(shape);
661
662        // Try doing setup.
663        let (pk, _) = prover.setup(&program);
664
665        // Try to generate traces.
666        let main_traces = prover.generate_traces(&record);
667
668        // Try to commit the traces.
669        let main_data = prover.commit(&record, main_traces);
670
671        let mut challenger = prover.machine().config().challenger();
672
673        // Try to "open".
674        prover.open(&pk, main_data, &mut challenger).unwrap();
675    }
676
677    #[test]
678    #[ignore]
679    fn test_making_shapes() {
680        use p3_baby_bear::BabyBear;
681        let shape_config = CoreShapeConfig::<BabyBear>::default();
682        let num_shapes = shape_config.all_shapes().collect::<HashSet<_>>().len();
683        println!("There are {num_shapes} core shapes");
684        assert!(num_shapes < 1 << 24);
685    }
686
687    #[test]
688    fn test_dummy_record() {
689        use crate::utils::setup_logger;
690        use p3_baby_bear::BabyBear;
691        use sp1_stark::{baby_bear_poseidon2::BabyBearPoseidon2, CpuProver};
692
693        type SC = BabyBearPoseidon2;
694        type A = RiscvAir<BabyBear>;
695
696        setup_logger();
697
698        let preprocessed_log_heights = [(RiscvAirId::Program, 10), (RiscvAirId::Byte, 16)];
699
700        let core_log_heights = [
701            (RiscvAirId::Cpu, 11),
702            (RiscvAirId::DivRem, 11),
703            (RiscvAirId::AddSub, 10),
704            (RiscvAirId::Bitwise, 10),
705            (RiscvAirId::Mul, 10),
706            (RiscvAirId::ShiftRight, 10),
707            (RiscvAirId::ShiftLeft, 10),
708            (RiscvAirId::Lt, 10),
709            (RiscvAirId::MemoryLocal, 10),
710            (RiscvAirId::SyscallCore, 10),
711            (RiscvAirId::Global, 10),
712        ];
713
714        let height_map =
715            preprocessed_log_heights.into_iter().chain(core_log_heights).collect::<HashMap<_, _>>();
716
717        let shape = Shape::new(height_map);
718
719        // Try generating preprocessed traces.
720        let config = SC::default();
721        let machine = A::machine(config);
722        let prover = CpuProver::new(machine);
723
724        try_generate_dummy_proof(&prover, &shape);
725    }
726}