sp1_recursion_circuit_v2/
hash.rs1use std::iter::zip;
2
3use itertools::Itertools;
4use p3_baby_bear::BabyBear;
5use p3_field::{AbstractField, Field};
6
7use p3_bn254_fr::Bn254Fr;
8use sp1_recursion_compiler::{
9 circuit::CircuitV2Builder,
10 ir::{Builder, Config, DslIr, Felt, Var},
11};
12use sp1_recursion_core_v2::{stark::config::BabyBearPoseidon2Outer, DIGEST_SIZE};
13use sp1_stark::baby_bear_poseidon2::BabyBearPoseidon2;
14
15use crate::{
16 challenger::{reduce_32, RATE, SPONGE_SIZE},
17 select_chain, CircuitConfig,
18};
19
20pub trait FieldHasherVariable<C: CircuitConfig> {
21 type Digest: Clone + Copy;
22
23 fn hash(builder: &mut Builder<C>, input: &[Felt<C::F>]) -> Self::Digest;
24
25 fn compress(builder: &mut Builder<C>, input: [Self::Digest; 2]) -> Self::Digest;
26
27 fn assert_digest_eq(builder: &mut Builder<C>, a: Self::Digest, b: Self::Digest);
28
29 fn select_chain_digest(
31 builder: &mut Builder<C>,
32 should_swap: C::Bit,
33 input: [Self::Digest; 2],
34 ) -> [Self::Digest; 2];
35}
36
37impl<C: CircuitConfig<F = BabyBear, Bit = Felt<BabyBear>>> FieldHasherVariable<C>
38 for BabyBearPoseidon2
39{
40 type Digest = [Felt<BabyBear>; DIGEST_SIZE];
41
42 fn hash(builder: &mut Builder<C>, input: &[Felt<<C as Config>::F>]) -> Self::Digest {
43 builder.poseidon2_hash_v2(input)
44 }
45
46 fn compress(builder: &mut Builder<C>, input: [Self::Digest; 2]) -> Self::Digest {
47 builder.poseidon2_compress_v2(input.into_iter().flatten())
48 }
49
50 fn assert_digest_eq(builder: &mut Builder<C>, a: Self::Digest, b: Self::Digest) {
51 zip(a, b).for_each(|(e1, e2)| builder.assert_felt_eq(e1, e2));
52 }
53
54 fn select_chain_digest(
55 builder: &mut Builder<C>,
56 should_swap: <C as CircuitConfig>::Bit,
57 input: [Self::Digest; 2],
58 ) -> [Self::Digest; 2] {
59 let err_msg = "select_chain's return value should have length the sum of its inputs";
60 let mut selected = select_chain(builder, should_swap, input[0], input[1]);
61 let ret = [
62 core::array::from_fn(|_| selected.next().expect(err_msg)),
63 core::array::from_fn(|_| selected.next().expect(err_msg)),
64 ];
65 assert_eq!(selected.next(), None, "{}", err_msg);
66 ret
67 }
68}
69
70pub const BN254_DIGEST_SIZE: usize = 1;
71impl<C: CircuitConfig<F = BabyBear, N = Bn254Fr, Bit = Var<Bn254Fr>>> FieldHasherVariable<C>
72 for BabyBearPoseidon2Outer
73{
74 type Digest = [Var<Bn254Fr>; BN254_DIGEST_SIZE];
75
76 fn hash(builder: &mut Builder<C>, input: &[Felt<<C as Config>::F>]) -> Self::Digest {
77 assert!(C::N::bits() == p3_bn254_fr::Bn254Fr::bits());
78 assert!(C::F::bits() == p3_baby_bear::BabyBear::bits());
79 let num_f_elms = C::N::bits() / C::F::bits();
80 let mut state: [Var<C::N>; SPONGE_SIZE] =
81 [builder.eval(C::N::zero()), builder.eval(C::N::zero()), builder.eval(C::N::zero())];
82 for block_chunk in &input.iter().chunks(RATE) {
83 for (chunk_id, chunk) in (&block_chunk.chunks(num_f_elms)).into_iter().enumerate() {
84 let chunk = chunk.collect_vec().into_iter().copied().collect::<Vec<_>>();
85 state[chunk_id] = reduce_32(builder, chunk.as_slice());
86 }
87 builder.push(DslIr::CircuitPoseidon2Permute(state))
88 }
89
90 [state[0]; BN254_DIGEST_SIZE]
91 }
92
93 fn compress(builder: &mut Builder<C>, input: [Self::Digest; 2]) -> Self::Digest {
94 let state: [Var<C::N>; SPONGE_SIZE] =
95 [builder.eval(input[0][0]), builder.eval(input[1][0]), builder.eval(C::N::zero())];
96 builder.push(DslIr::CircuitPoseidon2Permute(state));
97 [state[0]; BN254_DIGEST_SIZE]
98 }
99
100 fn assert_digest_eq(builder: &mut Builder<C>, a: Self::Digest, b: Self::Digest) {
101 zip(a, b).for_each(|(e1, e2)| builder.assert_var_eq(e1, e2));
102 }
103
104 fn select_chain_digest(
105 builder: &mut Builder<C>,
106 should_swap: <C as CircuitConfig>::Bit,
107 input: [Self::Digest; 2],
108 ) -> [Self::Digest; 2] {
109 let result0: [Var<_>; 1] = core::array::from_fn(|j| {
110 let result = builder.uninit();
111 builder.push(DslIr::CircuitSelectV(should_swap, input[1][j], input[0][j], result));
112 result
113 });
114 let result1: [Var<_>; 1] = core::array::from_fn(|j| {
115 let result = builder.uninit();
116 builder.push(DslIr::CircuitSelectV(should_swap, input[0][j], input[1][j], result));
117 result
118 });
119
120 [result0, result1]
121 }
122}
123
124