use p3_baby_bear::{BabyBear, DiffusionMatrixBabyBear};
use p3_commit::{ExtensionMmcs, TwoAdicMultiplicativeCoset};
use p3_field::{extension::BinomialExtensionField, AbstractField, Field, TwoAdicField};
use p3_fri::FriConfig;
use p3_merkle_tree::FieldMerkleTreeMmcs;
use p3_poseidon2::{Poseidon2, Poseidon2ExternalMatrixGeneral};
use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation};
use sp1_recursion_compiler::{
asm::AsmConfig,
ir::{Array, Builder, Config, Felt, MemVariable, Var},
};
use sp1_recursion_core::{
air::ChallengerPublicValues,
runtime::{DIGEST_SIZE, PERMUTATION_WIDTH},
};
use sp1_stark::{
air::MachineAir, baby_bear_poseidon2::BabyBearPoseidon2, Dom, ShardProof, StarkGenericConfig,
StarkMachine, StarkVerifyingKey,
};
use crate::{
challenger::DuplexChallengerVariable,
fri::{types::FriConfigVariable, TwoAdicMultiplicativeCosetVariable},
stark::EMPTY,
types::{QuotientDataValues, VerifyingKeyVariable},
};
type SC = BabyBearPoseidon2;
type F = <SC as StarkGenericConfig>::Val;
type EF = <SC as StarkGenericConfig>::Challenge;
type C = AsmConfig<F, EF>;
type Val = BabyBear;
type Challenge = BinomialExtensionField<Val, 4>;
type Perm = Poseidon2<Val, Poseidon2ExternalMatrixGeneral, DiffusionMatrixBabyBear, 16, 7>;
type Hash = PaddingFreeSponge<Perm, 16, 8, 8>;
type Compress = TruncatedPermutation<Perm, 2, 8, 16>;
type ValMmcs =
FieldMerkleTreeMmcs<<Val as Field>::Packing, <Val as Field>::Packing, Hash, Compress, 8>;
type ChallengeMmcs = ExtensionMmcs<Val, Challenge, ValMmcs>;
type RecursionConfig = AsmConfig<Val, Challenge>;
type RecursionBuilder = Builder<RecursionConfig>;
pub fn const_fri_config(
builder: &mut RecursionBuilder,
config: &FriConfig<ChallengeMmcs>,
) -> FriConfigVariable<RecursionConfig> {
let two_addicity = Val::TWO_ADICITY;
let mut generators = builder.dyn_array(two_addicity);
let mut subgroups = builder.dyn_array(two_addicity);
for i in 0..two_addicity {
let constant_generator = Val::two_adic_generator(i);
builder.set(&mut generators, i, constant_generator);
let constant_domain = TwoAdicMultiplicativeCoset { log_n: i, shift: Val::one() };
let domain_value: TwoAdicMultiplicativeCosetVariable<_> = builder.constant(constant_domain);
builder.set(&mut subgroups, i, domain_value);
}
FriConfigVariable {
log_blowup: builder.eval(BabyBear::from_canonical_usize(config.log_blowup)),
blowup: builder.eval(BabyBear::from_canonical_usize(1 << config.log_blowup)),
num_queries: builder.eval(BabyBear::from_canonical_usize(config.num_queries)),
proof_of_work_bits: builder.eval(BabyBear::from_canonical_usize(config.proof_of_work_bits)),
subgroups,
generators,
}
}
pub fn clone<T: MemVariable<C>>(builder: &mut RecursionBuilder, var: &T) -> T {
let mut arr = builder.dyn_array(1);
builder.set(&mut arr, 0, var.clone());
builder.get(&arr, 0)
}
pub fn clone_array<T: MemVariable<C>>(
builder: &mut RecursionBuilder,
arr: &Array<C, T>,
) -> Array<C, T> {
let mut new_arr = builder.dyn_array(arr.len());
builder.range(0, arr.len()).for_each(|i, builder| {
let var = builder.get(arr, i);
builder.set(&mut new_arr, i, var);
});
new_arr
}
pub fn felt2var<C: Config>(builder: &mut Builder<C>, felt: Felt<C::F>) -> Var<C::N> {
let bits = builder.num2bits_f(felt);
builder.bits2num_v(&bits)
}
pub fn var2felt<C: Config>(builder: &mut Builder<C>, var: Var<C::N>) -> Felt<C::F> {
let bits = builder.num2bits_v(var);
builder.bits2num_f(&bits)
}
pub fn assert_challenger_eq_pv<C: Config>(
builder: &mut Builder<C>,
var: &DuplexChallengerVariable<C>,
values: ChallengerPublicValues<Felt<C::F>>,
) {
for i in 0..PERMUTATION_WIDTH {
let element = builder.get(&var.sponge_state, i);
builder.assert_felt_eq(element, values.sponge_state[i]);
}
let num_inputs_var = felt2var(builder, values.num_inputs);
builder.assert_var_eq(var.nb_inputs, num_inputs_var);
let mut input_buffer_array: Array<_, Felt<_>> = builder.dyn_array(PERMUTATION_WIDTH);
for i in 0..PERMUTATION_WIDTH {
builder.set(&mut input_buffer_array, i, values.input_buffer[i]);
}
builder.range(0, num_inputs_var).for_each(|i, builder| {
let element = builder.get(&var.input_buffer, i);
let values_element = builder.get(&input_buffer_array, i);
builder.assert_felt_eq(element, values_element);
});
let num_outputs_var = felt2var(builder, values.num_outputs);
builder.assert_var_eq(var.nb_outputs, num_outputs_var);
let mut output_buffer_array: Array<_, Felt<_>> = builder.dyn_array(PERMUTATION_WIDTH);
for i in 0..PERMUTATION_WIDTH {
builder.set(&mut output_buffer_array, i, values.output_buffer[i]);
}
builder.range(0, num_outputs_var).for_each(|i, builder| {
let element = builder.get(&var.output_buffer, i);
let values_element = builder.get(&output_buffer_array, i);
builder.assert_felt_eq(element, values_element);
});
}
pub fn assign_challenger_from_pv<C: Config>(
builder: &mut Builder<C>,
dst: &mut DuplexChallengerVariable<C>,
values: ChallengerPublicValues<Felt<C::F>>,
) {
for i in 0..PERMUTATION_WIDTH {
builder.set(&mut dst.sponge_state, i, values.sponge_state[i]);
}
let num_inputs_var = felt2var(builder, values.num_inputs);
builder.assign(dst.nb_inputs, num_inputs_var);
for i in 0..PERMUTATION_WIDTH {
builder.set(&mut dst.input_buffer, i, values.input_buffer[i]);
}
let num_outputs_var = felt2var(builder, values.num_outputs);
builder.assign(dst.nb_outputs, num_outputs_var);
for i in 0..PERMUTATION_WIDTH {
builder.set(&mut dst.output_buffer, i, values.output_buffer[i]);
}
}
pub fn get_challenger_public_values<C: Config>(
builder: &mut Builder<C>,
var: &DuplexChallengerVariable<C>,
) -> ChallengerPublicValues<Felt<C::F>> {
let sponge_state = core::array::from_fn(|i| builder.get(&var.sponge_state, i));
let num_inputs = var2felt(builder, var.nb_inputs);
let input_buffer = core::array::from_fn(|i| builder.get(&var.input_buffer, i));
let num_outputs = var2felt(builder, var.nb_outputs);
let output_buffer = core::array::from_fn(|i| builder.get(&var.output_buffer, i));
ChallengerPublicValues { sponge_state, num_inputs, input_buffer, num_outputs, output_buffer }
}
pub fn hash_vkey<C: Config>(
builder: &mut Builder<C>,
vk: &VerifyingKeyVariable<C>,
) -> Array<C, Felt<C::F>> {
let domain_slots: Var<_> = builder.eval(vk.prep_domains.len() * 4);
let vkey_slots: Var<_> = builder.constant(C::N::from_canonical_usize(DIGEST_SIZE + 1));
let total_slots: Var<_> = builder.eval(vkey_slots + domain_slots);
let mut inputs = builder.dyn_array(total_slots);
builder.range(0, DIGEST_SIZE).for_each(|i, builder| {
let element = builder.get(&vk.commitment, i);
builder.set(&mut inputs, i, element);
});
builder.set(&mut inputs, DIGEST_SIZE, vk.pc_start);
let four: Var<_> = builder.constant(C::N::from_canonical_usize(4));
let one: Var<_> = builder.constant(C::N::one());
builder.range(0, vk.prep_domains.len()).for_each(|i, builder| {
let sorted_index = builder.get(&vk.preprocessed_sorted_idxs, i);
let domain = builder.get(&vk.prep_domains, i);
let log_n_index: Var<_> = builder.eval(vkey_slots + sorted_index * four);
let size_index: Var<_> = builder.eval(log_n_index + one);
let shift_index: Var<_> = builder.eval(size_index + one);
let g_index: Var<_> = builder.eval(shift_index + one);
let log_n_felt = var2felt(builder, domain.log_n);
let size_felt = var2felt(builder, domain.size);
builder.set(&mut inputs, log_n_index, log_n_felt);
builder.set(&mut inputs, size_index, size_felt);
builder.set(&mut inputs, shift_index, domain.shift);
builder.set(&mut inputs, g_index, domain.g);
});
builder.poseidon2_hash(&inputs)
}
pub(crate) fn get_sorted_indices<SC: StarkGenericConfig, A: MachineAir<SC::Val>>(
machine: &StarkMachine<SC, A>,
proof: &ShardProof<SC>,
) -> Vec<usize> {
machine
.chips_sorted_indices(proof)
.into_iter()
.map(|x| match x {
Some(x) => x,
None => EMPTY,
})
.collect()
}
pub(crate) fn get_preprocessed_data<SC: StarkGenericConfig, A: MachineAir<SC::Val>>(
machine: &StarkMachine<SC, A>,
vk: &StarkVerifyingKey<SC>,
) -> (Vec<usize>, Vec<Dom<SC>>) {
let chips = machine.chips();
let (prep_sorted_indices, prep_domains) = machine
.preprocessed_chip_ids()
.into_iter()
.map(|chip_idx| {
let name = chips[chip_idx].name().clone();
let prep_sorted_idx = vk.chip_ordering[&name];
(prep_sorted_idx, vk.chip_information[prep_sorted_idx].1)
})
.unzip();
(prep_sorted_indices, prep_domains)
}
pub(crate) fn get_chip_quotient_data<SC: StarkGenericConfig, A: MachineAir<SC::Val>>(
machine: &StarkMachine<SC, A>,
proof: &ShardProof<SC>,
) -> Vec<QuotientDataValues> {
machine
.shard_chips_ordered(&proof.chip_ordering)
.map(|chip| {
let log_quotient_degree = chip.log_quotient_degree();
QuotientDataValues { log_quotient_degree, quotient_size: 1 << log_quotient_degree }
})
.collect()
}