Skip to main content

trident/field/
poseidon2.rs

1//! Generic Poseidon2 hash function over any PrimeField.
2//!
3//! Implements the Poseidon2 permutation (Grassi et al., 2023) with
4//! configurable state width, round counts, and S-box. The sponge
5//! construction (absorb/squeeze) is field-generic.
6//!
7//! Warriors call `poseidon2_hash::<Goldilocks>(...)` or
8//! `poseidon2_hash::<BabyBear>(...)` — same code, different field.
9
10use super::PrimeField;
11
12// ─── Poseidon2 Parameters ──────────────────────────────────────────
13
14/// Poseidon2 configuration for a specific field instantiation.
15pub struct Poseidon2Config<F: PrimeField> {
16    /// State width (typically 8 or 12).
17    pub width: usize,
18    /// Rate (number of input elements absorbed per permutation).
19    pub rate: usize,
20    /// Number of full rounds (split evenly: half before, half after partial).
21    pub rounds_f: usize,
22    /// Number of partial rounds.
23    pub rounds_p: usize,
24    /// Internal diagonal constants for the internal linear layer.
25    pub diag: Vec<F>,
26    /// Round constants (R_F * width + R_P elements).
27    pub round_constants: Vec<F>,
28}
29
30/// Default Poseidon2 config for Goldilocks (t=8, rate=4, RF=8, RP=22).
31pub fn goldilocks_config() -> Poseidon2Config<super::Goldilocks> {
32    use super::Goldilocks;
33
34    let width = 8;
35    let rate = 4;
36    let rounds_f = 8;
37    let rounds_p = 22;
38
39    let diag: Vec<Goldilocks> = [2u64, 3, 5, 9, 17, 33, 65, 129]
40        .iter()
41        .map(|&v| Goldilocks(v))
42        .collect();
43
44    let round_constants = generate_round_constants::<Goldilocks>(
45        width,
46        rounds_f,
47        rounds_p,
48        "Poseidon2-Goldilocks-t8-RF8-RP22",
49    );
50
51    Poseidon2Config {
52        width,
53        rate,
54        rounds_f,
55        rounds_p,
56        diag,
57        round_constants,
58    }
59}
60
61/// Generate round constants deterministically from BLAKE3.
62fn generate_round_constants<F: PrimeField>(
63    width: usize,
64    rounds_f: usize,
65    rounds_p: usize,
66    tag_prefix: &str,
67) -> Vec<F> {
68    let total_rounds = rounds_f + rounds_p;
69    let mut constants = Vec::new();
70    for r in 0..total_rounds {
71        let is_full = r < rounds_f / 2 || r >= rounds_f / 2 + rounds_p;
72        if is_full {
73            for e in 0..width {
74                let tag = format!("{}-{}-{}", tag_prefix, r, e);
75                let digest = blake3::hash(tag.as_bytes());
76                let bytes: [u8; 8] = digest.as_bytes()[..8].try_into().unwrap_or([0u8; 8]);
77                constants.push(F::from_u64(u64::from_le_bytes(bytes)));
78            }
79        } else {
80            let tag = format!("{}-{}-0", tag_prefix, r);
81            let digest = blake3::hash(tag.as_bytes());
82            let bytes: [u8; 8] = digest.as_bytes()[..8].try_into().unwrap_or([0u8; 8]);
83            constants.push(F::from_u64(u64::from_le_bytes(bytes)));
84        }
85    }
86    constants
87}
88
89// ─── Cached Goldilocks Config ──────────────────────────────────────
90
91fn cached_goldilocks_config() -> &'static Poseidon2Config<super::Goldilocks> {
92    static CONFIG: std::sync::OnceLock<Poseidon2Config<super::Goldilocks>> =
93        std::sync::OnceLock::new();
94    CONFIG.get_or_init(goldilocks_config)
95}
96
97// ─── Permutation ───────────────────────────────────────────────────
98
99/// Apply the Poseidon2 S-box (x^7) to a single field element.
100#[inline]
101fn sbox<F: PrimeField>(x: F) -> F {
102    let x2 = x.mul(x);
103    let x3 = x2.mul(x);
104    let x6 = x3.mul(x3);
105    x6.mul(x)
106}
107
108/// External linear layer: circ(2,1,...,1).
109/// new[i] = state[i] + sum(state).
110fn external_linear<F: PrimeField>(state: &mut [F]) {
111    let sum = state.iter().fold(F::ZERO, |a, &b| a.add(b));
112    for s in state.iter_mut() {
113        *s = s.add(sum);
114    }
115}
116
117/// Internal linear layer: diag(d_0,...,d_{w-1}) + ones_matrix.
118/// new[i] = d_i * state[i] + sum(state).
119fn internal_linear<F: PrimeField>(state: &mut [F], diag: &[F]) {
120    let sum = state.iter().fold(F::ZERO, |a, &b| a.add(b));
121    for (i, s) in state.iter_mut().enumerate() {
122        *s = diag[i].mul(*s).add(sum);
123    }
124}
125
126/// Full Poseidon2 permutation (in-place, generic over field and width).
127pub fn permutation<F: PrimeField>(state: &mut [F], config: &Poseidon2Config<F>) {
128    let mut ci = 0;
129    let width = config.width;
130
131    // First R_F/2 full rounds
132    for _ in 0..config.rounds_f / 2 {
133        for s in state[..width].iter_mut() {
134            *s = s.add(config.round_constants[ci]);
135            ci += 1;
136        }
137        for s in state[..width].iter_mut() {
138            *s = sbox(*s);
139        }
140        external_linear(&mut state[..width]);
141    }
142
143    // R_P partial rounds
144    for _ in 0..config.rounds_p {
145        state[0] = state[0].add(config.round_constants[ci]);
146        ci += 1;
147        state[0] = sbox(state[0]);
148        internal_linear(&mut state[..width], &config.diag);
149    }
150
151    // Last R_F/2 full rounds
152    for _ in 0..config.rounds_f / 2 {
153        for s in state[..width].iter_mut() {
154            *s = s.add(config.round_constants[ci]);
155            ci += 1;
156        }
157        for s in state[..width].iter_mut() {
158            *s = sbox(*s);
159        }
160        external_linear(&mut state[..width]);
161    }
162}
163
164// ─── Sponge Hasher ─────────────────────────────────────────────────
165
166/// Absorb field elements, permute, squeeze — generic over PrimeField.
167fn sponge_hash<F: PrimeField>(
168    elements: &[F],
169    config: &Poseidon2Config<F>,
170    squeeze_count: usize,
171) -> Vec<F> {
172    let mut state = vec![F::ZERO; config.width];
173    let mut absorbed = 0;
174
175    for &elem in elements {
176        if absorbed == config.rate {
177            permutation(&mut state, config);
178            absorbed = 0;
179        }
180        state[absorbed] = state[absorbed].add(elem);
181        absorbed += 1;
182    }
183
184    // Squeeze
185    permutation(&mut state, config);
186    let mut out = Vec::with_capacity(squeeze_count);
187    let mut squeezed = 0;
188    loop {
189        for &elem in state[..config.rate].iter() {
190            out.push(elem);
191            squeezed += 1;
192            if squeezed == squeeze_count {
193                return out;
194            }
195        }
196        permutation(&mut state, config);
197    }
198}
199
200// ─── Goldilocks Convenience Functions ──────────────────────────────
201
202/// Hash arbitrary bytes using Poseidon2 over Goldilocks, returning 32 bytes.
203///
204/// This is the drop-in replacement for `crate::package::poseidon2::hash_bytes`.
205pub fn hash_bytes_goldilocks(data: &[u8]) -> [u8; 32] {
206    use super::Goldilocks;
207
208    const BYTES_PER_ELEM: usize = 7;
209    let mut elements = Vec::with_capacity(data.len() / BYTES_PER_ELEM + 2);
210    for chunk in data.chunks(BYTES_PER_ELEM) {
211        let mut buf = [0u8; 8];
212        buf[..chunk.len()].copy_from_slice(chunk);
213        elements.push(Goldilocks::from_u64(u64::from_le_bytes(buf)));
214    }
215    // Length separator
216    elements.push(Goldilocks::from_u64(data.len() as u64));
217
218    let config = cached_goldilocks_config();
219    let result = sponge_hash(&elements, config, 4);
220
221    let mut out = [0u8; 32];
222    for (i, elem) in result.iter().enumerate() {
223        out[i * 8..i * 8 + 8].copy_from_slice(&elem.to_u64().to_le_bytes());
224    }
225    out
226}
227
228/// Hash Goldilocks field elements, returning 4 elements.
229pub fn hash_fields_goldilocks(elements: &[super::Goldilocks]) -> [super::Goldilocks; 4] {
230    let config = cached_goldilocks_config();
231    let result = sponge_hash(elements, config, 4);
232    [result[0], result[1], result[2], result[3]]
233}
234
235// ─── Tests ─────────────────────────────────────────────────────────
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use crate::field::Goldilocks;
241
242    #[test]
243    fn goldilocks_hash_deterministic() {
244        assert_eq!(
245            hash_bytes_goldilocks(b"hello world"),
246            hash_bytes_goldilocks(b"hello world"),
247        );
248    }
249
250    #[test]
251    fn goldilocks_hash_different_inputs() {
252        assert_ne!(
253            hash_bytes_goldilocks(b"hello"),
254            hash_bytes_goldilocks(b"world"),
255        );
256    }
257
258    #[test]
259    fn goldilocks_hash_fields_deterministic() {
260        let elems: Vec<Goldilocks> = (1..=5).map(|v| Goldilocks::from_u64(v)).collect();
261        assert_eq!(
262            hash_fields_goldilocks(&elems),
263            hash_fields_goldilocks(&elems)
264        );
265    }
266
267    #[test]
268    fn goldilocks_collision_resistance() {
269        let hashes: Vec<[u8; 32]> = (0u64..20)
270            .map(|i| hash_bytes_goldilocks(&i.to_le_bytes()))
271            .collect();
272        for i in 0..hashes.len() {
273            for j in i + 1..hashes.len() {
274                assert_ne!(hashes[i], hashes[j], "collision between {} and {}", i, j);
275            }
276        }
277    }
278
279    #[test]
280    fn permutation_diffusion() {
281        let config = cached_goldilocks_config();
282        let base: Vec<Goldilocks> = (0..8).map(|i| Goldilocks::from_u64(i + 100)).collect();
283        let mut s1 = base.clone();
284        permutation(&mut s1, config);
285
286        let mut tweaked = base;
287        tweaked[0] = tweaked[0].add(Goldilocks::ONE);
288        let mut s2 = tweaked;
289        permutation(&mut s2, config);
290
291        for i in 0..8 {
292            assert_ne!(s1[i], s2[i], "element {} unchanged after input tweak", i);
293        }
294    }
295}