Skip to main content

sp1_recursion_circuit/machine/
compress.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4    marker::PhantomData,
5    mem::MaybeUninit,
6};
7
8use itertools::Itertools;
9
10use slop_air::Air;
11
12use slop_algebra::{AbstractField, PrimeField32};
13use slop_challenger::IopCtx;
14
15use serde::{Deserialize, Serialize};
16use sp1_core_machine::riscv::MAX_LOG_NUMBER_OF_SHARDS;
17use sp1_recursion_compiler::ir::{Builder, Felt, IrIter};
18
19use sp1_primitives::{SP1Field, SP1GlobalContext};
20use sp1_recursion_executor::{RecursionPublicValues, RECURSIVE_PROOF_NUM_PV_ELTS};
21
22use sp1_hypercube::{
23    air::{MachineAir, ShardRange, POSEIDON_NUM_WORDS, PV_DIGEST_NUM_WORDS},
24    MachineVerifyingKey, ShardProof, DIGEST_SIZE,
25};
26
27use crate::{
28    challenger::CanObserveVariable,
29    machine::{
30        assert_complete, assert_recursion_public_values_valid, recursion_public_values_digest,
31        root_public_values_digest,
32    },
33    shard::{MachineVerifyingKeyVariable, RecursiveShardVerifier, ShardProofVariable},
34    zerocheck::RecursiveVerifierConstraintFolder,
35    CircuitConfig, SP1FieldConfigVariable,
36};
37
38use sp1_recursion_compiler::circuit::CircuitV2Builder;
39
40use super::InnerVal;
41/// A program to verify a batch of recursive proofs and aggregate their public values.
42#[derive(Debug, Clone, Copy)]
43pub struct SP1CompressVerifier<C, SC, A> {
44    _phantom: PhantomData<(C, SC, A)>,
45}
46
47pub enum PublicValuesOutputDigest {
48    Reduce,
49    Root,
50}
51
52/// Witness layout for the compress stage verifier.
53#[allow(clippy::type_complexity)]
54pub struct SP1ShapedWitnessVariable<C: CircuitConfig, GC: SP1FieldConfigVariable<C>> {
55    /// The shard proofs to verify.
56    pub vks_and_proofs: Vec<(MachineVerifyingKeyVariable<C, GC>, ShardProofVariable<C, GC>)>,
57    pub is_complete: Felt<SP1Field>,
58}
59
60pub type VkAndProof<GC, Proof> = (MachineVerifyingKey<GC>, ShardProof<GC, Proof>);
61
62#[derive(Clone, Serialize, Deserialize)]
63#[serde(bound(serialize = "ShardProof<GC,Proof>: Serialize"))]
64#[serde(bound(deserialize = "ShardProof<GC,Proof>: Deserialize<'de>"))]
65/// An input layout for the shard proofs that have been normalized to a standard shape.
66pub struct SP1ShapedWitnessValues<GC: IopCtx, Proof> {
67    pub vks_and_proofs: Vec<VkAndProof<GC, Proof>>,
68    pub is_complete: bool,
69}
70
71impl<GC: IopCtx, Proof> SP1ShapedWitnessValues<GC, Proof> {
72    pub fn range(&self) -> ShardRange
73    where
74        GC::F: PrimeField32,
75    {
76        let start_pv: &RecursionPublicValues<GC::F> =
77            self.vks_and_proofs[0].1.public_values.as_slice().borrow();
78        let end_pv: &RecursionPublicValues<GC::F> =
79            self.vks_and_proofs[self.vks_and_proofs.len() - 1].1.public_values.as_slice().borrow();
80
81        let start = start_pv.range().start();
82        let end = end_pv.range().end();
83
84        (start..end).into()
85    }
86}
87
88impl<C, SC, A> SP1CompressVerifier<C, SC, A>
89where
90    C: CircuitConfig<Bit = Felt<SP1Field>>,
91    A: MachineAir<InnerVal> + for<'a> Air<RecursiveVerifierConstraintFolder<'a>>,
92{
93    /// Verify a batch of recursive proofs and aggregate their public values.
94    ///
95    /// The compression verifier can aggregate proofs of different kinds:
96    /// - Core proofs: proofs which are recursive proof of a batch of SP1 shard proofs. The
97    ///   implementation in this function assumes a fixed recursive verifier specified by
98    ///   `recursive_vk`.
99    /// - Deferred proofs: proofs which are recursive proof of a batch of deferred proofs. The
100    ///   implementation in this function assumes a fixed deferred verification program specified by
101    ///   `deferred_vk`.
102    /// - Compress proofs: these are proofs which refer to a prove of this program. The key for it
103    ///   is part of public values will be propagated across all levels of recursion and will be
104    ///   checked against itself as in [sp1_prover::Prover] or as in [super::SP1RootVerifier].
105    pub fn verify(
106        builder: &mut Builder<C>,
107        machine: &RecursiveShardVerifier<SP1GlobalContext, A, C>,
108        input: SP1ShapedWitnessVariable<C, SP1GlobalContext>,
109        vk_root: [Felt<SP1Field>; DIGEST_SIZE],
110        kind: PublicValuesOutputDigest,
111    ) {
112        // Read input.
113        let SP1ShapedWitnessVariable { vks_and_proofs, is_complete } = input;
114
115        // Initialize the values for the aggregated public output.
116        let mut reduce_public_values_stream: Vec<Felt<_>> = (0..RECURSIVE_PROOF_NUM_PV_ELTS)
117            .map(|_| unsafe { MaybeUninit::zeroed().assume_init() })
118            .collect();
119        let compress_public_values: &mut RecursionPublicValues<_> =
120            reduce_public_values_stream.as_mut_slice().borrow_mut();
121
122        // Make sure there is at least one proof.
123        assert!(!vks_and_proofs.is_empty());
124
125        // Initialize the consistency check variables.
126        let mut sp1_vk_digest: [Felt<_>; DIGEST_SIZE] =
127            array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
128        let mut pc: [Felt<_>; 3] =
129            array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
130        let mut current_exit_code: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() };
131        let mut current_timestamp: [Felt<_>; 4] = array::from_fn(|_| builder.uninit());
132
133        let mut committed_value_digest: [[Felt<_>; 4]; PV_DIGEST_NUM_WORDS] =
134            array::from_fn(|_| array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() }));
135        let mut deferred_proofs_digest: [Felt<_>; POSEIDON_NUM_WORDS] =
136            array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
137        let mut deferred_proof_index: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() };
138        let mut reconstruct_deferred_digest: [Felt<_>; POSEIDON_NUM_WORDS] =
139            core::array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
140        let mut global_cumulative_sums = Vec::new();
141        let mut init_addr: [Felt<_>; 3] =
142            array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
143        let mut finalize_addr: [Felt<_>; 3] =
144            array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
145        let mut init_page_idx: [Felt<_>; 3] =
146            array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
147        let mut finalize_page_idx: [Felt<_>; 3] =
148            array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
149        let mut commit_syscall: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() };
150        let mut commit_deferred_syscall: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() };
151        let mut contains_first_shard: Felt<_> = builder.eval(SP1Field::zero());
152        let mut num_included_shard: Felt<_> = builder.eval(SP1Field::zero());
153        let mut proof_nonce: [Felt<_>; 4] =
154            array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
155
156        // Verify the shard proofs.
157        // Verification of proofs can be done in parallel but the aggregation/consistency checks
158        // must be done sequentially.
159        vks_and_proofs.iter().ir_par_map_collect::<Vec<_>, _, _>(
160            builder,
161            |builder, (vk, shard_proof)| {
162                // Prepare a challenger.
163                let mut challenger = SP1GlobalContext::challenger_variable(builder);
164
165                // Observe the vk and start pc.
166                challenger.observe(builder, vk.preprocessed_commit);
167                challenger.observe_slice(builder, vk.pc_start);
168                challenger.observe_slice(builder, vk.initial_global_cumulative_sum.0.x.0);
169                challenger.observe_slice(builder, vk.initial_global_cumulative_sum.0.y.0);
170                challenger.observe(builder, vk.enable_untrusted_programs);
171                // Observe the padding.
172                let zero: Felt<_> = builder.eval(SP1Field::zero());
173                for _ in 0..6 {
174                    challenger.observe(builder, zero);
175                }
176                // Verify the shard proof.
177                machine.verify_shard(builder, vk, shard_proof, &mut challenger);
178            },
179        );
180
181        // Check consistency and aggregate public values.
182        for (i, (_, shard_proof)) in vks_and_proofs.into_iter().enumerate() {
183            // Get the current public values.
184            let current_public_values: &RecursionPublicValues<Felt<SP1Field>> =
185                shard_proof.public_values.as_slice().borrow();
186            // Assert that the public values are valid.
187            assert_recursion_public_values_valid::<C, SP1GlobalContext>(
188                builder,
189                current_public_values,
190            );
191            // Assert that the vk root is the same as the witnessed one.
192            for (expected, actual) in vk_root.iter().zip_eq(current_public_values.vk_root.iter()) {
193                builder.assert_felt_eq(*expected, *actual);
194            }
195
196            // Verify that there are less than `(1 << MAX_LOG_NUMBER_OF_SHARDS)` included shards.
197            C::range_check_felt(
198                builder,
199                current_public_values.num_included_shard,
200                MAX_LOG_NUMBER_OF_SHARDS,
201            );
202
203            // Verify that `contains_first_shard` is boolean.
204            builder.assert_felt_eq(
205                current_public_values.contains_first_shard
206                    * (current_public_values.contains_first_shard - SP1Field::one()),
207                SP1Field::zero(),
208            );
209
210            // Accumulate the number of included shards.
211            num_included_shard =
212                builder.eval(num_included_shard + current_public_values.num_included_shard);
213
214            // Accumulate the `contains_first_shard` flag.
215            contains_first_shard =
216                builder.eval(contains_first_shard + current_public_values.contains_first_shard);
217
218            // Add the global cumulative sums to the vector.
219            global_cumulative_sums.push(current_public_values.global_cumulative_sum);
220
221            if i == 0 {
222                // Initialize global and accumulated values.
223
224                // Assign the committed values and deferred proof digests.
225                compress_public_values.prev_committed_value_digest =
226                    current_public_values.prev_committed_value_digest;
227                committed_value_digest = current_public_values.prev_committed_value_digest;
228
229                compress_public_values.prev_deferred_proofs_digest =
230                    current_public_values.prev_deferred_proofs_digest;
231                deferred_proofs_digest = current_public_values.prev_deferred_proofs_digest;
232
233                // Initialize the deferred proof index.
234                compress_public_values.prev_deferred_proof =
235                    current_public_values.prev_deferred_proof;
236                deferred_proof_index = current_public_values.prev_deferred_proof;
237
238                // Initiallize start pc.
239                compress_public_values.pc_start = current_public_values.pc_start;
240                pc = current_public_values.pc_start;
241
242                // Initialize timestamp.
243                compress_public_values.initial_timestamp = current_public_values.initial_timestamp;
244                current_timestamp = current_public_values.initial_timestamp;
245
246                // Initialize the MemoryInitialize address.
247                compress_public_values.previous_init_addr =
248                    current_public_values.previous_init_addr;
249                init_addr = current_public_values.previous_init_addr;
250
251                // Initialize the MemoryFinalize address.
252                compress_public_values.previous_finalize_addr =
253                    current_public_values.previous_finalize_addr;
254                finalize_addr = current_public_values.previous_finalize_addr;
255
256                // Initialize the PageProtInit address.
257                compress_public_values.previous_init_page_idx =
258                    current_public_values.previous_init_page_idx;
259                init_page_idx = current_public_values.previous_init_page_idx;
260
261                // Initialize the PageProtFinalize address.
262                compress_public_values.previous_finalize_page_idx =
263                    current_public_values.previous_finalize_page_idx;
264                finalize_page_idx = current_public_values.previous_finalize_page_idx;
265
266                // Initialize the start of deferred digests.
267                compress_public_values.start_reconstruct_deferred_digest =
268                    current_public_values.start_reconstruct_deferred_digest;
269                reconstruct_deferred_digest =
270                    current_public_values.start_reconstruct_deferred_digest;
271
272                // Initialize exit code.
273                compress_public_values.prev_exit_code = current_public_values.prev_exit_code;
274                current_exit_code = current_public_values.prev_exit_code;
275
276                // Initialize `commit_syscall`.
277                compress_public_values.prev_commit_syscall =
278                    current_public_values.prev_commit_syscall;
279                commit_syscall = current_public_values.prev_commit_syscall;
280
281                // Initialize `commit_deferred_syscall`.
282                compress_public_values.prev_commit_deferred_syscall =
283                    current_public_values.prev_commit_deferred_syscall;
284                commit_deferred_syscall = current_public_values.prev_commit_deferred_syscall;
285
286                // Initialize the sp1_vk digest
287                compress_public_values.sp1_vk_digest = current_public_values.sp1_vk_digest;
288                sp1_vk_digest = current_public_values.sp1_vk_digest;
289
290                // Initialize the proof nonce.
291                compress_public_values.proof_nonce = current_public_values.proof_nonce;
292                proof_nonce = current_public_values.proof_nonce;
293            }
294
295            // Assert that the current values match the accumulated values and update them.
296
297            // Assert that the `prev_committed_value_digest` is equal to current one, then update.
298            for (word, current_word) in committed_value_digest
299                .iter()
300                .zip_eq(current_public_values.prev_committed_value_digest.iter())
301            {
302                for (limb, current_limb) in word.iter().zip_eq(current_word.iter()) {
303                    builder.assert_felt_eq(*limb, *current_limb);
304                }
305            }
306            committed_value_digest = current_public_values.committed_value_digest;
307
308            // Assert that the `prev_deferred_proofs_digest` is equal to current one, then update.
309            for (limb, current_limb) in deferred_proofs_digest
310                .iter()
311                .zip_eq(current_public_values.prev_deferred_proofs_digest.iter())
312            {
313                builder.assert_felt_eq(*limb, *current_limb);
314            }
315            deferred_proofs_digest = current_public_values.deferred_proofs_digest;
316
317            // Assert that the `prev_deferred_proof` is equal to the current one, then update.
318            builder.assert_felt_eq(deferred_proof_index, current_public_values.prev_deferred_proof);
319            deferred_proof_index = current_public_values.deferred_proof;
320
321            // Assert that the start pc is equal to the current pc, then update.
322            for (limb, current_limb) in pc.iter().zip_eq(current_public_values.pc_start.iter()) {
323                builder.assert_felt_eq(*limb, *current_limb);
324            }
325            pc = current_public_values.next_pc;
326
327            // Verify that the timestamp is equal to the current one, then update.
328            for (limb, current_limb) in
329                current_timestamp.iter().zip_eq(current_public_values.initial_timestamp.iter())
330            {
331                builder.assert_felt_eq(*limb, *current_limb);
332            }
333            current_timestamp = current_public_values.last_timestamp;
334
335            // Verify that the init address is equal to the current one, then update.
336            for (limb, current_limb) in
337                init_addr.iter().zip_eq(current_public_values.previous_init_addr.iter())
338            {
339                builder.assert_felt_eq(*limb, *current_limb);
340            }
341            init_addr = current_public_values.last_init_addr;
342
343            // Verify that the finalize address is equal to the current one, then update.
344            for (limb, current_limb) in
345                finalize_addr.iter().zip_eq(current_public_values.previous_finalize_addr.iter())
346            {
347                builder.assert_felt_eq(*limb, *current_limb);
348            }
349            finalize_addr = current_public_values.last_finalize_addr;
350
351            // Verify that the init page index is equal to the current one, then update.
352            for (limb, current_limb) in
353                init_page_idx.iter().zip_eq(current_public_values.previous_init_page_idx.iter())
354            {
355                builder.assert_felt_eq(*limb, *current_limb);
356            }
357            init_page_idx = current_public_values.last_init_page_idx;
358
359            // Verify that the finalize page index is equal to the current one, then update.
360            for (limb, current_limb) in finalize_page_idx
361                .iter()
362                .zip_eq(current_public_values.previous_finalize_page_idx.iter())
363            {
364                builder.assert_felt_eq(*limb, *current_limb);
365            }
366            finalize_page_idx = current_public_values.last_finalize_page_idx;
367
368            // Assert that the start deferred digest is equal to the current one, then update.
369            for (digest, current_digest) in reconstruct_deferred_digest
370                .iter()
371                .zip_eq(current_public_values.start_reconstruct_deferred_digest.iter())
372            {
373                builder.assert_felt_eq(*digest, *current_digest);
374            }
375            reconstruct_deferred_digest = current_public_values.end_reconstruct_deferred_digest;
376
377            // Assert that the `prev_exit_code` is equal to the current one, then update.
378            builder.assert_felt_eq(current_exit_code, current_public_values.prev_exit_code);
379            current_exit_code = current_public_values.exit_code;
380
381            // Assert that the `prev_commit_syscall` is equal to the current one, then update.
382            builder.assert_felt_eq(commit_syscall, current_public_values.prev_commit_syscall);
383            commit_syscall = current_public_values.commit_syscall;
384
385            // Assert that `prev_commit_deferred_syscall` is equal to the current one, then update.
386            builder.assert_felt_eq(
387                commit_deferred_syscall,
388                current_public_values.prev_commit_deferred_syscall,
389            );
390            commit_deferred_syscall = current_public_values.commit_deferred_syscall;
391
392            // Assert that the sp1_vk digest is always the same.
393            for (digest, current) in
394                sp1_vk_digest.iter().zip_eq(current_public_values.sp1_vk_digest)
395            {
396                builder.assert_felt_eq(*digest, current);
397            }
398
399            // Assert that the `proof_nonce` is equal to the current one, then update.
400            for (limb, current_limb) in
401                proof_nonce.iter().zip_eq(current_public_values.proof_nonce.iter())
402            {
403                builder.assert_felt_eq(*limb, *current_limb);
404            }
405        }
406
407        // Range check the accumulated number of included shards.
408        C::range_check_felt(builder, num_included_shard, MAX_LOG_NUMBER_OF_SHARDS);
409
410        // Check that the `contains_first_shard` flag is boolean.
411        builder.assert_felt_eq(
412            contains_first_shard * (contains_first_shard - SP1Field::one()),
413            SP1Field::zero(),
414        );
415
416        // Sum all the global cumulative sum of the proofs.
417        let global_cumulative_sum = builder.sum_digest_v2(global_cumulative_sums);
418
419        // Update the global values from the last accumulated values.
420        // Set the `committed_value_digest`.
421        compress_public_values.committed_value_digest = committed_value_digest;
422        // Set the `deferred_proofs_digest`.
423        compress_public_values.deferred_proofs_digest = deferred_proofs_digest;
424        // Set next_pc to be the last pc.
425        compress_public_values.next_pc = pc;
426        // Set the timestamp to be the last timestamp.
427        compress_public_values.last_timestamp = current_timestamp;
428        // Set the MemoryInitialize address to be the last MemoryInitialize address.
429        compress_public_values.last_init_addr = init_addr;
430        // Set the MemoryFinalize address to be the last MemoryFinalize address.
431        compress_public_values.last_finalize_addr = finalize_addr;
432        // Set the PageProtInit address to be the last PageProtInit address.
433        compress_public_values.last_init_page_idx = init_page_idx;
434        // Set the PageProtFinalize address to be the last PageProtFinalize address.
435        compress_public_values.last_finalize_page_idx = finalize_page_idx;
436        // Set the start reconstruct deferred digest to be the last reconstruct deferred digest.
437        compress_public_values.end_reconstruct_deferred_digest = reconstruct_deferred_digest;
438        // Set the deferred proof index to be the last deferred proof index.
439        compress_public_values.deferred_proof = deferred_proof_index;
440        // Set sp1_vk digest to the one from the proof values.
441        compress_public_values.sp1_vk_digest = sp1_vk_digest;
442        // Reflect the vk root.
443        compress_public_values.vk_root = vk_root;
444        // Assign the cumulative sum.
445        compress_public_values.global_cumulative_sum = global_cumulative_sum;
446        // Assign the `contains_first_shard` flag.
447        compress_public_values.contains_first_shard = contains_first_shard;
448        // Assign the `num_included_shard` value.
449        compress_public_values.num_included_shard = num_included_shard;
450        // Assign the `is_complete` flag.
451        compress_public_values.is_complete = is_complete;
452        // Set the exit code.
453        compress_public_values.exit_code = current_exit_code;
454        // Set the `commit_syscall` flag.
455        compress_public_values.commit_syscall = commit_syscall;
456        // Set the `commit_deferred_syscall` flag.
457        compress_public_values.commit_deferred_syscall = commit_deferred_syscall;
458        compress_public_values.proof_nonce = proof_nonce;
459        // Set the digest according to the previous values.
460        compress_public_values.digest = match kind {
461            PublicValuesOutputDigest::Reduce => {
462                recursion_public_values_digest::<C, SP1GlobalContext>(
463                    builder,
464                    compress_public_values,
465                )
466            }
467            PublicValuesOutputDigest::Root => {
468                root_public_values_digest::<C, SP1GlobalContext>(builder, compress_public_values)
469            }
470        };
471
472        // If the proof is complete, make completeness assertions.
473        assert_complete(builder, compress_public_values, is_complete);
474
475        SP1GlobalContext::commit_recursion_public_values(builder, *compress_public_values);
476    }
477}