Skip to main content

spongefish/drivers/
ark_ff_impl.rs

1//! Helpers for bridging `ark_ff` field types with `spongefish` codecs.
2use alloc::{vec, vec::Vec};
3use core::marker::PhantomData;
4
5use ark_ff::{BigInteger, Field, Fp, FpConfig, PrimeField, SmallFp, SmallFpConfig};
6
7use crate::{
8    codecs::{Decoding, Encoding},
9    error::VerificationError,
10    io::NargDeserialize,
11    VerificationResult,
12};
13
14fn parse_canonical_prime_field<F: PrimeField>(bytes: &[u8]) -> Option<F> {
15    // A canonical encoding of an element of [0, p) fits in ⌈MODULUS_BIT_SIZE/8⌉ bytes.
16    // Reject any longer input up front, before allocating.
17    if bytes.len() > (F::MODULUS_BIT_SIZE as usize).div_ceil(8) {
18        return None;
19    }
20    let bits = bytes
21        .iter()
22        .flat_map(|byte| (0..8).rev().map(move |shift| (byte >> shift) & 1 == 1))
23        .collect::<Vec<_>>();
24    let bigint = F::BigInt::from_bits_be(&bits);
25    F::from_bigint(bigint)
26}
27
28// Make arkworks field elements a valid Unit type
29impl<C: ark_ff::FpConfig<N>, const N: usize> crate::Unit for Fp<C, N> {
30    const ZERO: Self = C::ZERO;
31}
32
33// Make SmallFp field elements a valid Unit type
34impl<P: SmallFpConfig> crate::Unit for SmallFp<P> {
35    const ZERO: Self = P::ZERO;
36}
37
38/// A buffer meant to hold enough bytes for obtaining a uniformly-distributed
39/// random field element.
40/// In practice, for [`DecodingFieldBuffer`] is meant to hold `F::MODULUS_BIT_SIZE.div_ceil(8) + 32`
41/// bytes. Unfortunately Rust does not support const generic expressions,
42/// and so [`DecodingFieldBuffer`] is implemented as a vector of [`u8`] with a [`PhantomData`]
43/// marker binding it to the [`ark_ff::Field`].
44pub struct DecodingFieldBuffer<F: Field> {
45    buf: Vec<u8>,
46    _phantom: PhantomData<F>,
47}
48
49/// The function determining the size of [`DecodingFieldBuffer`]:
50pub fn decoding_field_buffer_size<F: Field>() -> usize {
51    let base_field_modulus_bytes = u64::from(F::BasePrimeField::MODULUS_BIT_SIZE.div_ceil(8));
52    // Get 32 bytes of extra randomness for every base field element in the extension
53    let length = (base_field_modulus_bytes + 32) * F::extension_degree();
54    length as usize
55}
56
57/// A macro to bridge [`ark_serialize::CanonicalDeserialize`] with [`NargDeserialize`].
58///
59/// arkworks implements deserialization exactly as we want for field and elliptic curve elements.
60/// However, when used on slices, vectors, or fixed-length arrays it will also try to read the array length
61/// in the first 8 bytes.
62/// We work around that implementing [`NargDeserialize`] for it ourselves.
63macro_rules! impl_deserialize {
64    (impl [$($generics:tt)*] for $type:ty) => {
65        impl<$($generics)*> NargDeserialize for $type {
66            fn deserialize_from_narg(buf: &mut &[u8]) -> VerificationResult<Self> {
67                let extension_degree = <Self as Field>::extension_degree() as usize;
68                let base_field_size = (<Self as Field>::BasePrimeField::MODULUS_BIT_SIZE
69                    .div_ceil(8)) as usize;
70                let total_bytes = extension_degree * base_field_size;
71                if buf.len() < total_bytes {
72                    return Err(VerificationError);
73                }
74
75                let mut base_elems = Vec::with_capacity(extension_degree);
76                for chunk in buf[..total_bytes].chunks_exact(base_field_size) {
77                    let elem =
78                        parse_canonical_prime_field::<<Self as Field>::BasePrimeField>(chunk)
79                            .ok_or(VerificationError)?;
80                    base_elems.push(elem);
81                }
82                debug_assert_eq!(base_elems.len(), extension_degree);
83                let value = Self::from_base_prime_field_elems(base_elems).ok_or(VerificationError)?;
84                *buf = &buf[total_bytes..];
85                Ok(value)
86            }
87        }
88    };
89}
90
91/// A macro to bridge [`ark_serialize::CanonicalSerialize`] with [`Encoding`].
92///
93/// arkworks implements serialization exactly as we want for field and elliptic curve elements.
94/// However, when used over slices, vectors, or fixed-length arrays it will also write the array length
95/// in the first 8 bytes.
96/// We work around that implementing [NargSerialize][`spongefish::NargSerialize`] for those types ourselves.
97macro_rules! impl_encoding {
98    (impl [$($generics:tt)*] for $type:ty) => {
99        impl<$($generics)*> Encoding<[u8]> for $type {
100            fn encode(&self) -> impl AsRef<[u8]> {
101                let base_field_size = (<Self as Field>::BasePrimeField::MODULUS_BIT_SIZE
102                    .div_ceil(8)) as usize;
103                let mut buf = Vec::with_capacity(base_field_size * <Self as Field>::extension_degree() as usize);
104                for base_element in self.to_base_prime_field_elements() {
105                    let bytes = base_element.into_bigint().to_bytes_be();
106                    // Handle BigInt wider than the field (e.g. F16 inside SmallFp's BigInt<1>).
107                    let start = bytes.len().saturating_sub(base_field_size);
108                    // Handle BigInt narrower than the field (defensive).
109                    let padding = base_field_size.saturating_sub(bytes.len());
110                    buf.extend(core::iter::repeat_n(0, padding));
111                    buf.extend_from_slice(&bytes[start..]);
112                }
113                buf
114            }
115        }
116    };
117}
118
119/// Macro to implement [`Decoding`] for some [`ark_ff::Field`] instantiations.
120///
121/// Remember that the Rust type system does not accept conflicting blanket implementations,
122/// so we can't implement [`Decoding`] for `ark_ff::Field` and `ark_ff::AdditiveGroup`: the compiler
123/// will complain that a type might be implementing both in the future.
124macro_rules! impl_decoding {
125        (impl [$($generics:tt)*] for $type:ty) => {
126        impl<$($generics)*> Decoding<[u8]> for $type {
127            type Repr = DecodingFieldBuffer<Self>;
128
129            fn decode(repr: Self::Repr) -> Self {
130                debug_assert_eq!(repr.buf.len(), decoding_field_buffer_size::<Self>());
131                let base_field_size = decoding_field_buffer_size::<<Self as Field>::BasePrimeField>();
132
133                let result = repr.buf.chunks(base_field_size)
134                    .map(|chunk| <Self as Field>::BasePrimeField::from_be_bytes_mod_order(chunk))
135                    .collect::<Vec<_>>();
136                // Convert Vec to array - this unwrap is safe because we know the length
137                Self::from_base_prime_field_elems(result).unwrap()
138            }
139        }
140    }
141}
142
143// Implement NargDeserialize for prime-order fields and field extensions.
144impl_deserialize!(impl [C: FpConfig<N>, const N: usize] for Fp<C, N>);
145impl_deserialize!(impl [C: ark_ff::Fp2Config] for ark_ff::Fp2<C>);
146impl_deserialize!(impl [C: ark_ff::Fp3Config] for ark_ff::Fp3<C>);
147impl_deserialize!(impl [C: ark_ff::Fp4Config] for ark_ff::Fp4<C>);
148impl_deserialize!(impl [C: ark_ff::Fp6Config] for ark_ff::Fp6<C>);
149impl_deserialize!(impl [C: ark_ff::Fp12Config] for ark_ff::Fp12<C>);
150impl_deserialize!(impl [P: SmallFpConfig] for SmallFp<P>);
151// Implement Encoding for prime-order field and field extensions.
152// The NargSerialize implementation is inherited here.
153impl_encoding!(impl [C: FpConfig<N>, const N: usize] for Fp<C, N>);
154impl_encoding!(impl [C: ark_ff::Fp2Config] for ark_ff::Fp2<C>);
155impl_encoding!(impl [C: ark_ff::Fp3Config] for ark_ff::Fp3<C>);
156impl_encoding!(impl [C: ark_ff::Fp4Config] for ark_ff::Fp4<C>);
157impl_encoding!(impl [C: ark_ff::Fp6Config] for ark_ff::Fp6<C>);
158impl_encoding!(impl [C: ark_ff::Fp12Config] for ark_ff::Fp12<C>);
159impl_encoding!(impl [P: SmallFpConfig] for SmallFp<P>);
160// Implement Decoding for prime-order fields and field extensions.
161impl_decoding!(impl [C: FpConfig<N>, const N: usize] for Fp<C, N>);
162impl_decoding!(impl [C: ark_ff::Fp2Config] for ark_ff::Fp2<C>);
163impl_decoding!(impl [C: ark_ff::Fp3Config] for ark_ff::Fp3<C>);
164impl_decoding!(impl [C: ark_ff::Fp4Config] for ark_ff::Fp4<C>);
165impl_decoding!(impl [C: ark_ff::Fp6Config] for ark_ff::Fp6<C>);
166impl_decoding!(impl [C: ark_ff::Fp12Config] for ark_ff::Fp12<C>);
167impl_decoding!(impl [P: SmallFpConfig] for SmallFp<P>);
168
169/// Number of uniformly random bits in a uniformly-distributed element in `[0, b)`
170///
171/// This function returns the maximum n for which
172/// `Uniform([b]) mod 2^n`
173/// and
174/// `Uniform([2^n])`
175/// are statistically indistinguishable.
176/// Given \(b = q 2^n + r\) the statistical distance
177/// is \(\frac{2r}{ab}(a-r)\).
178#[allow(unused)]
179fn random_bits_in_random_modp<const N: usize>(b: ark_ff::BigInt<N>) -> usize {
180    use ark_ff::{BigInt, BigInteger};
181    // XXX. is it correct to have num_bits+1 here?
182    for n in (0..=b.num_bits()).rev() {
183        // compute the remainder of b by 2^n
184        let r_bits = &b.to_bits_le()[..n as usize];
185        let r = BigInt::<N>::from_bits_le(r_bits);
186        let log2_a_minus_r = r_bits.iter().rev().skip_while(|&&bit| bit).count() as u32;
187        if b.num_bits() + n - 1 - r.num_bits() - log2_a_minus_r >= 128 {
188            return n as usize;
189        }
190    }
191    0
192}
193
194impl<F: Field> Default for DecodingFieldBuffer<F> {
195    fn default() -> Self {
196        let base_field_modulus_bytes = u64::from(F::BasePrimeField::MODULUS_BIT_SIZE.div_ceil(8));
197        // Get 32 bytes of extra randomness for every base field element in the extension
198        let len = (base_field_modulus_bytes + 32) * F::extension_degree();
199        Self {
200            buf: vec![0u8; len as usize],
201            _phantom: PhantomData,
202        }
203    }
204}
205
206impl<F: Field> AsMut<[u8]> for DecodingFieldBuffer<F> {
207    fn as_mut(&mut self) -> &mut [u8] {
208        self.buf.as_mut()
209    }
210}
211
212#[cfg(test)]
213mod test_ark_ff {
214    use ark_ff::{BigInteger, PrimeField};
215
216    use crate::{
217        codecs::Encoding,
218        io::{NargDeserialize, NargSerialize},
219    };
220
221    // ----- SmallFp test fields -----
222
223    // Goldilocks field: p = 2^64 - 2^32 + 1
224    ark_ff::define_field!(
225        modulus = "18446744069414584321",
226        generator = "7",
227        name = Goldilocks,
228    );
229
230    // Mersenne31 field: p = 2^31 - 1
231    ark_ff::define_field!(modulus = "2147483647", generator = "7", name = M31,);
232
233    // BabyBear field: p = 15 * 2^27 + 1
234    ark_ff::define_field!(modulus = "2013265921", generator = "31", name = BabyBear,);
235
236    // KoalaBear field: p = 2^31 - 2^24 + 1
237    ark_ff::define_field!(modulus = "2130706433", generator = "3", name = KoalaBear,);
238
239    // A 16-bit test field: p = 65521 (largest 16-bit prime)
240    ark_ff::define_field!(modulus = "65521", generator = "17", name = F16,);
241
242    // ----- Encoding / serialization round-trip tests -----
243
244    /// Encode → serialize → deserialize round-trip, testing zero, one, p-1,
245    /// and a handful of interior values.
246    fn roundtrip_testsuite<F>()
247    where
248        F: ark_ff::PrimeField
249            + Encoding<[u8]>
250            + crate::io::NargSerialize
251            + crate::io::NargDeserialize,
252    {
253        for v in [0u64, 1, 42, 12345] {
254            let original = F::from(v);
255            let serialized = encode_to_vec(&original);
256            let mut slice: &[u8] = &serialized;
257            let deserialized = F::deserialize_from_narg(&mut slice)
258                .unwrap_or_else(|_| panic!("failed to deserialize value {v}"));
259            assert!(
260                slice.is_empty(),
261                "deserialize did not consume all bytes for value {v}"
262            );
263            assert_eq!(original, deserialized, "roundtrip mismatch for {v}");
264        }
265
266        // p - 1 (the largest valid element)
267        let p_minus_1 = -F::ONE;
268        let ser = encode_to_vec(&p_minus_1);
269        let mut sl: &[u8] = &ser;
270        let de = F::deserialize_from_narg(&mut sl).expect("p-1 should deserialize");
271        assert!(sl.is_empty());
272        assert_eq!(de, p_minus_1);
273    }
274
275    fn encode_to_vec<F: Encoding<[u8]>>(x: &F) -> alloc::vec::Vec<u8> {
276        let mut dst = alloc::vec::Vec::new();
277        x.serialize_into_narg(&mut dst);
278        dst
279    }
280
281    /// Encoding the same value twice must produce identical bytes.
282    fn deterministic_encoding_testsuite<F: ark_ff::Field + Encoding<[u8]>>() {
283        for v in [0u64, 1, 42, 12345] {
284            let elem = F::from(v);
285            let a = encode_to_vec(&elem);
286            let b = encode_to_vec(&elem);
287            assert_eq!(a, b, "encoding not deterministic for {v}");
288        }
289    }
290
291    /// Distinct values must encode differently.
292    fn distinct_values_encode_differently<F: ark_ff::PrimeField + Encoding<[u8]>>() {
293        let zero = encode_to_vec(&F::ZERO);
294        let one = encode_to_vec(&F::ONE);
295        let p_minus_1 = encode_to_vec(&(-F::ONE));
296
297        assert_ne!(zero, one);
298        assert_ne!(one, p_minus_1);
299        assert_ne!(zero, p_minus_1);
300    }
301
302    /// Deserializing p (the modulus itself) must fail — the encoding
303    /// is not canonical because p ≡ 0 and 0 already has its own encoding.
304    fn reject_modulus<F: ark_ff::PrimeField + core::fmt::Debug + crate::io::NargDeserialize>() {
305        let modulus_bytes = F::MODULUS.to_bytes_be();
306        // Keep only the trailing ⌈MODULUS_BIT_SIZE/8⌉ bytes; the backing BigInt
307        // can be wider than the field (e.g. F16 inside SmallFp's BigInt<1>).
308        let field_size = F::MODULUS_BIT_SIZE.div_ceil(8) as usize;
309        let start = modulus_bytes.len().saturating_sub(field_size);
310        let trimmed = &modulus_bytes[start..];
311        let mut sl: &[u8] = trimmed;
312        assert!(
313            F::deserialize_from_narg(&mut sl).is_err(),
314            "deserializing p should fail (modulus_bits={}, field_size={field_size}, trimmed={trimmed:?})",
315            F::MODULUS_BIT_SIZE,
316        );
317    }
318
319    /// A single bit-flip must either change the decoded value or cause rejection.
320    fn bitflip_testsuite<F>()
321    where
322        F: ark_ff::PrimeField + Encoding<[u8]> + crate::io::NargDeserialize,
323    {
324        let original = F::from(42u64);
325        let encoded = encode_to_vec(&original);
326
327        for byte_idx in 0..encoded.len() {
328            for bit in 0..8u8 {
329                let mut flipped = encoded.clone();
330                flipped[byte_idx] ^= 1 << bit;
331                let mut sl: &[u8] = &flipped;
332                if let Ok(v) = F::deserialize_from_narg(&mut sl) {
333                    assert_ne!(
334                        v, original,
335                        "bit-flip at byte {byte_idx} bit {bit} decoded to same value"
336                    );
337                } // rejection is fine
338            }
339        }
340    }
341
342    /// Truncated buffer must be rejected.
343    fn wrong_length_testsuite<F>()
344    where
345        F: ark_ff::PrimeField + Encoding<[u8]> + crate::io::NargDeserialize,
346    {
347        let encoded = encode_to_vec(&F::from(1u64));
348
349        // Truncated: one byte short
350        if !encoded.is_empty() {
351            let short = &encoded[..encoded.len() - 1];
352            let mut sl: &[u8] = short;
353            assert!(
354                F::deserialize_from_narg(&mut sl).is_err(),
355                "truncated buffer should fail"
356            );
357        }
358    }
359
360    #[test]
361    fn test_smallfp_roundtrip() {
362        roundtrip_testsuite::<Goldilocks>();
363        roundtrip_testsuite::<M31>();
364        roundtrip_testsuite::<BabyBear>();
365        roundtrip_testsuite::<KoalaBear>();
366        roundtrip_testsuite::<F16>();
367    }
368
369    #[test]
370    fn test_smallfp_deterministic_encoding() {
371        deterministic_encoding_testsuite::<Goldilocks>();
372        deterministic_encoding_testsuite::<M31>();
373        deterministic_encoding_testsuite::<BabyBear>();
374        deterministic_encoding_testsuite::<KoalaBear>();
375        deterministic_encoding_testsuite::<F16>();
376    }
377
378    #[test]
379    fn test_smallfp_distinct_values_encode_differently() {
380        distinct_values_encode_differently::<Goldilocks>();
381        distinct_values_encode_differently::<M31>();
382        distinct_values_encode_differently::<BabyBear>();
383        distinct_values_encode_differently::<KoalaBear>();
384        distinct_values_encode_differently::<F16>();
385    }
386
387    #[test]
388    fn test_smallfp_reject_modulus() {
389        reject_modulus::<Goldilocks>();
390        reject_modulus::<M31>();
391        reject_modulus::<BabyBear>();
392        reject_modulus::<KoalaBear>();
393        // F16 modulus is 65521, which fits in 2 bytes. Encoding is 2 BE bytes.
394        reject_modulus::<F16>();
395    }
396
397    #[test]
398    fn test_smallfp_bitflip() {
399        bitflip_testsuite::<Goldilocks>();
400        bitflip_testsuite::<M31>();
401        bitflip_testsuite::<BabyBear>();
402        bitflip_testsuite::<KoalaBear>();
403        bitflip_testsuite::<F16>();
404    }
405
406    #[test]
407    fn test_smallfp_wrong_length() {
408        wrong_length_testsuite::<Goldilocks>();
409        wrong_length_testsuite::<M31>();
410        wrong_length_testsuite::<BabyBear>();
411        wrong_length_testsuite::<KoalaBear>();
412        wrong_length_testsuite::<F16>();
413    }
414
415    // ----- MontFp (large field) tests -----
416
417    #[test]
418    fn test_montfp_roundtrip() {
419        roundtrip_testsuite::<ark_bls12_381::Fr>();
420        roundtrip_testsuite::<ark_bls12_381::Fq>();
421    }
422
423    #[test]
424    fn test_montfp_reject_modulus() {
425        reject_modulus::<ark_bls12_381::Fr>();
426        reject_modulus::<ark_bls12_381::Fq>();
427    }
428
429    #[test]
430    fn test_montfp_bitflip() {
431        bitflip_testsuite::<ark_bls12_381::Fr>();
432    }
433
434    // ----- SmallFp extension field (Fp2) -----
435
436    pub struct GoldilocksFp2Config;
437    impl ark_ff::Fp2Config for GoldilocksFp2Config {
438        type Fp = Goldilocks;
439
440        // 7 is a quadratic non-residue mod Goldilocks
441        const NONRESIDUE: Self::Fp = ark_ff::SmallFp::from_raw(7);
442
443        const FROBENIUS_COEFF_FP2_C1: &'static [Self::Fp] = &[
444            // 7^(((q^0) - 1) / 2) = 1
445            ark_ff::SmallFp::from_raw(1),
446            // 7^(((q^1) - 1) / 2) = p - 1
447            ark_ff::SmallFp::from_raw(18_446_744_069_414_584_320),
448        ];
449    }
450    pub type GoldilocksFp2 = ark_ff::Fp2<GoldilocksFp2Config>;
451
452    #[test]
453    fn test_encoding_small_fp_goldilocks_fp2() {
454        deterministic_encoding_testsuite::<GoldilocksFp2>();
455    }
456
457    #[test]
458    fn test_prime_field_encoding_is_left_padded_big_endian() {
459        let value = ark_secp256k1::Fr::from(1u64);
460        let encoded = Encoding::<[u8]>::encode(&value);
461        let bytes = encoded.as_ref();
462
463        assert_eq!(bytes.len(), 32);
464        assert!(bytes[..31].iter().all(|&byte| byte == 0));
465        assert_eq!(bytes[31], 1);
466    }
467
468    #[test]
469    fn test_prime_field_deserialize_rejects_modulus() {
470        let modulus = ark_secp256k1::Fr::MODULUS.to_bytes_be();
471        let mut slice = modulus.as_slice();
472
473        assert!(ark_secp256k1::Fr::deserialize_from_narg(&mut slice).is_err());
474        assert_eq!(slice, modulus.as_slice());
475    }
476}