Skip to main content

sp1_recursion_circuit/machine/
compress.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4    marker::PhantomData,
5    mem::MaybeUninit,
6};
7
8use itertools::Itertools;
9
10use slop_air::Air;
11
12use slop_algebra::{AbstractField, PrimeField32};
13use slop_challenger::IopCtx;
14
15use serde::{Deserialize, Serialize};
16use sp1_core_machine::riscv::MAX_LOG_NUMBER_OF_SHARDS;
17use sp1_recursion_compiler::ir::{Builder, Felt, IrIter};
18
19use sp1_primitives::{SP1Field, SP1GlobalContext};
20use sp1_recursion_executor::{RecursionPublicValues, RECURSIVE_PROOF_NUM_PV_ELTS};
21
22use sp1_hypercube::{
23    air::{MachineAir, ShardRange, POSEIDON_NUM_WORDS, PV_DIGEST_NUM_WORDS},
24    MachineVerifyingKey, ShardProof, DIGEST_SIZE,
25};
26
27use crate::{
28    challenger::CanObserveVariable,
29    machine::{
30        assert_complete, assert_recursion_public_values_valid, recursion_public_values_digest,
31        root_public_values_digest,
32    },
33    shard::{MachineVerifyingKeyVariable, RecursiveShardVerifier, ShardProofVariable},
34    zerocheck::RecursiveVerifierConstraintFolder,
35    CircuitConfig, SP1FieldConfigVariable,
36};
37
38use sp1_recursion_compiler::circuit::CircuitV2Builder;
39
40use super::InnerVal;
41/// A program to verify a batch of recursive proofs and aggregate their public values.
42#[derive(Debug, Clone, Copy)]
43pub struct SP1CompressVerifier<C, SC, A> {
44    _phantom: PhantomData<(C, SC, A)>,
45}
46
47pub enum PublicValuesOutputDigest {
48    Reduce,
49    Root,
50}
51
52/// Witness layout for the compress stage verifier.
53#[allow(clippy::type_complexity)]
54pub struct SP1ShapedWitnessVariable<C: CircuitConfig, GC: SP1FieldConfigVariable<C>> {
55    /// The shard proofs to verify.
56    pub vks_and_proofs: Vec<(MachineVerifyingKeyVariable<C, GC>, ShardProofVariable<C, GC>)>,
57    pub is_complete: Felt<SP1Field>,
58}
59
60pub type VkAndProof<GC, Proof> = (MachineVerifyingKey<GC>, ShardProof<GC, Proof>);
61
62#[derive(Clone, Serialize, Deserialize)]
63#[serde(bound(serialize = "ShardProof<GC,Proof>: Serialize"))]
64#[serde(bound(deserialize = "ShardProof<GC,Proof>: Deserialize<'de>"))]
65/// An input layout for the shard proofs that have been normalized to a standard shape.
66pub struct SP1ShapedWitnessValues<GC: IopCtx, Proof> {
67    pub vks_and_proofs: Vec<VkAndProof<GC, Proof>>,
68    pub is_complete: bool,
69}
70
71impl<GC: IopCtx, Proof> SP1ShapedWitnessValues<GC, Proof> {
72    pub fn range(&self) -> ShardRange
73    where
74        GC::F: PrimeField32,
75    {
76        let start_pv: &RecursionPublicValues<GC::F> =
77            self.vks_and_proofs[0].1.public_values.as_slice().borrow();
78        let end_pv: &RecursionPublicValues<GC::F> =
79            self.vks_and_proofs[self.vks_and_proofs.len() - 1].1.public_values.as_slice().borrow();
80
81        let start = start_pv.range().start();
82        let end = end_pv.range().end();
83
84        (start..end).into()
85    }
86}
87
88impl<C, SC, A> SP1CompressVerifier<C, SC, A>
89where
90    C: CircuitConfig<Bit = Felt<SP1Field>>,
91    A: MachineAir<InnerVal> + for<'a> Air<RecursiveVerifierConstraintFolder<'a>>,
92{
93    /// Verify a batch of recursive proofs and aggregate their public values.
94    ///
95    /// The compression verifier can aggregate proofs of different kinds:
96    /// - Core proofs: proofs which are recursive proof of a batch of SP1 shard proofs. The
97    ///   implementation in this function assumes a fixed recursive verifier specified by
98    ///   `recursive_vk`.
99    /// - Deferred proofs: proofs which are recursive proof of a batch of deferred proofs. The
100    ///   implementation in this function assumes a fixed deferred verification program specified by
101    ///   `deferred_vk`.
102    /// - Compress proofs: these are proofs which refer to a prove of this program. The key for it
103    ///   is part of public values will be propagated across all levels of recursion and will be
104    ///   checked against itself as in [sp1_prover::Prover] or as in [super::SP1RootVerifier].
105    pub fn verify(
106        builder: &mut Builder<C>,
107        machine: &RecursiveShardVerifier<SP1GlobalContext, A, C>,
108        input: SP1ShapedWitnessVariable<C, SP1GlobalContext>,
109        vk_root: [Felt<SP1Field>; DIGEST_SIZE],
110        kind: PublicValuesOutputDigest,
111    ) {
112        // Read input.
113        let SP1ShapedWitnessVariable { vks_and_proofs, is_complete } = input;
114
115        // Initialize the values for the aggregated public output.
116        let mut reduce_public_values_stream: Vec<Felt<_>> = (0..RECURSIVE_PROOF_NUM_PV_ELTS)
117            .map(|_| unsafe { MaybeUninit::zeroed().assume_init() })
118            .collect();
119        let compress_public_values: &mut RecursionPublicValues<_> =
120            reduce_public_values_stream.as_mut_slice().borrow_mut();
121
122        // Make sure there is at least one proof.
123        assert!(!vks_and_proofs.is_empty());
124
125        // Initialize the consistency check variables.
126        let mut sp1_vk_digest: [Felt<_>; DIGEST_SIZE] =
127            array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
128        let mut pc: [Felt<_>; 3] =
129            array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
130        let mut current_exit_code: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() };
131        let mut current_timestamp: [Felt<_>; 4] = array::from_fn(|_| builder.uninit());
132
133        let mut committed_value_digest: [[Felt<_>; 4]; PV_DIGEST_NUM_WORDS] =
134            array::from_fn(|_| array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() }));
135        let mut deferred_proofs_digest: [Felt<_>; POSEIDON_NUM_WORDS] =
136            array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
137        let mut deferred_proof_index: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() };
138        let mut reconstruct_deferred_digest: [Felt<_>; POSEIDON_NUM_WORDS] =
139            core::array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
140        let mut global_cumulative_sums = Vec::new();
141        let mut init_addr: [Felt<_>; 3] =
142            array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
143        let mut finalize_addr: [Felt<_>; 3] =
144            array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
145        let mut init_page_idx: [Felt<_>; 3] =
146            array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
147        let mut finalize_page_idx: [Felt<_>; 3] =
148            array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
149        let mut commit_syscall: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() };
150        let mut commit_deferred_syscall: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() };
151        let mut contains_first_shard: Felt<_> = builder.eval(SP1Field::zero());
152        let mut num_included_shard: Felt<_> = builder.eval(SP1Field::zero());
153        let mut proof_nonce: [Felt<_>; 4] =
154            array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
155
156        // Verify the shard proofs.
157        // Verification of proofs can be done in parallel but the aggregation/consistency checks
158        // must be done sequentially.
159        vks_and_proofs.iter().ir_par_map_collect::<Vec<_>, _, _>(
160            builder,
161            |builder, (vk, shard_proof)| {
162                // Prepare a challenger.
163                let mut challenger = SP1GlobalContext::challenger_variable(builder);
164
165                // Observe the vk and start pc.
166                challenger.observe(builder, vk.preprocessed_commit);
167                challenger.observe_slice(builder, vk.pc_start);
168                challenger.observe_slice(builder, vk.initial_global_cumulative_sum.0.x.0);
169                challenger.observe_slice(builder, vk.initial_global_cumulative_sum.0.y.0);
170                challenger.observe(builder, vk.untrusted_config.enable_untrusted_programs);
171                #[cfg(feature = "mprotect")]
172                {
173                    challenger.observe(builder, vk.untrusted_config.enable_trap_handler);
174                    challenger.observe_slice(builder, vk.untrusted_config.trap_context);
175                    challenger.observe_slice(builder, vk.untrusted_config.untrusted_memory);
176                }
177
178                // Observe the padding.
179                let zero: Felt<_> = builder.eval(SP1Field::zero());
180                for _ in 0..6 {
181                    challenger.observe(builder, zero);
182                }
183                // Verify the shard proof.
184                machine.verify_shard(builder, vk, shard_proof, &mut challenger);
185            },
186        );
187
188        // Check consistency and aggregate public values.
189        for (i, (_, shard_proof)) in vks_and_proofs.into_iter().enumerate() {
190            // Get the current public values.
191            let current_public_values: &RecursionPublicValues<Felt<SP1Field>> =
192                shard_proof.public_values.as_slice().borrow();
193            // Assert that the public values are valid.
194            assert_recursion_public_values_valid::<C, SP1GlobalContext>(
195                builder,
196                current_public_values,
197            );
198            // Assert that the vk root is the same as the witnessed one.
199            for (expected, actual) in vk_root.iter().zip_eq(current_public_values.vk_root.iter()) {
200                builder.assert_felt_eq(*expected, *actual);
201            }
202
203            // Verify that there are less than `(1 << MAX_LOG_NUMBER_OF_SHARDS)` included shards.
204            C::range_check_felt(
205                builder,
206                current_public_values.num_included_shard,
207                MAX_LOG_NUMBER_OF_SHARDS,
208            );
209
210            // Verify that `contains_first_shard` is boolean.
211            builder.assert_felt_eq(
212                current_public_values.contains_first_shard
213                    * (current_public_values.contains_first_shard - SP1Field::one()),
214                SP1Field::zero(),
215            );
216
217            // Accumulate the number of included shards.
218            num_included_shard =
219                builder.eval(num_included_shard + current_public_values.num_included_shard);
220
221            // Accumulate the `contains_first_shard` flag.
222            contains_first_shard =
223                builder.eval(contains_first_shard + current_public_values.contains_first_shard);
224
225            // Add the global cumulative sums to the vector.
226            global_cumulative_sums.push(current_public_values.global_cumulative_sum);
227
228            if i == 0 {
229                // Initialize global and accumulated values.
230
231                // Assign the committed values and deferred proof digests.
232                compress_public_values.prev_committed_value_digest =
233                    current_public_values.prev_committed_value_digest;
234                committed_value_digest = current_public_values.prev_committed_value_digest;
235
236                compress_public_values.prev_deferred_proofs_digest =
237                    current_public_values.prev_deferred_proofs_digest;
238                deferred_proofs_digest = current_public_values.prev_deferred_proofs_digest;
239
240                // Initialize the deferred proof index.
241                compress_public_values.prev_deferred_proof =
242                    current_public_values.prev_deferred_proof;
243                deferred_proof_index = current_public_values.prev_deferred_proof;
244
245                // Initiallize start pc.
246                compress_public_values.pc_start = current_public_values.pc_start;
247                pc = current_public_values.pc_start;
248
249                // Initialize timestamp.
250                compress_public_values.initial_timestamp = current_public_values.initial_timestamp;
251                current_timestamp = current_public_values.initial_timestamp;
252
253                // Initialize the MemoryInitialize address.
254                compress_public_values.previous_init_addr =
255                    current_public_values.previous_init_addr;
256                init_addr = current_public_values.previous_init_addr;
257
258                // Initialize the MemoryFinalize address.
259                compress_public_values.previous_finalize_addr =
260                    current_public_values.previous_finalize_addr;
261                finalize_addr = current_public_values.previous_finalize_addr;
262
263                // Initialize the PageProtInit address.
264                compress_public_values.previous_init_page_idx =
265                    current_public_values.previous_init_page_idx;
266                init_page_idx = current_public_values.previous_init_page_idx;
267
268                // Initialize the PageProtFinalize address.
269                compress_public_values.previous_finalize_page_idx =
270                    current_public_values.previous_finalize_page_idx;
271                finalize_page_idx = current_public_values.previous_finalize_page_idx;
272
273                // Initialize the start of deferred digests.
274                compress_public_values.start_reconstruct_deferred_digest =
275                    current_public_values.start_reconstruct_deferred_digest;
276                reconstruct_deferred_digest =
277                    current_public_values.start_reconstruct_deferred_digest;
278
279                // Initialize exit code.
280                compress_public_values.prev_exit_code = current_public_values.prev_exit_code;
281                current_exit_code = current_public_values.prev_exit_code;
282
283                // Initialize `commit_syscall`.
284                compress_public_values.prev_commit_syscall =
285                    current_public_values.prev_commit_syscall;
286                commit_syscall = current_public_values.prev_commit_syscall;
287
288                // Initialize `commit_deferred_syscall`.
289                compress_public_values.prev_commit_deferred_syscall =
290                    current_public_values.prev_commit_deferred_syscall;
291                commit_deferred_syscall = current_public_values.prev_commit_deferred_syscall;
292
293                // Initialize the sp1_vk digest
294                compress_public_values.sp1_vk_digest = current_public_values.sp1_vk_digest;
295                sp1_vk_digest = current_public_values.sp1_vk_digest;
296
297                // Initialize the proof nonce.
298                compress_public_values.proof_nonce = current_public_values.proof_nonce;
299                proof_nonce = current_public_values.proof_nonce;
300            }
301
302            // Assert that the current values match the accumulated values and update them.
303
304            // Assert that the `prev_committed_value_digest` is equal to current one, then update.
305            for (word, current_word) in committed_value_digest
306                .iter()
307                .zip_eq(current_public_values.prev_committed_value_digest.iter())
308            {
309                for (limb, current_limb) in word.iter().zip_eq(current_word.iter()) {
310                    builder.assert_felt_eq(*limb, *current_limb);
311                }
312            }
313            committed_value_digest = current_public_values.committed_value_digest;
314
315            // Assert that the `prev_deferred_proofs_digest` is equal to current one, then update.
316            for (limb, current_limb) in deferred_proofs_digest
317                .iter()
318                .zip_eq(current_public_values.prev_deferred_proofs_digest.iter())
319            {
320                builder.assert_felt_eq(*limb, *current_limb);
321            }
322            deferred_proofs_digest = current_public_values.deferred_proofs_digest;
323
324            // Assert that the `prev_deferred_proof` is equal to the current one, then update.
325            builder.assert_felt_eq(deferred_proof_index, current_public_values.prev_deferred_proof);
326            deferred_proof_index = current_public_values.deferred_proof;
327
328            // Assert that the start pc is equal to the current pc, then update.
329            for (limb, current_limb) in pc.iter().zip_eq(current_public_values.pc_start.iter()) {
330                builder.assert_felt_eq(*limb, *current_limb);
331            }
332            pc = current_public_values.next_pc;
333
334            // Verify that the timestamp is equal to the current one, then update.
335            for (limb, current_limb) in
336                current_timestamp.iter().zip_eq(current_public_values.initial_timestamp.iter())
337            {
338                builder.assert_felt_eq(*limb, *current_limb);
339            }
340            current_timestamp = current_public_values.last_timestamp;
341
342            // Verify that the init address is equal to the current one, then update.
343            for (limb, current_limb) in
344                init_addr.iter().zip_eq(current_public_values.previous_init_addr.iter())
345            {
346                builder.assert_felt_eq(*limb, *current_limb);
347            }
348            init_addr = current_public_values.last_init_addr;
349
350            // Verify that the finalize address is equal to the current one, then update.
351            for (limb, current_limb) in
352                finalize_addr.iter().zip_eq(current_public_values.previous_finalize_addr.iter())
353            {
354                builder.assert_felt_eq(*limb, *current_limb);
355            }
356            finalize_addr = current_public_values.last_finalize_addr;
357
358            // Verify that the init page index is equal to the current one, then update.
359            for (limb, current_limb) in
360                init_page_idx.iter().zip_eq(current_public_values.previous_init_page_idx.iter())
361            {
362                builder.assert_felt_eq(*limb, *current_limb);
363            }
364            init_page_idx = current_public_values.last_init_page_idx;
365
366            // Verify that the finalize page index is equal to the current one, then update.
367            for (limb, current_limb) in finalize_page_idx
368                .iter()
369                .zip_eq(current_public_values.previous_finalize_page_idx.iter())
370            {
371                builder.assert_felt_eq(*limb, *current_limb);
372            }
373            finalize_page_idx = current_public_values.last_finalize_page_idx;
374
375            // Assert that the start deferred digest is equal to the current one, then update.
376            for (digest, current_digest) in reconstruct_deferred_digest
377                .iter()
378                .zip_eq(current_public_values.start_reconstruct_deferred_digest.iter())
379            {
380                builder.assert_felt_eq(*digest, *current_digest);
381            }
382            reconstruct_deferred_digest = current_public_values.end_reconstruct_deferred_digest;
383
384            // Assert that the `prev_exit_code` is equal to the current one, then update.
385            builder.assert_felt_eq(current_exit_code, current_public_values.prev_exit_code);
386            current_exit_code = current_public_values.exit_code;
387
388            // Assert that the `prev_commit_syscall` is equal to the current one, then update.
389            builder.assert_felt_eq(commit_syscall, current_public_values.prev_commit_syscall);
390            commit_syscall = current_public_values.commit_syscall;
391
392            // Assert that `prev_commit_deferred_syscall` is equal to the current one, then update.
393            builder.assert_felt_eq(
394                commit_deferred_syscall,
395                current_public_values.prev_commit_deferred_syscall,
396            );
397            commit_deferred_syscall = current_public_values.commit_deferred_syscall;
398
399            // Assert that the sp1_vk digest is always the same.
400            for (digest, current) in
401                sp1_vk_digest.iter().zip_eq(current_public_values.sp1_vk_digest)
402            {
403                builder.assert_felt_eq(*digest, current);
404            }
405
406            // Assert that the `proof_nonce` is equal to the current one, then update.
407            for (limb, current_limb) in
408                proof_nonce.iter().zip_eq(current_public_values.proof_nonce.iter())
409            {
410                builder.assert_felt_eq(*limb, *current_limb);
411            }
412        }
413
414        // Range check the accumulated number of included shards.
415        C::range_check_felt(builder, num_included_shard, MAX_LOG_NUMBER_OF_SHARDS);
416
417        // Check that the `contains_first_shard` flag is boolean.
418        builder.assert_felt_eq(
419            contains_first_shard * (contains_first_shard - SP1Field::one()),
420            SP1Field::zero(),
421        );
422
423        // Sum all the global cumulative sum of the proofs.
424        let global_cumulative_sum = builder.sum_digest_v2(global_cumulative_sums);
425
426        // Update the global values from the last accumulated values.
427        // Set the `committed_value_digest`.
428        compress_public_values.committed_value_digest = committed_value_digest;
429        // Set the `deferred_proofs_digest`.
430        compress_public_values.deferred_proofs_digest = deferred_proofs_digest;
431        // Set next_pc to be the last pc.
432        compress_public_values.next_pc = pc;
433        // Set the timestamp to be the last timestamp.
434        compress_public_values.last_timestamp = current_timestamp;
435        // Set the MemoryInitialize address to be the last MemoryInitialize address.
436        compress_public_values.last_init_addr = init_addr;
437        // Set the MemoryFinalize address to be the last MemoryFinalize address.
438        compress_public_values.last_finalize_addr = finalize_addr;
439        // Set the PageProtInit address to be the last PageProtInit address.
440        compress_public_values.last_init_page_idx = init_page_idx;
441        // Set the PageProtFinalize address to be the last PageProtFinalize address.
442        compress_public_values.last_finalize_page_idx = finalize_page_idx;
443        // Set the start reconstruct deferred digest to be the last reconstruct deferred digest.
444        compress_public_values.end_reconstruct_deferred_digest = reconstruct_deferred_digest;
445        // Set the deferred proof index to be the last deferred proof index.
446        compress_public_values.deferred_proof = deferred_proof_index;
447        // Set sp1_vk digest to the one from the proof values.
448        compress_public_values.sp1_vk_digest = sp1_vk_digest;
449        // Reflect the vk root.
450        compress_public_values.vk_root = vk_root;
451        // Assign the cumulative sum.
452        compress_public_values.global_cumulative_sum = global_cumulative_sum;
453        // Assign the `contains_first_shard` flag.
454        compress_public_values.contains_first_shard = contains_first_shard;
455        // Assign the `num_included_shard` value.
456        compress_public_values.num_included_shard = num_included_shard;
457        // Assign the `is_complete` flag.
458        compress_public_values.is_complete = is_complete;
459        // Set the exit code.
460        compress_public_values.exit_code = current_exit_code;
461        // Set the `commit_syscall` flag.
462        compress_public_values.commit_syscall = commit_syscall;
463        // Set the `commit_deferred_syscall` flag.
464        compress_public_values.commit_deferred_syscall = commit_deferred_syscall;
465        compress_public_values.proof_nonce = proof_nonce;
466        // Set the digest according to the previous values.
467        compress_public_values.digest = match kind {
468            PublicValuesOutputDigest::Reduce => {
469                recursion_public_values_digest::<C, SP1GlobalContext>(
470                    builder,
471                    compress_public_values,
472                )
473            }
474            PublicValuesOutputDigest::Root => {
475                root_public_values_digest::<C, SP1GlobalContext>(builder, compress_public_values)
476            }
477        };
478
479        // If the proof is complete, make completeness assertions.
480        assert_complete(builder, compress_public_values, is_complete);
481
482        SP1GlobalContext::commit_recursion_public_values(builder, *compress_public_values);
483    }
484}