Skip to main content

sp1_prover/
shapes.rs

1use std::{
2    collections::{BTreeMap, BTreeSet},
3    fmt::Debug,
4    num::NonZero,
5    sync::{
6        atomic::{AtomicUsize, Ordering},
7        Arc, Mutex,
8    },
9};
10
11use hashbrown::HashSet;
12use lru::LruCache;
13use serde::{Deserialize, Serialize};
14use slop_air::BaseAir;
15use slop_algebra::AbstractField;
16use slop_basefold::FriConfig;
17use sp1_core_executor::MAX_PROGRAM_SIZE;
18use sp1_core_machine::{
19    bytes::columns::NUM_BYTE_PREPROCESSED_COLS, program::NUM_PROGRAM_PREPROCESSED_COLS,
20    range::columns::NUM_RANGE_PREPROCESSED_COLS, riscv::RiscvAir,
21};
22use sp1_hypercube::{
23    air::MachineAir,
24    log2_ceil_usize,
25    prover::{CoreProofShape, DefaultTraceGenerator, ProverSemaphore, TraceGenerator},
26    Chip, HashableKey, Machine, MachineShape, SP1PcsProofInner, SP1VerifyingKey,
27};
28use sp1_primitives::{
29    fri_params::{core_fri_config, CORE_LOG_BLOWUP},
30    SP1Field, SP1GlobalContext,
31};
32use sp1_prover_types::ArtifactClient;
33use sp1_recursion_circuit::{
34    dummy::{dummy_shard_proof, dummy_vk},
35    machine::{
36        SP1CompressWithVKeyWitnessValues, SP1MerkleProofWitnessValues, SP1NormalizeWitnessValues,
37        SP1ShapedWitnessValues,
38    },
39};
40use sp1_recursion_compiler::config::InnerConfig;
41use sp1_recursion_executor::{
42    shape::RecursionShape, RecursionAirEventCount, RecursionProgram, DIGEST_SIZE,
43};
44use sp1_recursion_machine::chips::{
45    alu_base::BaseAluChip,
46    alu_ext::ExtAluChip,
47    mem::{MemoryConstChip, MemoryVarChip},
48    poseidon2_helper::{
49        convert::ConvertChip, linear::Poseidon2LinearLayerChip, sbox::Poseidon2SBoxChip,
50    },
51    poseidon2_wide::Poseidon2WideChip,
52    prefix_sum_checks::PrefixSumChecksChip,
53    public_values::PublicValuesChip,
54    select::SelectChip,
55};
56use sp1_verifier::compressed::RECURSION_MAX_LOG_ROW_COUNT;
57use thiserror::Error;
58use tokio::task::JoinSet;
59
60use crate::{
61    components::{SP1ProverComponents, CORE_LOG_STACKING_HEIGHT},
62    recursion::{
63        compose_program_from_input, deferred_program_from_input, dummy_compose_input,
64        dummy_deferred_input, normalize_program_from_input, recursive_verifier,
65        shrink_program_from_input,
66    },
67    worker::{AirProverWorker, RecursionVkWorker},
68    CompressAir, CORE_MAX_LOG_ROW_COUNT,
69};
70
71pub const DEFAULT_ARITY: usize = 4;
72
73/// The shape of the "normalize" program, which proves the correct execution for the verifier of a
74/// single core shard proof.
75#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
76pub struct SP1NormalizeInputShape {
77    pub proof_shapes: Vec<CoreProofShape<SP1Field, RiscvAir<SP1Field>>>,
78    pub max_log_row_count: usize,
79    pub log_blowup: usize,
80    pub log_stacking_height: usize,
81}
82
83#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug)]
84pub enum SP1RecursionProgramShape {
85    // The program that verifies a core shard proof.
86    Normalize(CoreProofShape<SP1Field, RiscvAir<SP1Field>>),
87    // Compose(arity) is the program that verifies a batch of Normalize proofs of size arity.
88    Compose(usize),
89    // The deferred proof program.
90    Deferred,
91    // The shrink program that verifies the the root of the recursion tree.
92    Shrink,
93}
94
95const PADDED_ELEMENT_THRESHOLD: u64 =
96    sp1_core_executor::ELEMENT_THRESHOLD + (1 << CORE_LOG_STACKING_HEIGHT);
97
98#[derive(Debug, Error)]
99pub enum VkBuildError {
100    #[error("IO error: {0}")]
101    IO(#[from] std::io::Error),
102    #[error("Serialization error: {0}")]
103    Bincode(#[from] bincode::Error),
104}
105
106impl SP1NormalizeInputShape {
107    pub fn dummy_input(
108        &self,
109        vk: SP1VerifyingKey,
110    ) -> SP1NormalizeWitnessValues<SP1GlobalContext, SP1PcsProofInner> {
111        let shard_proofs = self
112            .proof_shapes
113            .iter()
114            .map(|core_shape| {
115                dummy_shard_proof(
116                    core_shape.shard_chips.clone(),
117                    self.max_log_row_count,
118                    core_fri_config(),
119                    self.log_stacking_height,
120                    &[
121                        core_shape.preprocessed_area >> self.log_stacking_height,
122                        core_shape.main_area >> self.log_stacking_height,
123                    ],
124                    &[core_shape.preprocessed_padding_cols, core_shape.main_padding_cols],
125                )
126            })
127            .collect::<Vec<_>>();
128
129        SP1NormalizeWitnessValues {
130            vk: vk.vk,
131            shard_proofs,
132            is_complete: false,
133            vk_root: [SP1Field::zero(); DIGEST_SIZE],
134            reconstruct_deferred_digest: [SP1Field::zero(); 8],
135            num_deferred_proofs: SP1Field::zero(),
136        }
137    }
138}
139
140pub struct SP1NormalizeCache {
141    lru: Arc<Mutex<LruCache<SP1NormalizeInputShape, Arc<RecursionProgram<SP1Field>>>>>,
142    total_calls: AtomicUsize,
143    hits: AtomicUsize,
144}
145
146impl SP1NormalizeCache {
147    pub fn new(size: usize) -> Self {
148        let size = NonZero::new(size).expect("size must be non-zero");
149        let lru = LruCache::new(size);
150        let lru = Arc::new(Mutex::new(lru));
151        Self { lru, total_calls: AtomicUsize::new(0), hits: AtomicUsize::new(0) }
152    }
153
154    pub fn get(&self, shape: &SP1NormalizeInputShape) -> Option<Arc<RecursionProgram<SP1Field>>> {
155        self.total_calls.fetch_add(1, Ordering::Relaxed);
156        if let Some(program) = self.lru.lock().unwrap().get(shape).cloned() {
157            self.hits.fetch_add(1, Ordering::Relaxed);
158            Some(program)
159        } else {
160            None
161        }
162    }
163
164    pub fn push(&self, shape: SP1NormalizeInputShape, program: Arc<RecursionProgram<SP1Field>>) {
165        self.lru.lock().unwrap().push(shape, program);
166    }
167
168    pub fn stats(&self) -> (usize, usize, f64) {
169        (
170            self.total_calls.load(Ordering::Relaxed),
171            self.hits.load(Ordering::Relaxed),
172            self.hits.load(Ordering::Relaxed) as f64
173                / self.total_calls.load(Ordering::Relaxed) as f64,
174        )
175    }
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
179pub struct SP1RecursionProofShape {
180    pub shape: RecursionShape<SP1Field>,
181}
182
183impl Default for SP1RecursionProofShape {
184    fn default() -> Self {
185        Self::compress_proof_shape_from_arity(DEFAULT_ARITY).unwrap()
186    }
187}
188
189impl SP1RecursionProofShape {
190    pub fn compress_proof_shape_from_arity(arity: usize) -> Option<Self> {
191        match arity {
192            DEFAULT_ARITY => {
193                let file = include_bytes!("../compress_shape.json");
194                serde_json::from_slice(file).ok().or_else(|| {
195                    tracing::warn!("Failed to load compress_shape.json, using default shape.");
196                    // This is not a well-tuned shape, but is likely to be big enough even if
197                    // relatively substantial changes are made to the verifier.
198                    Some(SP1RecursionProofShape {
199                        shape: [
200                            (
201                                CompressAir::<SP1Field>::MemoryConst(MemoryConstChip::default()),
202                                600_000usize.next_multiple_of(32),
203                            ),
204                            (
205                                CompressAir::<SP1Field>::MemoryVar(MemoryVarChip::default()),
206                                500_000usize.next_multiple_of(32),
207                            ),
208                            (
209                                CompressAir::<SP1Field>::BaseAlu(BaseAluChip),
210                                500_000usize.next_multiple_of(32),
211                            ),
212                            (
213                                CompressAir::<SP1Field>::ExtAlu(ExtAluChip),
214                                850_000usize.next_multiple_of(32),
215                            ),
216                            (
217                                CompressAir::<SP1Field>::Poseidon2Wide(Poseidon2WideChip),
218                                150_448usize.next_multiple_of(32),
219                            ),
220                            (
221                                CompressAir::<SP1Field>::PrefixSumChecks(PrefixSumChecksChip),
222                                275_440usize.next_multiple_of(32),
223                            ),
224                            (
225                                CompressAir::<SP1Field>::Select(SelectChip),
226                                800_000usize.next_multiple_of(32),
227                            ),
228                            (CompressAir::<SP1Field>::PublicValues(PublicValuesChip), 16usize),
229                        ]
230                        .into_iter()
231                        .collect(),
232                    })
233                })
234            }
235            _ => None,
236        }
237    }
238
239    pub fn dummy_input(
240        &self,
241        arity: usize,
242        height: usize,
243        chips: BTreeSet<Chip<SP1Field, CompressAir<SP1Field>>>,
244        max_log_row_count: usize,
245        fri_config: FriConfig<SP1Field>,
246        log_stacking_height: usize,
247    ) -> SP1CompressWithVKeyWitnessValues<SP1PcsProofInner> {
248        let dummy_vk = dummy_vk();
249
250        let preprocessed_multiple = chips
251            .iter()
252            .map(|chip| self.shape.height(chip).unwrap() * chip.preprocessed_width())
253            .sum::<usize>()
254            .div_ceil(1 << log_stacking_height);
255
256        let main_multiple = chips
257            .iter()
258            .map(|chip| self.shape.height(chip).unwrap() * chip.width())
259            .sum::<usize>()
260            .div_ceil(1 << log_stacking_height);
261
262        let preprocessed_padding_cols = ((preprocessed_multiple * (1 << log_stacking_height))
263            - chips
264                .iter()
265                .map(|chip| self.shape.height(chip).unwrap() * chip.preprocessed_width())
266                .sum::<usize>())
267        .div_ceil(1 << max_log_row_count)
268        .max(1);
269
270        let main_padding_cols = ((main_multiple * (1 << log_stacking_height))
271            - chips
272                .iter()
273                .map(|chip| self.shape.height(chip).unwrap() * chip.width())
274                .sum::<usize>())
275        .div_ceil(1 << max_log_row_count)
276        .max(1);
277
278        let dummy_proof = dummy_shard_proof(
279            chips,
280            max_log_row_count,
281            fri_config,
282            log_stacking_height,
283            &[preprocessed_multiple, main_multiple],
284            &[preprocessed_padding_cols, main_padding_cols],
285        );
286
287        let vks_and_proofs =
288            (0..arity).map(|_| (dummy_vk.clone(), dummy_proof.clone())).collect::<Vec<_>>();
289
290        SP1CompressWithVKeyWitnessValues {
291            compress_val: SP1ShapedWitnessValues { vks_and_proofs, is_complete: false },
292            merkle_val: SP1MerkleProofWitnessValues::dummy(arity, height),
293        }
294    }
295
296    pub async fn check_compatibility(
297        &self,
298        program: Arc<RecursionProgram<SP1Field>>,
299        machine: Machine<SP1Field, CompressAir<SP1Field>>,
300    ) -> bool {
301        // Generate the preprocessed traces to get the heights.
302        let trace_generator = DefaultTraceGenerator::new(machine);
303        let setup_permits = ProverSemaphore::new(1);
304        let preprocessed_traces = trace_generator
305            .generate_preprocessed_traces(program, RECURSION_MAX_LOG_ROW_COUNT, setup_permits)
306            .await;
307
308        let mut is_compatible = true;
309        for (chip, trace) in preprocessed_traces.preprocessed_traces.into_iter() {
310            let real_height = trace.num_real_entries();
311            let expected_height = self.shape.height_of_name(&chip).unwrap();
312            if real_height > expected_height {
313                tracing::warn!(
314                    "program is incompatible with shape: {} > {} for chip {}",
315                    real_height,
316                    expected_height,
317                    chip
318                );
319                is_compatible = false;
320            }
321        }
322        is_compatible
323    }
324
325    #[allow(dead_code)]
326    async fn max_arity<C: SP1ProverComponents>(
327        &self,
328        vk_verification: bool,
329        height: usize,
330    ) -> usize {
331        let mut arity = 0;
332        let compress_verifier = C::compress_verifier();
333        let recursive_compress_verifier =
334            recursive_verifier::<_, _, InnerConfig>(compress_verifier.shard_verifier());
335        for possible_arity in 1.. {
336            let input = dummy_compose_input(&compress_verifier, self, possible_arity, height);
337            let program =
338                compose_program_from_input(&recursive_compress_verifier, vk_verification, &input);
339            let program = Arc::new(program);
340            let is_compatible =
341                self.check_compatibility(program, compress_verifier.machine().clone()).await;
342            if !is_compatible {
343                break;
344            }
345            arity = possible_arity;
346        }
347        arity
348    }
349}
350
351pub async fn build_vk_map<A: ArtifactClient, C: SP1ProverComponents + 'static>(
352    dummy: bool,
353    num_compiler_workers: usize,
354    num_setup_workers: usize,
355    indices: Option<Vec<usize>>,
356    max_arity: usize,
357    merkle_tree_height: usize,
358    vk_worker: Arc<RecursionVkWorker<C>>,
359) -> (BTreeSet<[SP1Field; DIGEST_SIZE]>, Vec<usize>) {
360    let recursion_permits = vk_worker.recursion_permits.clone();
361    let recursion_prover = vk_worker.recursion_prover.clone();
362    let shrink_prover = vk_worker.shrink_prover.clone();
363    if dummy {
364        let dummy_set = dummy_vk_map::<C>().into_keys().collect();
365        return (dummy_set, vec![]);
366    }
367
368    // Setup the channels.
369    let (vk_tx, mut vk_rx) =
370        tokio::sync::mpsc::unbounded_channel::<(usize, [SP1Field; DIGEST_SIZE])>();
371    let (shape_tx, shape_rx) =
372        tokio::sync::mpsc::channel::<(usize, SP1RecursionProgramShape)>(num_compiler_workers);
373    let (program_tx, program_rx) = tokio::sync::mpsc::channel(num_setup_workers);
374    let (panic_tx, mut panic_rx) = tokio::sync::mpsc::unbounded_channel();
375
376    // Setup the mutexes.
377    let shape_rx = Arc::new(tokio::sync::Mutex::new(shape_rx));
378    let program_rx = Arc::new(tokio::sync::Mutex::new(program_rx));
379
380    // Generate all the possible shape inputs we encounter in recursion. This may span normalize,
381    // compose (of any arity), deferred, shrink, etc.
382    let all_shapes = create_all_input_shapes(C::core_verifier().machine().shape(), max_arity);
383
384    let num_shapes = all_shapes.len();
385
386    let height = if indices.is_some() { merkle_tree_height } else { log2_ceil_usize(num_shapes) };
387
388    let indices_set = indices.map(|indices| indices.into_iter().collect::<HashSet<_>>());
389
390    let vk_map_size = indices_set.as_ref().map(|indices| indices.len()).unwrap_or(num_shapes);
391
392    let mut set = JoinSet::new();
393
394    // Initialize compiler workers.
395    for _ in 0..num_compiler_workers {
396        let program_tx = program_tx.clone();
397        let shape_rx = shape_rx.clone();
398        let panic_tx = panic_tx.clone();
399        set.spawn(async move {
400            while let Some((i, shape)) = shape_rx.lock().await.recv().await {
401                // eprintln!("shape: {:?}", shape);
402                let compress_verifier = C::compress_verifier();
403                let recursive_compress_verifier =
404                    recursive_verifier::<_, _, InnerConfig>(compress_verifier.shard_verifier());
405                // Spawn on another thread to handle panics.
406                let program_thread = tokio::spawn(async move {
407                    let reduce_shape =
408                        SP1RecursionProofShape::compress_proof_shape_from_arity(max_arity);
409                    match shape {
410                        SP1RecursionProgramShape::Normalize(shape_clone) => {
411                            let normalize_shape = SP1NormalizeInputShape {
412                                proof_shapes: vec![shape_clone],
413                                max_log_row_count: CORE_MAX_LOG_ROW_COUNT,
414                                log_blowup: CORE_LOG_BLOWUP,
415                                log_stacking_height: CORE_LOG_STACKING_HEIGHT as usize,
416                            };
417                            let dummy_vk = dummy_vk();
418                            let core_verifier = C::core_verifier();
419                            let recursive_core_verifier = recursive_verifier::<_, _, InnerConfig>(
420                                core_verifier.shard_verifier(),
421                            );
422                            let witness =
423                                normalize_shape.dummy_input(SP1VerifyingKey { vk: dummy_vk });
424                            let mut program =
425                                normalize_program_from_input(&recursive_core_verifier, &witness);
426                            program.shape =
427                                Some(reduce_shape.clone().expect("max arity not supported").shape);
428                            (Arc::new(program), false)
429                        }
430                        SP1RecursionProgramShape::Compose(arity) => {
431                            let dummy_input = dummy_compose_input(
432                                &compress_verifier,
433                                &SP1RecursionProofShape::compress_proof_shape_from_arity(max_arity)
434                                    .expect("max arity not supported"),
435                                arity,
436                                height,
437                            );
438
439                            let mut program = compose_program_from_input(
440                                &recursive_compress_verifier,
441                                true,
442                                &dummy_input,
443                            );
444                            program.shape =
445                                Some(reduce_shape.clone().expect("max arity not supported").shape);
446                            (Arc::new(program), false)
447                        }
448                        SP1RecursionProgramShape::Deferred => {
449                            let dummy_input = dummy_deferred_input(
450                                &C::compress_verifier(),
451                                &reduce_shape.clone().expect("max arity not supported"),
452                                height,
453                            );
454                            let mut program = deferred_program_from_input(
455                                &recursive_compress_verifier,
456                                true,
457                                &dummy_input,
458                            );
459
460                            program.shape =
461                                Some(reduce_shape.clone().expect("max arity not supported").shape);
462
463                            (Arc::new(program), false)
464                        }
465                        SP1RecursionProgramShape::Shrink => {
466                            let dummy_input = dummy_compose_input(
467                                &C::compress_verifier(),
468                                &reduce_shape.clone().expect("max arity not supported"),
469                                1,
470                                height,
471                            );
472                            let program = shrink_program_from_input(
473                                &recursive_compress_verifier,
474                                true,
475                                &dummy_input,
476                            );
477
478                            (Arc::new(program), true)
479                        }
480                    }
481                });
482                match program_thread.await {
483                    Ok((program, is_shrink)) => {
484                        program_tx.send((i, program, is_shrink)).await.unwrap()
485                    }
486                    Err(e) => {
487                        tracing::warn!(
488                            "Program generation failed for shape {}, with error: {:?}",
489                            i,
490                            e
491                        );
492                        panic_tx.send(i).unwrap();
493                    }
494                }
495            }
496        });
497    }
498
499    let recursion_prover = recursion_prover.clone();
500    // Initialize setup workers.
501    for _ in 0..num_setup_workers {
502        let vk_tx = vk_tx.clone();
503        let program_rx = program_rx.clone();
504        let prover = recursion_prover.clone();
505        let recursion_permits = recursion_permits.clone();
506        let shrink_prover = shrink_prover.clone();
507        set.spawn(async move {
508            let mut done = 0;
509            while let Some((i, program, is_shrink)) = program_rx.lock().await.recv().await {
510                let prover = prover.clone();
511                let shrink_prover = shrink_prover.clone();
512                let recursion_permits = recursion_permits.clone();
513                let vk_thread = tokio::spawn(async move {
514                    if is_shrink {
515                        shrink_prover.setup(program).await
516                    } else {
517                        prover.setup(program, recursion_permits.clone()).await.1
518                    }
519                });
520                let vk = vk_thread.await.unwrap();
521                done += 1;
522
523                let vk_digest = vk.hash_koalabear();
524
525                tracing::info!(
526                    "program {} = {:?}, {}% done",
527                    i,
528                    vk_digest,
529                    done * 100 / vk_map_size
530                );
531                vk_tx.send((i, vk_digest)).unwrap();
532            }
533        });
534    }
535
536    // Generate shapes and send them to the compiler workers.
537    let subset_shapes = all_shapes
538        .into_iter()
539        .enumerate()
540        .filter(|(i, _)| indices_set.as_ref().map(|set| set.contains(i)).unwrap_or(true))
541        .collect::<Vec<_>>();
542
543    for (i, shape) in subset_shapes.iter() {
544        shape_tx.send((*i, shape.clone())).await.unwrap();
545    }
546
547    drop(shape_tx);
548    drop(program_tx);
549    drop(vk_tx);
550    drop(panic_tx);
551
552    set.join_all().await;
553
554    let mut vk_map = BTreeMap::new();
555    while let Some((i, vk)) = vk_rx.recv().await {
556        vk_map.insert(i, vk);
557    }
558
559    let mut panic_indices = vec![];
560    while let Some(i) = panic_rx.recv().await {
561        panic_indices.push(i);
562    }
563    for (i, shape) in subset_shapes {
564        if panic_indices.contains(&i) {
565            tracing::info!("panic shape {}: {:?}", i, shape);
566        }
567    }
568
569    // Build vk_set in lexicographic order.
570    let vk_set: BTreeSet<[SP1Field; DIGEST_SIZE]> = vk_map.into_values().collect();
571
572    (vk_set, panic_indices)
573}
574
575fn max_main_multiple_for_preprocessed_multiple(preprocessed_multiple: usize) -> usize {
576    (PADDED_ELEMENT_THRESHOLD - preprocessed_multiple as u64 * (1 << CORE_LOG_STACKING_HEIGHT))
577        .div_ceil(1 << CORE_LOG_STACKING_HEIGHT as u64) as usize
578}
579
580pub fn create_all_input_shapes(
581    core_shape: &MachineShape<SP1Field, RiscvAir<SP1Field>>,
582    max_arity: usize,
583) -> Vec<SP1RecursionProgramShape> {
584    let (max_preprocessed_multiple, _, capacity) = normalize_program_parameter_space();
585    let max_num_padding_cols =
586        ((1 << CORE_LOG_STACKING_HEIGHT) as usize).div_ceil(1 << CORE_MAX_LOG_ROW_COUNT);
587
588    let mut result: Vec<SP1RecursionProgramShape> = Vec::with_capacity(capacity);
589    for preprocessed_multiple in 1..=max_preprocessed_multiple {
590        for main_multiple in 1..=max_main_multiple_for_preprocessed_multiple(preprocessed_multiple)
591        {
592            for main_padding_cols in 1..=max_num_padding_cols {
593                for preprocessed_padding_cols in 1..=max_num_padding_cols {
594                    for cluster in &core_shape.chip_clusters {
595                        result.push(SP1RecursionProgramShape::Normalize(CoreProofShape {
596                            shard_chips: cluster.clone(),
597                            preprocessed_area: preprocessed_multiple << CORE_LOG_STACKING_HEIGHT,
598                            main_area: main_multiple << CORE_LOG_STACKING_HEIGHT,
599                            preprocessed_padding_cols,
600                            main_padding_cols,
601                        }));
602                    }
603                }
604            }
605        }
606    }
607
608    // Add the compose shapes for each arity.
609    for arity in 1..=max_arity {
610        result.push(SP1RecursionProgramShape::Compose(arity));
611    }
612
613    // Add the deferred shape.
614    result.push(SP1RecursionProgramShape::Deferred);
615    // Add the shrink shape.
616    result.push(SP1RecursionProgramShape::Shrink);
617    result
618}
619
620pub fn normalize_program_parameter_space() -> (usize, usize, usize) {
621    let max_preprocessed_multiple = (MAX_PROGRAM_SIZE * NUM_PROGRAM_PREPROCESSED_COLS
622        + (1 << 17) * NUM_RANGE_PREPROCESSED_COLS
623        + (1 << 16) * NUM_BYTE_PREPROCESSED_COLS)
624        .div_ceil(1 << CORE_LOG_STACKING_HEIGHT);
625    let max_main_multiple =
626        (PADDED_ELEMENT_THRESHOLD).div_ceil(1 << CORE_LOG_STACKING_HEIGHT) as usize;
627
628    let num_shapes = (0..=max_preprocessed_multiple)
629        .map(max_main_multiple_for_preprocessed_multiple)
630        .sum::<usize>();
631
632    (max_preprocessed_multiple, max_main_multiple, num_shapes)
633}
634
635pub fn dummy_vk_map<C: SP1ProverComponents>() -> BTreeMap<[SP1Field; DIGEST_SIZE], usize> {
636    create_all_input_shapes(C::core_verifier().machine().shape(), DEFAULT_ARITY)
637        .iter()
638        .enumerate()
639        .map(|(i, _)| ([SP1Field::from_canonical_usize(i); DIGEST_SIZE], i))
640        .collect()
641}
642
643pub fn max_count(a: RecursionAirEventCount, b: RecursionAirEventCount) -> RecursionAirEventCount {
644    use std::cmp::max;
645    RecursionAirEventCount {
646        mem_const_events: max(a.mem_const_events, b.mem_const_events),
647        mem_var_events: max(a.mem_var_events, b.mem_var_events),
648        base_alu_events: max(a.base_alu_events, b.base_alu_events),
649        ext_alu_events: max(a.ext_alu_events, b.ext_alu_events),
650        ext_felt_conversion_events: max(a.ext_felt_conversion_events, b.ext_felt_conversion_events),
651        poseidon2_wide_events: max(a.poseidon2_wide_events, b.poseidon2_wide_events),
652        poseidon2_linear_layer_events: max(
653            a.poseidon2_linear_layer_events,
654            b.poseidon2_linear_layer_events,
655        ),
656        poseidon2_sbox_events: max(a.poseidon2_sbox_events, b.poseidon2_sbox_events),
657        select_events: max(a.select_events, b.select_events),
658        prefix_sum_checks_events: max(a.prefix_sum_checks_events, b.prefix_sum_checks_events),
659        commit_pv_hash_events: max(a.commit_pv_hash_events, b.commit_pv_hash_events),
660    }
661}
662
663pub fn create_test_shape(
664    cluster: &BTreeSet<Chip<SP1Field, RiscvAir<SP1Field>>>,
665) -> SP1NormalizeInputShape {
666    let preprocessed_multiple = (MAX_PROGRAM_SIZE * NUM_PROGRAM_PREPROCESSED_COLS
667        + (1 << 17) * NUM_RANGE_PREPROCESSED_COLS
668        + (1 << 16) * NUM_BYTE_PREPROCESSED_COLS)
669        .div_ceil(1 << CORE_LOG_STACKING_HEIGHT);
670    let main_multiple = (PADDED_ELEMENT_THRESHOLD).div_ceil(1 << CORE_LOG_STACKING_HEIGHT) as usize;
671    let num_padding_cols =
672        ((1 << CORE_LOG_STACKING_HEIGHT) as usize).div_ceil(1 << CORE_MAX_LOG_ROW_COUNT);
673    SP1NormalizeInputShape {
674        proof_shapes: vec![CoreProofShape {
675            shard_chips: cluster.clone(),
676            preprocessed_area: preprocessed_multiple << CORE_LOG_STACKING_HEIGHT,
677            main_area: main_multiple << CORE_LOG_STACKING_HEIGHT,
678            preprocessed_padding_cols: num_padding_cols,
679            main_padding_cols: num_padding_cols,
680        }],
681        max_log_row_count: CORE_MAX_LOG_ROW_COUNT,
682        log_stacking_height: CORE_LOG_STACKING_HEIGHT as usize,
683        log_blowup: CORE_LOG_BLOWUP,
684    }
685}
686
687pub fn build_recursion_count_from_shape(
688    shape: &RecursionShape<SP1Field>,
689) -> RecursionAirEventCount {
690    RecursionAirEventCount {
691        mem_const_events: shape
692            .height(&CompressAir::<SP1Field>::MemoryConst(MemoryConstChip::default()))
693            .unwrap(),
694        mem_var_events: shape
695            .height(&CompressAir::<SP1Field>::MemoryVar(MemoryVarChip::<SP1Field, 2>::default()))
696            .unwrap()
697            * 2,
698        base_alu_events: shape.height(&CompressAir::<SP1Field>::BaseAlu(BaseAluChip)).unwrap(),
699        ext_alu_events: shape.height(&CompressAir::<SP1Field>::ExtAlu(ExtAluChip)).unwrap(),
700        ext_felt_conversion_events: shape
701            .height(&CompressAir::<SP1Field>::ExtFeltConvert(ConvertChip))
702            .unwrap_or(0),
703        poseidon2_wide_events: shape
704            .height(&CompressAir::<SP1Field>::Poseidon2Wide(Poseidon2WideChip))
705            .unwrap_or(0),
706        poseidon2_linear_layer_events: shape
707            .height(&CompressAir::<SP1Field>::Poseidon2LinearLayer(Poseidon2LinearLayerChip))
708            .unwrap_or(0),
709        poseidon2_sbox_events: shape
710            .height(&CompressAir::<SP1Field>::Poseidon2SBox(Poseidon2SBoxChip))
711            .unwrap_or(0),
712        select_events: shape.height(&CompressAir::<SP1Field>::Select(SelectChip)).unwrap(),
713        prefix_sum_checks_events: shape
714            .height(&CompressAir::<SP1Field>::PrefixSumChecks(PrefixSumChecksChip))
715            .unwrap_or(0),
716        commit_pv_hash_events: shape
717            .height(&CompressAir::<SP1Field>::PublicValues(PublicValuesChip))
718            .unwrap(),
719    }
720}
721
722pub fn build_shape_from_recursion_air_event_count(
723    event_count: &RecursionAirEventCount,
724) -> SP1RecursionProofShape {
725    let &RecursionAirEventCount {
726        mem_const_events,
727        mem_var_events,
728        base_alu_events,
729        ext_alu_events,
730        poseidon2_wide_events,
731        select_events,
732        prefix_sum_checks_events,
733        commit_pv_hash_events,
734        ..
735    } = event_count;
736    let chips_and_heights = vec![
737        (CompressAir::<SP1Field>::MemoryConst(MemoryConstChip::default()), mem_const_events),
738        (
739            CompressAir::<SP1Field>::MemoryVar(MemoryVarChip::<SP1Field, 2>::default()),
740            mem_var_events / 2,
741        ),
742        (CompressAir::<SP1Field>::BaseAlu(BaseAluChip), base_alu_events),
743        (CompressAir::<SP1Field>::ExtAlu(ExtAluChip), ext_alu_events),
744        (CompressAir::<SP1Field>::Poseidon2Wide(Poseidon2WideChip), poseidon2_wide_events),
745        (CompressAir::<SP1Field>::Select(SelectChip), select_events),
746        (CompressAir::<SP1Field>::PrefixSumChecks(PrefixSumChecksChip), prefix_sum_checks_events),
747        (CompressAir::<SP1Field>::PublicValues(PublicValuesChip), commit_pv_hash_events),
748    ];
749    SP1RecursionProofShape { shape: chips_and_heights.into_iter().collect() }
750}
751
752#[cfg(test)]
753mod tests {
754    use anyhow::Context;
755
756    use crate::{
757        recursion::{
758            compose_program_from_input, deferred_program_from_input, dummy_compose_input,
759            dummy_deferred_input, normalize_program_from_input, recursive_verifier,
760        },
761        worker::SP1LightNode,
762        CpuSP1ProverComponents,
763    };
764    #[cfg(feature = "experimental")]
765    use sp1_core_executor::SP1Context;
766
767    use sp1_core_machine::utils::setup_logger;
768    use sp1_recursion_compiler::config::InnerConfig;
769    use sp1_recursion_executor::RecursionAirEventCount;
770
771    use super::*;
772
773    #[tokio::test]
774    #[ignore = "should be invoked specifically"]
775    async fn test_max_arity() {
776        setup_logger();
777        let client = SP1LightNode::new().await;
778
779        let vk_verification = client.inner().vk_verification();
780        let allowed_vk_height = client.inner().allowed_vk_height();
781
782        let reduce_shape = SP1RecursionProofShape::compress_proof_shape_from_arity(DEFAULT_ARITY)
783            .expect("default arity shape should be valid");
784
785        let arity = reduce_shape
786            .max_arity::<CpuSP1ProverComponents>(vk_verification, allowed_vk_height)
787            .await;
788
789        tracing::info!("arity: {}", arity);
790    }
791
792    #[derive(Debug, Error)]
793    enum ShapeError {
794        #[error("Expected arity to be {DEFAULT_ARITY}, found {_0}")]
795        WrongArity(usize),
796        #[error(
797            "Expected the arity {DEFAULT_ARITY} shape to be large enough
798                to accommodate all core shard proof shapes."
799        )]
800        CoreShapesTooLarge,
801        #[error("Expected height of chip {_0} to be a multiple of 32")]
802        ChipHeightNotMultipleOf32(String),
803        #[error("Expected the shape to be minimal")]
804        ShapeNotMinimal,
805        #[error("Public values chip height is not 16")]
806        PublicValuesChipHeightNot16,
807    }
808
809    #[tokio::test]
810    async fn test_core_shape_fit() -> Result<(), anyhow::Error> {
811        setup_logger();
812        let elf = test_artifacts::FIBONACCI_ELF;
813        let client = SP1LightNode::new().await;
814        // let prover = SP1ProverBuilder::new().without_recursion_vks().build().await;
815        let vk = client.setup(&elf).await?;
816
817        let context =
818            "Shape is not valid. To fix: From sp1-wip directory, run `cargo test --release -p sp1-prover --features experimental -- test_find_recursion_shape --include-ignored`";
819
820        let machine = RiscvAir::<SP1Field>::machine();
821        let chip_clusters = &machine.shape().chip_clusters;
822        let mut max_cluster_count = RecursionAirEventCount::default();
823
824        let core_verifier = CpuSP1ProverComponents::core_verifier();
825        let recursive_core_verifier =
826            recursive_verifier::<SP1GlobalContext, _, InnerConfig>(core_verifier.shard_verifier());
827
828        for cluster in chip_clusters {
829            let shape = create_test_shape(cluster);
830            let program = normalize_program_from_input(
831                &recursive_core_verifier,
832                &shape.dummy_input(vk.clone()),
833            );
834            max_cluster_count = max_count(max_cluster_count, program.event_counts);
835        }
836
837        let vk_verification = client.inner().vk_verification();
838        let allowed_vk_height = client.inner().allowed_vk_height();
839
840        tracing::info!("max_cluster_count: {:?}", max_cluster_count);
841
842        let reduce_shape =
843            SP1RecursionProofShape::compress_proof_shape_from_arity(DEFAULT_ARITY).unwrap();
844        let arity = reduce_shape
845            .max_arity::<CpuSP1ProverComponents>(vk_verification, allowed_vk_height)
846            .await;
847        if arity != DEFAULT_ARITY {
848            return Err(ShapeError::WrongArity(arity)).context(context);
849        }
850
851        // Check that the deferred program fits within the reduce shape.
852        {
853            let compress_verifier = CpuSP1ProverComponents::compress_verifier();
854            let recursive_compress_verifier = recursive_verifier::<SP1GlobalContext, _, InnerConfig>(
855                compress_verifier.shard_verifier(),
856            );
857            let deferred_input =
858                dummy_deferred_input(&compress_verifier, &reduce_shape, allowed_vk_height);
859            let deferred_program = deferred_program_from_input(
860                &recursive_compress_verifier,
861                vk_verification,
862                &deferred_input,
863            );
864            let deferred_count = deferred_program.event_counts;
865            tracing::info!("deferred_count: {:?}", deferred_count);
866            max_cluster_count = max_count(max_cluster_count, deferred_count);
867        }
868
869        let arity_4_count = build_recursion_count_from_shape(&reduce_shape.shape);
870        let combined_count = max_count(max_cluster_count, arity_4_count);
871
872        let max_cluster_shape = build_shape_from_recursion_air_event_count(&max_cluster_count);
873        if combined_count != arity_4_count {
874            return Err(ShapeError::CoreShapesTooLarge).context(context);
875        }
876
877        for (chip, height) in (&reduce_shape.shape).into_iter() {
878            if chip != "PublicValues" {
879                if !height.is_multiple_of(32) {
880                    return Err(ShapeError::ChipHeightNotMultipleOf32(chip.clone()))
881                        .context(context);
882                }
883                let mut new_reduce_shape = reduce_shape.clone();
884
885                new_reduce_shape.shape.insert_with_name(chip, height - 32);
886
887                if !(new_reduce_shape
888                    .max_arity::<CpuSP1ProverComponents>(vk_verification, allowed_vk_height)
889                    .await
890                    < DEFAULT_ARITY
891                    || new_reduce_shape.shape.height_of_name(chip).unwrap()
892                        < max_cluster_shape
893                            .shape
894                            .height_of_name(chip)
895                            .unwrap()
896                            .next_multiple_of(32))
897                {
898                    return Err(ShapeError::ShapeNotMinimal).context(context);
899                }
900            } else if *height != 16 {
901                return Err(ShapeError::PublicValuesChipHeightNot16).context(context);
902            }
903        }
904        Ok(())
905    }
906
907    #[cfg(feature = "experimental")]
908    use serial_test::serial;
909
910    #[tokio::test]
911    #[serial]
912    #[cfg(feature = "experimental")]
913    async fn test_build_vk_map() {
914        use std::fs::File;
915
916        use either::Either;
917
918        use sp1_core_machine::io::SP1Stdin;
919        use sp1_prover_types::network_base_types::ProofMode;
920        use sp1_verifier::SP1Proof;
921
922        use crate::worker::{cpu_worker_builder, SP1LocalNodeBuilder};
923
924        setup_logger();
925
926        // Use a temporary directory for the vk_map file to avoid conflicts
927        let temp_dir = std::env::temp_dir();
928        let vk_map_path = temp_dir.join("vk_map.bin");
929
930        // Clean up any existing file from previous runs
931        let _ = std::fs::remove_file(&vk_map_path);
932
933        let node = SP1LocalNodeBuilder::from_worker_client_builder(cpu_worker_builder())
934            .build()
935            .await
936            .unwrap();
937
938        let elf = test_artifacts::FIBONACCI_ELF;
939
940        // Make a proof to get proof shapes to populate the vk map with.
941        let vk = node.setup(&elf).await.expect("Failed to setup");
942
943        let proof = node
944            .prove_with_mode(&elf, SP1Stdin::default(), SP1Context::default(), ProofMode::Core)
945            .await
946            .expect("Failed to prove");
947
948        // Create all circuit shapes.
949        let shapes = create_all_input_shapes(
950            CpuSP1ProverComponents::core_verifier().shard_verifier().machine().shape(),
951            DEFAULT_ARITY,
952        );
953
954        // Determine the indices in `shapes` of the shapes appear in the proof.
955        let mut shape_indices = vec![];
956
957        let core_proof = match proof.proof {
958            SP1Proof::Core(proof) => proof,
959            _ => panic!("Expected core proof"),
960        };
961
962        for proof in &core_proof {
963            let shape = SP1RecursionProgramShape::Normalize(
964                CpuSP1ProverComponents::core_verifier().shape_from_proof(proof),
965            );
966
967            shape_indices.push(shapes.iter().position(|s| s == &shape).unwrap());
968        }
969
970        let shape_indices =
971            shape_indices.into_iter().chain(shapes.len() - 12..shapes.len()).collect::<Vec<_>>();
972
973        let result = node.build_vks(Some(Either::Left(shape_indices)), 4).await.unwrap();
974
975        let vk_map_path = temp_dir.join("vk_map.bin");
976
977        // Create the file to store the vk map.
978        let mut file = File::create(vk_map_path.clone()).unwrap();
979
980        bincode::serialize_into(&mut file, &result.vk_map).unwrap();
981
982        // Build a new prover that performs the vk verification check using the built vk map.
983        let node = SP1LocalNodeBuilder::from_worker_client_builder(
984            cpu_worker_builder().with_vk_map_path(vk_map_path.to_str().unwrap().to_string()),
985        )
986        .build()
987        .await
988        .unwrap();
989
990        tracing::info!("Rebuilt prover with vk map.");
991
992        // Make a proof with the vks checked.
993        let proof = node
994            .prove_with_mode(
995                &elf,
996                SP1Stdin::default(),
997                SP1Context::default(),
998                ProofMode::Compressed,
999            )
1000            .await
1001            .expect("Failed to prove");
1002
1003        node.verify(&vk, &proof.proof).unwrap();
1004
1005        std::fs::remove_file(vk_map_path).unwrap();
1006    }
1007
1008    #[tokio::test]
1009    #[ignore = "should be invoked for shape tuning"]
1010    async fn test_find_recursion_shape() {
1011        setup_logger();
1012        let elf = test_artifacts::FIBONACCI_ELF;
1013        let client = SP1LightNode::new().await;
1014        let vk = client.setup(&elf).await.unwrap();
1015
1016        let machine = RiscvAir::<SP1Field>::machine();
1017        let chip_clusters = &machine.shape().chip_clusters;
1018        let allowed_vk_height = client.inner().allowed_vk_height();
1019        let vk_verification = client.inner().vk_verification();
1020
1021        let verifier = CpuSP1ProverComponents::compress_verifier();
1022        let dummy_input =
1023            |current_shape: &SP1RecursionProofShape| -> SP1CompressWithVKeyWitnessValues<SP1PcsProofInner> {
1024                dummy_compose_input(&verifier, current_shape, DEFAULT_ARITY, allowed_vk_height)
1025            };
1026        let core_verifier = CpuSP1ProverComponents::core_verifier();
1027        let recursive_core_verifier =
1028            recursive_verifier::<SP1GlobalContext, _, InnerConfig>(core_verifier.shard_verifier());
1029
1030        let recursive_compress_verifier =
1031            recursive_verifier::<SP1GlobalContext, _, InnerConfig>(verifier.shard_verifier());
1032        let compose_program =
1033            |input: &SP1CompressWithVKeyWitnessValues<SP1PcsProofInner>| -> Arc<RecursionProgram<SP1Field>> {
1034                Arc::new(compose_program_from_input(
1035                    &recursive_compress_verifier,
1036                    vk_verification,
1037                    input,
1038                ))
1039            };
1040
1041        // Find the recursion proof shape that fits the normalize programs verifying all core
1042        // shards.
1043        let mut max_cluster_count = RecursionAirEventCount::default();
1044
1045        for cluster in chip_clusters {
1046            let shape = create_test_shape(cluster);
1047            let program = normalize_program_from_input(
1048                &recursive_core_verifier,
1049                &shape.dummy_input(vk.clone()),
1050            );
1051            max_cluster_count = max_count(max_cluster_count, program.event_counts);
1052        }
1053
1054        // Iterate on this shape until the compose program verifying DEFAULT_ARITY proofs of shape
1055        // `current_shape` can be proved using `current_shape`.
1056        let mut current_shape = build_shape_from_recursion_air_event_count(&max_cluster_count);
1057        let trace_generator =
1058            DefaultTraceGenerator::new(CompressAir::<SP1Field>::compress_machine());
1059        loop {
1060            // Create DEFAULT_ARITY dummy proofs of shape `current_shape`
1061            let input = dummy_input(&current_shape);
1062            // Compile the program that verifies those `DEFAULT_ARITY` proofs.
1063            let program = compose_program(&input);
1064            let setup_permits = ProverSemaphore::new(1);
1065            // The preprocessed traces contain the information of the minimum required table heights
1066            // to prove the compose program.
1067            let preprocessed_traces = trace_generator
1068                .generate_preprocessed_traces(program, RECURSION_MAX_LOG_ROW_COUNT, setup_permits)
1069                .await;
1070
1071            // Check if the `current_shape` heights are insufficient.
1072            let updated_key_values = preprocessed_traces
1073                .preprocessed_traces
1074                .into_iter()
1075                .filter_map(|(chip, trace)| {
1076                    let real_height = trace.num_real_entries();
1077                    let expected_height = current_shape.shape.height_of_name(&chip).unwrap();
1078
1079                    if real_height > expected_height {
1080                        tracing::warn!(
1081                            "Insufficient height for chip {}: expected {}, got {}",
1082                            chip,
1083                            expected_height,
1084                            real_height
1085                        );
1086                        Some((chip, real_height))
1087                    } else {
1088                        None
1089                    }
1090                })
1091                .collect::<Vec<_>>();
1092
1093            // If no need to update the chip heights, `current_shape` is good enough.
1094            if updated_key_values.is_empty() {
1095                break;
1096            }
1097            // Otherwise, update the heights in `current_shape` and repeat the loop.
1098            for (chip, real_height) in updated_key_values {
1099                current_shape.shape.insert_with_name(&chip, real_height);
1100            }
1101        }
1102
1103        // Write the shape to a file.
1104        let shape = SP1RecursionProofShape {
1105            shape: RecursionShape::new(
1106                current_shape
1107                    .shape
1108                    .into_iter()
1109                    .map(|(chip, height)| {
1110                        let new_height = if chip == "PublicValues" {
1111                            height
1112                        } else {
1113                            height.next_multiple_of(32)
1114                        };
1115                        (chip, new_height)
1116                    })
1117                    .collect(),
1118            ),
1119        };
1120
1121        let mut file = std::fs::File::create("compress_shape.json").unwrap();
1122        serde_json::to_writer_pretty(&mut file, &shape).unwrap();
1123    }
1124}