use std::array;
use std::borrow::{Borrow, BorrowMut};
use std::marker::PhantomData;
use crate::machine::utils::assert_complete;
use itertools::{izip, Itertools};
use p3_air::Air;
use p3_baby_bear::BabyBear;
use p3_commit::TwoAdicMultiplicativeCoset;
use p3_field::{AbstractField, PrimeField32, TwoAdicField};
use serde::{Deserialize, Serialize};
use sp1_core::air::{MachineAir, WORD_SIZE};
use sp1_core::air::{Word, POSEIDON_NUM_WORDS, PV_DIGEST_NUM_WORDS};
use sp1_core::stark::StarkMachine;
use sp1_core::stark::{Com, ShardProof, StarkGenericConfig, StarkVerifyingKey};
use sp1_core::utils::BabyBearPoseidon2;
use sp1_primitives::types::RecursionProgramType;
use sp1_recursion_compiler::config::InnerConfig;
use sp1_recursion_compiler::ir::{Array, Builder, Config, Felt, Var};
use sp1_recursion_compiler::prelude::DslVariable;
use sp1_recursion_core::air::{RecursionPublicValues, RECURSIVE_PROOF_NUM_PV_ELTS};
use sp1_recursion_core::runtime::{RecursionProgram, D, DIGEST_SIZE};
use sp1_recursion_compiler::prelude::*;
use crate::challenger::{CanObserveVariable, DuplexChallengerVariable};
use crate::fri::TwoAdicFriPcsVariable;
use crate::hints::Hintable;
use crate::stark::{RecursiveVerifierConstraintFolder, StarkVerifier};
use crate::types::ShardProofVariable;
use crate::types::VerifyingKeyVariable;
use crate::utils::{
assert_challenger_eq_pv, assign_challenger_from_pv, const_fri_config, felt2var,
get_challenger_public_values, hash_vkey,
};
use super::utils::{commit_public_values, proof_data_from_vk, verify_public_values_hash};
#[derive(Debug, Clone, Copy)]
pub struct SP1CompressVerifier<C: Config, SC: StarkGenericConfig, A> {
_phantom: PhantomData<(C, SC, A)>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum ReduceProgramType {
Core = 0,
Deferred = 1,
Reduce = 2,
}
pub struct SP1CompressMemoryLayout<'a, SC: StarkGenericConfig, A: MachineAir<SC::Val>> {
pub compress_vk: &'a StarkVerifyingKey<SC>,
pub recursive_machine: &'a StarkMachine<SC, A>,
pub shard_proofs: Vec<ShardProof<SC>>,
pub is_complete: bool,
pub kinds: Vec<ReduceProgramType>,
}
#[derive(DslVariable, Clone)]
pub struct SP1CompressMemoryLayoutVariable<C: Config> {
pub compress_vk: VerifyingKeyVariable<C>,
pub shard_proofs: Array<C, ShardProofVariable<C>>,
pub kinds: Array<C, Var<C::N>>,
pub is_complete: Var<C::N>,
}
impl<A> SP1CompressVerifier<InnerConfig, BabyBearPoseidon2, A>
where
A: MachineAir<BabyBear> + for<'a> Air<RecursiveVerifierConstraintFolder<'a, InnerConfig>>,
{
pub fn build(
machine: &StarkMachine<BabyBearPoseidon2, A>,
recursive_vk: &StarkVerifyingKey<BabyBearPoseidon2>,
deferred_vk: &StarkVerifyingKey<BabyBearPoseidon2>,
) -> RecursionProgram<BabyBear> {
let mut builder = Builder::<InnerConfig>::new(RecursionProgramType::Compress);
let input: SP1CompressMemoryLayoutVariable<_> = builder.uninit();
SP1CompressMemoryLayout::<BabyBearPoseidon2, A>::witness(&input, &mut builder);
let pcs = TwoAdicFriPcsVariable {
config: const_fri_config(&mut builder, machine.config().pcs().fri_config()),
};
SP1CompressVerifier::verify(
&mut builder,
&pcs,
machine,
input,
recursive_vk,
deferred_vk,
);
builder.halt();
builder.compile_program()
}
}
impl<C: Config, SC, A> SP1CompressVerifier<C, SC, A>
where
C::F: PrimeField32 + TwoAdicField,
SC: StarkGenericConfig<
Val = C::F,
Challenge = C::EF,
Domain = TwoAdicMultiplicativeCoset<C::F>,
>,
A: MachineAir<C::F> + for<'a> Air<RecursiveVerifierConstraintFolder<'a, C>>,
Com<SC>: Into<[SC::Val; DIGEST_SIZE]>,
{
pub fn verify(
builder: &mut Builder<C>,
pcs: &TwoAdicFriPcsVariable<C>,
machine: &StarkMachine<SC, A>,
input: SP1CompressMemoryLayoutVariable<C>,
recursive_vk: &StarkVerifyingKey<SC>,
deferred_vk: &StarkVerifyingKey<SC>,
) {
let SP1CompressMemoryLayoutVariable {
compress_vk,
shard_proofs,
kinds,
is_complete,
} = input;
let mut reduce_public_values_stream: Vec<Felt<_>> = (0..RECURSIVE_PROOF_NUM_PV_ELTS)
.map(|_| builder.uninit())
.collect();
let reduce_public_values: &mut RecursionPublicValues<_> =
reduce_public_values_stream.as_mut_slice().borrow_mut();
let compress_vk_digest = hash_vkey(builder, &compress_vk);
reduce_public_values.compress_vk_digest =
array::from_fn(|i| builder.get(&compress_vk_digest, i));
builder.assert_usize_ne(shard_proofs.len(), 0);
builder.assert_usize_eq(shard_proofs.len(), kinds.len());
let sp1_vk_digest: [Felt<_>; DIGEST_SIZE] = array::from_fn(|_| builder.uninit());
let pc: Felt<_> = builder.uninit();
let shard: Felt<_> = builder.uninit();
let execution_shard: Felt<_> = builder.uninit();
let mut initial_reconstruct_challenger = DuplexChallengerVariable::new(builder);
let mut reconstruct_challenger = DuplexChallengerVariable::new(builder);
let mut leaf_challenger = DuplexChallengerVariable::new(builder);
let committed_value_digest: [Word<Felt<_>>; PV_DIGEST_NUM_WORDS] =
array::from_fn(|_| Word(array::from_fn(|_| builder.uninit())));
let deferred_proofs_digest: [Felt<_>; POSEIDON_NUM_WORDS] =
array::from_fn(|_| builder.uninit());
let reconstruct_deferred_digest: [Felt<_>; POSEIDON_NUM_WORDS] =
core::array::from_fn(|_| builder.uninit());
let cumulative_sum: [Felt<_>; D] = core::array::from_fn(|_| builder.eval(C::F::zero()));
let init_addr_bits: [Felt<_>; 32] = core::array::from_fn(|_| builder.uninit());
let finalize_addr_bits: [Felt<_>; 32] = core::array::from_fn(|_| builder.uninit());
let recursive_vk_variable = proof_data_from_vk(builder, recursive_vk, machine);
let deferred_vk_variable = proof_data_from_vk(builder, deferred_vk, machine);
let core_kind = C::N::from_canonical_u32(ReduceProgramType::Core as u32);
let deferred_kind = C::N::from_canonical_u32(ReduceProgramType::Deferred as u32);
let reduce_kind = C::N::from_canonical_u32(ReduceProgramType::Reduce as u32);
builder.range(0, shard_proofs.len()).for_each(|i, builder| {
let proof = builder.get(&shard_proofs, i);
let kind = builder.get(&kinds, i);
let vk: VerifyingKeyVariable<_> = builder.uninit();
builder.if_eq(kind, core_kind).then_or_else(
|builder| {
builder.assign(vk.clone(), recursive_vk_variable.clone());
},
|builder| {
builder.if_eq(kind, deferred_kind).then_or_else(
|builder| {
builder.assign(vk.clone(), deferred_vk_variable.clone());
},
|builder| {
builder.if_eq(kind, reduce_kind).then_or_else(
|builder| {
builder.assign(vk.clone(), compress_vk.clone());
},
|builder| {
builder.error();
},
);
},
);
},
);
let mut challenger = DuplexChallengerVariable::new(builder);
challenger.observe(builder, vk.commitment.clone());
challenger.observe(builder, vk.pc_start);
challenger.observe(builder, proof.commitment.main_commit.clone());
for j in 0..machine.num_pv_elts() {
let element = builder.get(&proof.public_values, j);
challenger.observe(builder, element);
}
StarkVerifier::<C, SC>::verify_shard(
builder,
&vk,
pcs,
machine,
&mut challenger,
&proof,
true,
);
let current_public_values_elements = (0..RECURSIVE_PROOF_NUM_PV_ELTS)
.map(|i| builder.get(&proof.public_values, i))
.collect::<Vec<Felt<_>>>();
let current_public_values: &RecursionPublicValues<Felt<C::F>> =
current_public_values_elements.as_slice().borrow();
verify_public_values_hash(builder, current_public_values);
builder.if_eq(i, C::N::zero()).then(|builder| {
for (digest, current_digest, global_digest) in izip!(
reconstruct_deferred_digest.iter(),
current_public_values
.start_reconstruct_deferred_digest
.iter(),
reduce_public_values
.start_reconstruct_deferred_digest
.iter()
) {
builder.assign(*digest, *current_digest);
builder.assign(*global_digest, *current_digest);
}
for (digest, first_digest) in sp1_vk_digest
.iter()
.zip(current_public_values.sp1_vk_digest)
{
builder.assign(*digest, first_digest);
}
builder.assign(
reduce_public_values.start_pc,
current_public_values.start_pc,
);
builder.assign(pc, current_public_values.start_pc);
builder.assign(shard, current_public_values.start_shard);
builder.assign(
reduce_public_values.start_shard,
current_public_values.start_shard,
);
builder.assign(execution_shard, current_public_values.start_execution_shard);
builder.assign(
reduce_public_values.start_execution_shard,
current_public_values.start_execution_shard,
);
for (bit, (first_bit, current_bit)) in init_addr_bits.iter().zip(
reduce_public_values
.previous_init_addr_bits
.iter()
.zip(current_public_values.previous_init_addr_bits.iter()),
) {
builder.assign(*bit, *current_bit);
builder.assign(*first_bit, *current_bit);
}
for (bit, (first_bit, current_bit)) in finalize_addr_bits.iter().zip(
reduce_public_values
.previous_finalize_addr_bits
.iter()
.zip(current_public_values.previous_finalize_addr_bits.iter()),
) {
builder.assign(*bit, *current_bit);
builder.assign(*first_bit, *current_bit);
}
assign_challenger_from_pv(
builder,
&mut leaf_challenger,
current_public_values.leaf_challenger,
);
assign_challenger_from_pv(
builder,
&mut initial_reconstruct_challenger,
current_public_values.start_reconstruct_challenger,
);
assign_challenger_from_pv(
builder,
&mut reconstruct_challenger,
current_public_values.start_reconstruct_challenger,
);
for (word, current_word) in committed_value_digest
.iter()
.zip_eq(current_public_values.committed_value_digest.iter())
{
for (byte, current_byte) in word.0.iter().zip_eq(current_word.0.iter()) {
builder.assign(*byte, *current_byte);
}
}
for (digest, current_digest) in deferred_proofs_digest
.iter()
.zip_eq(current_public_values.deferred_proofs_digest.iter())
{
builder.assign(*digest, *current_digest);
}
});
for (digest, current_digest) in reconstruct_deferred_digest.iter().zip_eq(
current_public_values
.start_reconstruct_deferred_digest
.iter(),
) {
builder.assert_felt_eq(*digest, *current_digest);
}
for (digest, current) in sp1_vk_digest
.iter()
.zip(current_public_values.sp1_vk_digest)
{
builder.assert_felt_eq(*digest, current);
}
builder.assert_felt_eq(pc, current_public_values.start_pc);
builder.assert_felt_eq(shard, current_public_values.start_shard);
builder.assert_felt_eq(execution_shard, current_public_values.start_execution_shard);
for (bit, current_bit) in init_addr_bits
.iter()
.zip(current_public_values.previous_init_addr_bits.iter())
{
builder.assert_felt_eq(*bit, *current_bit);
}
for (bit, current_bit) in finalize_addr_bits
.iter()
.zip(current_public_values.previous_finalize_addr_bits.iter())
{
builder.assert_felt_eq(*bit, *current_bit);
}
assert_challenger_eq_pv(
builder,
&leaf_challenger,
current_public_values.leaf_challenger,
);
assert_challenger_eq_pv(
builder,
&reconstruct_challenger,
current_public_values.start_reconstruct_challenger,
);
{
let is_zero: Var<_> = builder.eval(C::N::one());
#[allow(clippy::needless_range_loop)]
for i in 0..committed_value_digest.len() {
for j in 0..WORD_SIZE {
let d = felt2var(builder, committed_value_digest[i][j]);
builder.if_ne(d, C::N::zero()).then(|builder| {
builder.assign(is_zero, C::N::zero());
});
}
}
builder.if_eq(is_zero, C::N::zero()).then(|builder| {
#[allow(clippy::needless_range_loop)]
for i in 0..committed_value_digest.len() {
for j in 0..WORD_SIZE {
builder.assert_felt_eq(
committed_value_digest[i][j],
current_public_values.committed_value_digest[i][j],
);
}
}
});
#[allow(clippy::needless_range_loop)]
for i in 0..committed_value_digest.len() {
for j in 0..WORD_SIZE {
builder.assign(
committed_value_digest[i][j],
current_public_values.committed_value_digest[i][j],
);
}
}
let is_zero: Var<_> = builder.eval(C::N::one());
#[allow(clippy::needless_range_loop)]
for i in 0..deferred_proofs_digest.len() {
let d = felt2var(builder, deferred_proofs_digest[i]);
builder.if_ne(d, C::N::zero()).then(|builder| {
builder.assign(is_zero, C::N::zero());
});
}
builder.if_eq(is_zero, C::N::zero()).then(|builder| {
#[allow(clippy::needless_range_loop)]
for i in 0..deferred_proofs_digest.len() {
builder.assert_felt_eq(
deferred_proofs_digest[i],
current_public_values.deferred_proofs_digest[i],
);
}
});
#[allow(clippy::needless_range_loop)]
for i in 0..deferred_proofs_digest.len() {
builder.assign(
deferred_proofs_digest[i],
current_public_values.deferred_proofs_digest[i],
);
}
}
for (digest, current_digest) in reconstruct_deferred_digest
.iter()
.zip_eq(current_public_values.end_reconstruct_deferred_digest.iter())
{
builder.assign(*digest, *current_digest);
}
builder.assign(pc, current_public_values.next_pc);
builder.assign(shard, current_public_values.next_shard);
builder.assign(execution_shard, current_public_values.next_execution_shard);
for (bit, next_bit) in init_addr_bits
.iter()
.zip(current_public_values.last_init_addr_bits.iter())
{
builder.assign(*bit, *next_bit);
}
for (bit, next_bit) in finalize_addr_bits
.iter()
.zip(current_public_values.last_finalize_addr_bits.iter())
{
builder.assign(*bit, *next_bit);
}
assign_challenger_from_pv(
builder,
&mut reconstruct_challenger,
current_public_values.end_reconstruct_challenger,
);
for (sum_element, current_sum_element) in cumulative_sum
.iter()
.zip_eq(current_public_values.cumulative_sum.iter())
{
builder.assign(*sum_element, *sum_element + *current_sum_element);
}
});
reduce_public_values.sp1_vk_digest = sp1_vk_digest;
reduce_public_values.next_pc = pc;
reduce_public_values.next_shard = shard;
reduce_public_values.next_execution_shard = execution_shard;
reduce_public_values.last_init_addr_bits = init_addr_bits;
reduce_public_values.last_finalize_addr_bits = finalize_addr_bits;
let values = get_challenger_public_values(builder, &leaf_challenger);
reduce_public_values.leaf_challenger = values;
let values = get_challenger_public_values(builder, &initial_reconstruct_challenger);
reduce_public_values.start_reconstruct_challenger = values;
let values = get_challenger_public_values(builder, &reconstruct_challenger);
reduce_public_values.end_reconstruct_challenger = values;
reduce_public_values.end_reconstruct_deferred_digest = reconstruct_deferred_digest;
reduce_public_values.deferred_proofs_digest = deferred_proofs_digest;
reduce_public_values.committed_value_digest = committed_value_digest;
reduce_public_values.cumulative_sum = cumulative_sum;
builder.if_eq(is_complete, C::N::one()).then_or_else(
|builder| {
builder.assign(reduce_public_values.is_complete, C::F::one());
assert_complete(builder, reduce_public_values, &reconstruct_challenger)
},
|builder| {
builder.assert_var_eq(is_complete, C::N::zero());
builder.assign(reduce_public_values.is_complete, C::F::zero());
},
);
commit_public_values(builder, reduce_public_values);
}
}