use std::iter::repeat;
use p3_field::{AbstractExtensionField, AbstractField};
use sp1_recursion_core::air::RecursionPublicValues;
use crate::prelude::*;
use sp1_recursion_core_v2::{chips::poseidon2_skinny::WIDTH, D, DIGEST_SIZE, HASH_RATE};
pub trait CircuitV2Builder<C: Config> {
fn bits2num_v2_f(
&mut self,
bits: impl IntoIterator<Item = Felt<<C as Config>::F>>,
) -> Felt<C::F>;
fn num2bits_v2_f(&mut self, num: Felt<C::F>, num_bits: usize) -> Vec<Felt<C::F>>;
fn exp_reverse_bits_v2(&mut self, input: Felt<C::F>, power_bits: Vec<Felt<C::F>>)
-> Felt<C::F>;
fn poseidon2_permute_v2(&mut self, state: [Felt<C::F>; WIDTH]) -> [Felt<C::F>; WIDTH];
fn poseidon2_hash_v2(&mut self, array: &[Felt<C::F>]) -> [Felt<C::F>; DIGEST_SIZE];
fn poseidon2_compress_v2(
&mut self,
input: impl IntoIterator<Item = Felt<C::F>>,
) -> [Felt<C::F>; DIGEST_SIZE];
fn fri_fold_v2(&mut self, input: CircuitV2FriFoldInput<C>) -> CircuitV2FriFoldOutput<C>;
fn ext2felt_v2(&mut self, ext: Ext<C::F, C::EF>) -> [Felt<C::F>; D];
fn commit_public_values_v2(&mut self, public_values: RecursionPublicValues<Felt<C::F>>);
fn cycle_tracker_v2_enter(&mut self, name: String);
fn cycle_tracker_v2_exit(&mut self);
fn hint_ext_v2(&mut self) -> Ext<C::F, C::EF>;
fn hint_felt_v2(&mut self) -> Felt<C::F>;
fn hint_exts_v2(&mut self, len: usize) -> Vec<Ext<C::F, C::EF>>;
fn hint_felts_v2(&mut self, len: usize) -> Vec<Felt<C::F>>;
}
impl<C: Config> CircuitV2Builder<C> for Builder<C> {
fn bits2num_v2_f(
&mut self,
bits: impl IntoIterator<Item = Felt<<C as Config>::F>>,
) -> Felt<<C as Config>::F> {
let mut num: Felt<_> = self.eval(C::F::zero());
for (i, bit) in bits.into_iter().enumerate() {
num = self.eval(num + bit * C::F::from_wrapped_u32(1 << i));
}
num
}
fn num2bits_v2_f(&mut self, num: Felt<C::F>, num_bits: usize) -> Vec<Felt<C::F>> {
let output = std::iter::from_fn(|| Some(self.uninit())).take(num_bits).collect::<Vec<_>>();
self.push(DslIr::CircuitV2HintBitsF(output.clone(), num));
let x: SymbolicFelt<_> = output
.iter()
.enumerate()
.map(|(i, &bit)| {
self.assert_felt_eq(bit * (bit - C::F::one()), C::F::zero());
bit * C::F::from_wrapped_u32(1 << i)
})
.sum();
self.assert_felt_eq(x, num);
output
}
fn exp_reverse_bits_v2(
&mut self,
input: Felt<C::F>,
power_bits: Vec<Felt<C::F>>,
) -> Felt<C::F> {
let output: Felt<_> = self.uninit();
self.operations.push(DslIr::CircuitV2ExpReverseBits(output, input, power_bits));
output
}
fn poseidon2_permute_v2(&mut self, array: [Felt<C::F>; WIDTH]) -> [Felt<C::F>; WIDTH] {
let output: [Felt<C::F>; WIDTH] = core::array::from_fn(|_| self.uninit());
self.operations.push(DslIr::CircuitV2Poseidon2PermuteBabyBear(Box::new((output, array))));
output
}
fn poseidon2_hash_v2(&mut self, input: &[Felt<C::F>]) -> [Felt<C::F>; DIGEST_SIZE] {
let mut state = core::array::from_fn(|_| self.eval(C::F::zero()));
for input_chunk in input.chunks(HASH_RATE) {
state[..input_chunk.len()].copy_from_slice(input_chunk);
state = self.poseidon2_permute_v2(state);
}
let state: [Felt<C::F>; DIGEST_SIZE] = state[..DIGEST_SIZE].try_into().unwrap();
state
}
fn poseidon2_compress_v2(
&mut self,
input: impl IntoIterator<Item = Felt<C::F>>,
) -> [Felt<C::F>; DIGEST_SIZE] {
let mut pre_iter = input.into_iter().chain(repeat(self.eval(C::F::default())));
let pre = core::array::from_fn(move |_| pre_iter.next().unwrap());
let post = self.poseidon2_permute_v2(pre);
let post: [Felt<C::F>; DIGEST_SIZE] = post[..DIGEST_SIZE].try_into().unwrap();
post
}
fn fri_fold_v2(&mut self, input: CircuitV2FriFoldInput<C>) -> CircuitV2FriFoldOutput<C> {
let mut uninit_vec = |len| std::iter::from_fn(|| Some(self.uninit())).take(len).collect();
let output = CircuitV2FriFoldOutput {
alpha_pow_output: uninit_vec(input.alpha_pow_input.len()),
ro_output: uninit_vec(input.ro_input.len()),
};
self.operations.push(DslIr::CircuitV2FriFold(Box::new((output.clone(), input))));
output
}
fn ext2felt_v2(&mut self, ext: Ext<C::F, C::EF>) -> [Felt<C::F>; D] {
let felts = core::array::from_fn(|_| self.uninit());
self.operations.push(DslIr::CircuitExt2Felt(felts, ext));
let mut reconstructed_ext: Ext<C::F, C::EF> = self.constant(C::EF::zero());
for i in 0..4 {
let felt = felts[i];
let monomial: Ext<C::F, C::EF> = self.constant(C::EF::monomial(i));
reconstructed_ext = self.eval(reconstructed_ext + monomial * felt);
}
self.assert_ext_eq(reconstructed_ext, ext);
felts
}
fn commit_public_values_v2(&mut self, public_values: RecursionPublicValues<Felt<C::F>>) {
self.operations.push(DslIr::CircuitV2CommitPublicValues(Box::new(public_values)));
}
fn cycle_tracker_v2_enter(&mut self, name: String) {
self.operations.push(DslIr::CycleTrackerV2Enter(name));
}
fn cycle_tracker_v2_exit(&mut self) {
self.operations.push(DslIr::CycleTrackerV2Exit);
}
fn hint_felt_v2(&mut self) -> Felt<C::F> {
self.hint_felts_v2(1)[0]
}
fn hint_ext_v2(&mut self) -> Ext<C::F, C::EF> {
self.hint_exts_v2(1)[0]
}
fn hint_felts_v2(&mut self, len: usize) -> Vec<Felt<C::F>> {
let arr = std::iter::from_fn(|| Some(self.uninit())).take(len).collect::<Vec<_>>();
self.operations.push(DslIr::CircuitV2HintFelts(arr.clone()));
arr
}
fn hint_exts_v2(&mut self, len: usize) -> Vec<Ext<C::F, C::EF>> {
let arr = std::iter::from_fn(|| Some(self.uninit())).take(len).collect::<Vec<_>>();
self.operations.push(DslIr::CircuitV2HintExts(arr.clone()));
arr
}
}