sp1_recursion_circuit_v2/
hash.rs

1use std::iter::zip;
2
3use itertools::Itertools;
4use p3_baby_bear::BabyBear;
5use p3_field::{AbstractField, Field};
6
7use p3_bn254_fr::Bn254Fr;
8use sp1_recursion_compiler::{
9    circuit::CircuitV2Builder,
10    ir::{Builder, Config, DslIr, Felt, Var},
11};
12use sp1_recursion_core_v2::{stark::config::BabyBearPoseidon2Outer, DIGEST_SIZE};
13use sp1_stark::baby_bear_poseidon2::BabyBearPoseidon2;
14
15use crate::{
16    challenger::{reduce_32, RATE, SPONGE_SIZE},
17    select_chain, CircuitConfig,
18};
19
20pub trait FieldHasherVariable<C: CircuitConfig> {
21    type Digest: Clone + Copy;
22
23    fn hash(builder: &mut Builder<C>, input: &[Felt<C::F>]) -> Self::Digest;
24
25    fn compress(builder: &mut Builder<C>, input: [Self::Digest; 2]) -> Self::Digest;
26
27    fn assert_digest_eq(builder: &mut Builder<C>, a: Self::Digest, b: Self::Digest);
28
29    // Encountered many issues trying to make the following two parametrically polymorphic.
30    fn select_chain_digest(
31        builder: &mut Builder<C>,
32        should_swap: C::Bit,
33        input: [Self::Digest; 2],
34    ) -> [Self::Digest; 2];
35}
36
37impl<C: CircuitConfig<F = BabyBear, Bit = Felt<BabyBear>>> FieldHasherVariable<C>
38    for BabyBearPoseidon2
39{
40    type Digest = [Felt<BabyBear>; DIGEST_SIZE];
41
42    fn hash(builder: &mut Builder<C>, input: &[Felt<<C as Config>::F>]) -> Self::Digest {
43        builder.poseidon2_hash_v2(input)
44    }
45
46    fn compress(builder: &mut Builder<C>, input: [Self::Digest; 2]) -> Self::Digest {
47        builder.poseidon2_compress_v2(input.into_iter().flatten())
48    }
49
50    fn assert_digest_eq(builder: &mut Builder<C>, a: Self::Digest, b: Self::Digest) {
51        zip(a, b).for_each(|(e1, e2)| builder.assert_felt_eq(e1, e2));
52    }
53
54    fn select_chain_digest(
55        builder: &mut Builder<C>,
56        should_swap: <C as CircuitConfig>::Bit,
57        input: [Self::Digest; 2],
58    ) -> [Self::Digest; 2] {
59        let err_msg = "select_chain's return value should have length the sum of its inputs";
60        let mut selected = select_chain(builder, should_swap, input[0], input[1]);
61        let ret = [
62            core::array::from_fn(|_| selected.next().expect(err_msg)),
63            core::array::from_fn(|_| selected.next().expect(err_msg)),
64        ];
65        assert_eq!(selected.next(), None, "{}", err_msg);
66        ret
67    }
68}
69
70pub const BN254_DIGEST_SIZE: usize = 1;
71impl<C: CircuitConfig<F = BabyBear, N = Bn254Fr, Bit = Var<Bn254Fr>>> FieldHasherVariable<C>
72    for BabyBearPoseidon2Outer
73{
74    type Digest = [Var<Bn254Fr>; BN254_DIGEST_SIZE];
75
76    fn hash(builder: &mut Builder<C>, input: &[Felt<<C as Config>::F>]) -> Self::Digest {
77        assert!(C::N::bits() == p3_bn254_fr::Bn254Fr::bits());
78        assert!(C::F::bits() == p3_baby_bear::BabyBear::bits());
79        let num_f_elms = C::N::bits() / C::F::bits();
80        let mut state: [Var<C::N>; SPONGE_SIZE] =
81            [builder.eval(C::N::zero()), builder.eval(C::N::zero()), builder.eval(C::N::zero())];
82        for block_chunk in &input.iter().chunks(RATE) {
83            for (chunk_id, chunk) in (&block_chunk.chunks(num_f_elms)).into_iter().enumerate() {
84                let chunk = chunk.collect_vec().into_iter().copied().collect::<Vec<_>>();
85                state[chunk_id] = reduce_32(builder, chunk.as_slice());
86            }
87            builder.push(DslIr::CircuitPoseidon2Permute(state))
88        }
89
90        [state[0]; BN254_DIGEST_SIZE]
91    }
92
93    fn compress(builder: &mut Builder<C>, input: [Self::Digest; 2]) -> Self::Digest {
94        let state: [Var<C::N>; SPONGE_SIZE] =
95            [builder.eval(input[0][0]), builder.eval(input[1][0]), builder.eval(C::N::zero())];
96        builder.push(DslIr::CircuitPoseidon2Permute(state));
97        [state[0]; BN254_DIGEST_SIZE]
98    }
99
100    fn assert_digest_eq(builder: &mut Builder<C>, a: Self::Digest, b: Self::Digest) {
101        zip(a, b).for_each(|(e1, e2)| builder.assert_var_eq(e1, e2));
102    }
103
104    fn select_chain_digest(
105        builder: &mut Builder<C>,
106        should_swap: <C as CircuitConfig>::Bit,
107        input: [Self::Digest; 2],
108    ) -> [Self::Digest; 2] {
109        let result0: [Var<_>; 1] = core::array::from_fn(|j| {
110            let result = builder.uninit();
111            builder.push(DslIr::CircuitSelectV(should_swap, input[1][j], input[0][j], result));
112            result
113        });
114        let result1: [Var<_>; 1] = core::array::from_fn(|j| {
115            let result = builder.uninit();
116            builder.push(DslIr::CircuitSelectV(should_swap, input[0][j], input[1][j], result));
117            result
118        });
119
120        [result0, result1]
121    }
122}
123
124// impl<C: Config<F = BabyBear>> FieldHasherVariable<C> for OuterHash {
125
126// }