threshold_bls/curve/
bn254.rs

1use crate::group::{self, Element, PairingCurve as PC, Point, Scalar as Sc};
2use crate::hash::hasher::Keccak256Hasher;
3use crate::hash::try_and_increment::TryAndIncrement;
4use crate::hash::HashToCurve;
5use crate::serialize::ContractSerialize;
6use ark_bn254 as bn254;
7use ark_ec::{PairingEngine, ProjectiveCurve};
8use ark_ff::PrimeField;
9use ark_ff::{Field, One, UniformRand, Zero};
10use rand_core::RngCore;
11use serde::{
12    de::{Error as DeserializeError, SeqAccess, Visitor},
13    ser::{Error as SerializationError, SerializeTuple},
14    Deserialize, Deserializer, Serialize, Serializer,
15};
16use std::{
17    fmt,
18    marker::PhantomData,
19    ops::{AddAssign, MulAssign, Neg, SubAssign},
20};
21
22use thiserror::Error;
23
24use super::{BLSError, CurveType};
25
26#[derive(Debug, Error)]
27pub enum BNError {
28    #[error("{0}")]
29    SerializationError(#[from] ark_serialize::SerializationError),
30    #[error("{0}")]
31    BLSError(#[from] BLSError),
32}
33
34#[derive(Debug, Clone, Copy, Eq, PartialEq, Deserialize, Serialize)]
35pub struct Scalar(
36    #[serde(deserialize_with = "deserialize_field")]
37    #[serde(serialize_with = "serialize_field")]
38    <bn254::Bn254 as PairingEngine>::Fr,
39);
40
41type ZG1 = <bn254::Bn254 as PairingEngine>::G1Projective;
42
43#[derive(Debug, Clone, Copy, Eq, PartialEq)]
44pub struct G1(pub(crate) ZG1);
45
46type ZG2 = <bn254::Bn254 as PairingEngine>::G2Projective;
47
48#[derive(Debug, Clone, Copy, Eq, PartialEq)]
49pub struct G2(pub(crate) ZG2);
50
51impl Serialize for G1 {
52    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
53    where
54        S: Serializer,
55    {
56        let bytes = self
57            .serialize_to_contract_form()
58            .map_err(SerializationError::custom)?;
59
60        let mut tup = serializer.serialize_tuple(32)?;
61        for byte in &bytes {
62            tup.serialize_element(byte)?;
63        }
64        tup.end()
65    }
66}
67impl<'de> Deserialize<'de> for G1 {
68    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
69    where
70        D: Deserializer<'de>,
71    {
72        struct G1Visitor;
73
74        impl<'de> Visitor<'de> for G1Visitor {
75            type Value = G1;
76
77            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
78                formatter.write_str("a valid group element")
79            }
80
81            fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
82            where
83                S: SeqAccess<'de>,
84            {
85                let bytes: Vec<u8> = (0..32)
86                    .map(|_| {
87                        seq.next_element()?
88                            .ok_or_else(|| DeserializeError::custom("could not read bytes"))
89                    })
90                    .collect::<Result<Vec<_>, _>>()?;
91
92                let ele =
93                    G1::deserialize_from_contract_form(&bytes).map_err(DeserializeError::custom)?;
94                Ok(ele)
95            }
96        }
97
98        deserializer.deserialize_tuple(32, G1Visitor)
99    }
100}
101
102impl Serialize for G2 {
103    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
104    where
105        S: Serializer,
106    {
107        let bytes = self
108            .serialize_to_contract_form()
109            .map_err(SerializationError::custom)?;
110
111        let mut tup = serializer.serialize_tuple(128)?;
112        for byte in &bytes {
113            tup.serialize_element(byte)?;
114        }
115        tup.end()
116    }
117}
118impl<'de> Deserialize<'de> for G2 {
119    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
120    where
121        D: Deserializer<'de>,
122    {
123        struct G2Visitor;
124
125        impl<'de> Visitor<'de> for G2Visitor {
126            type Value = G2;
127
128            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
129                formatter.write_str("a valid group element")
130            }
131
132            fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
133            where
134                S: SeqAccess<'de>,
135            {
136                let bytes: Vec<u8> = (0..128)
137                    .map(|_| {
138                        seq.next_element()?
139                            .ok_or_else(|| DeserializeError::custom("could not read bytes"))
140                    })
141                    .collect::<Result<Vec<_>, _>>()?;
142
143                let ele =
144                    G2::deserialize_from_contract_form(&bytes).map_err(DeserializeError::custom)?;
145                Ok(ele)
146            }
147        }
148
149        deserializer.deserialize_tuple(128, G2Visitor)
150    }
151}
152
153#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
154pub struct GT(
155    #[serde(deserialize_with = "deserialize_field")]
156    #[serde(serialize_with = "serialize_field")]
157    <bn254::Bn254 as PairingEngine>::Fqk,
158);
159
160impl Element for Scalar {
161    type RHS = Scalar;
162
163    fn new() -> Self {
164        Self(Zero::zero())
165    }
166
167    fn one() -> Self {
168        Self(One::one())
169    }
170
171    fn add(&mut self, s2: &Self) {
172        self.0.add_assign(s2.0);
173    }
174
175    fn mul(&mut self, mul: &Scalar) {
176        self.0.mul_assign(mul.0)
177    }
178
179    fn rand<R: rand_core::RngCore>(rng: &mut R) -> Self {
180        Self(bn254::Fr::rand(rng))
181    }
182}
183
184impl Sc for Scalar {
185    fn set_int(&mut self, i: u64) {
186        *self = Self(bn254::Fr::from(i))
187    }
188
189    fn inverse(&self) -> Option<Self> {
190        Some(Self(Field::inverse(&self.0)?))
191    }
192
193    fn negate(&mut self) {
194        *self = Self(self.0.neg())
195    }
196
197    fn sub(&mut self, other: &Self) {
198        self.0.sub_assign(other.0);
199    }
200}
201
202impl fmt::Display for Scalar {
203    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
204        write!(f, "{{{:?}}}", self.0)
205    }
206}
207
208/// G1 points can be multiplied by Fr elements
209impl Element for G1 {
210    type RHS = Scalar;
211
212    fn new() -> Self {
213        Self(Zero::zero())
214    }
215
216    fn one() -> Self {
217        Self(ZG1::prime_subgroup_generator())
218    }
219
220    fn rand<R: RngCore>(rng: &mut R) -> Self {
221        Self(ZG1::rand(rng))
222    }
223
224    fn add(&mut self, s2: &Self) {
225        self.0.add_assign(s2.0);
226    }
227
228    fn mul(&mut self, mul: &Scalar) {
229        self.0.mul_assign(mul.0);
230    }
231}
232
233/// Implementation of Point using G1 from BN254
234impl Point for G1 {
235    type Error = BNError;
236
237    fn map(&mut self, data: &[u8]) -> Result<(), BNError> {
238        let hasher = TryAndIncrement::new(&Keccak256Hasher);
239
240        let hash = hasher.hash(&[], data)?;
241
242        *self = Self(hash);
243
244        Ok(())
245    }
246}
247
248impl fmt::Display for G1 {
249    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
250        write!(f, "{{{:?}}}", self.0)
251    }
252}
253
254/// G1 points can be multiplied by Fr elements
255impl Element for G2 {
256    type RHS = Scalar;
257
258    fn new() -> Self {
259        Self(Zero::zero())
260    }
261
262    fn one() -> Self {
263        Self(ZG2::prime_subgroup_generator())
264    }
265
266    fn rand<R: RngCore>(mut rng: &mut R) -> Self {
267        Self(ZG2::rand(&mut rng))
268    }
269
270    fn add(&mut self, s2: &Self) {
271        self.0.add_assign(s2.0);
272    }
273
274    fn mul(&mut self, mul: &Scalar) {
275        self.0.mul_assign(mul.0)
276    }
277}
278
279/// Implementation of Point using G2 from BN254
280impl Point for G2 {
281    type Error = BNError;
282
283    fn map(&mut self, data: &[u8]) -> Result<(), BNError> {
284        let hasher = TryAndIncrement::new(&Keccak256Hasher);
285
286        let hash = hasher.hash(&[], data)?;
287
288        *self = Self(hash);
289
290        Ok(())
291    }
292}
293
294impl fmt::Display for G2 {
295    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
296        write!(f, "{{{:?}}}", self.0)
297    }
298}
299
300impl Element for GT {
301    type RHS = Scalar;
302
303    fn new() -> Self {
304        Self(One::one())
305    }
306    fn one() -> Self {
307        Self(One::one())
308    }
309    fn add(&mut self, s2: &Self) {
310        self.0.mul_assign(s2.0);
311    }
312    fn mul(&mut self, mul: &Scalar) {
313        let scalar = mul.0.into_repr();
314        let mut res = Self::one();
315        let mut temp = *self;
316        for b in ark_ff::BitIteratorLE::without_trailing_zeros(scalar) {
317            if b {
318                res.0.mul_assign(temp.0);
319            }
320            temp.0.square_in_place();
321        }
322        *self = res;
323    }
324    fn rand<R: RngCore>(rng: &mut R) -> Self {
325        Self(bn254::Fq12::rand(rng))
326    }
327}
328
329impl fmt::Display for GT {
330    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
331        write!(f, "{{{:?}}}", self.0)
332    }
333}
334
335pub type G1Curve = group::G1Curve<PairingCurve>;
336pub type G2Curve = group::G2Curve<PairingCurve>;
337
338#[derive(Clone, Debug, Serialize)]
339pub struct PairingCurve;
340
341impl PC for PairingCurve {
342    type Scalar = Scalar;
343    type G1 = G1;
344    type G2 = G2;
345    type GT = GT;
346
347    fn pair(a: &Self::G1, b: &Self::G2) -> Self::GT {
348        GT(<bn254::Bn254 as PairingEngine>::pairing(a.0, b.0))
349    }
350}
351
352// Serde implementations (ideally, these should be upstreamed to Zexe)
353
354fn deserialize_field<'de, D, C>(deserializer: D) -> Result<C, D::Error>
355where
356    D: Deserializer<'de>,
357    C: Field,
358{
359    struct FieldVisitor<C>(PhantomData<C>);
360
361    impl<'de, C> Visitor<'de> for FieldVisitor<C>
362    where
363        C: Field,
364    {
365        type Value = C;
366
367        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
368            formatter.write_str("a valid group element")
369        }
370
371        fn visit_seq<S>(self, mut seq: S) -> Result<C, S::Error>
372        where
373            S: SeqAccess<'de>,
374        {
375            let len = C::zero().serialized_size();
376            let bytes: Vec<u8> = (0..len)
377                .map(|_| {
378                    seq.next_element()?
379                        .ok_or_else(|| DeserializeError::custom("could not read bytes"))
380                })
381                .collect::<Result<Vec<_>, _>>()?;
382
383            let res = C::deserialize(&mut &bytes[..]).map_err(DeserializeError::custom)?;
384            Ok(res)
385        }
386    }
387
388    let visitor = FieldVisitor(PhantomData);
389    deserializer.deserialize_tuple(C::zero().serialized_size(), visitor)
390}
391
392fn serialize_field<S, C>(c: &C, s: S) -> Result<S::Ok, S::Error>
393where
394    S: Serializer,
395    C: Field,
396{
397    let len = c.serialized_size();
398    let mut bytes = Vec::with_capacity(len);
399    c.serialize(&mut bytes)
400        .map_err(SerializationError::custom)?;
401
402    let mut tup = s.serialize_tuple(len)?;
403    for byte in &bytes {
404        tup.serialize_element(byte)?;
405    }
406    tup.end()
407}
408
409#[derive(Clone, Debug)]
410pub struct BN254Curve;
411
412impl CurveType for BN254Curve {
413    type G1Curve = G1Curve;
414
415    type G2Curve = G2Curve;
416
417    type PairingCurve = PairingCurve;
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423    use serde::de::DeserializeOwned;
424    use static_assertions::assert_impl_all;
425
426    assert_impl_all!(G1: Serialize, DeserializeOwned, Clone);
427    assert_impl_all!(G2: Serialize, DeserializeOwned, Clone);
428    assert_impl_all!(GT: Serialize, DeserializeOwned, Clone);
429    assert_impl_all!(Scalar: Serialize, DeserializeOwned, Clone);
430
431    #[test]
432    fn serialize_group() {
433        for _ in 0..10 {
434            serialize_group_test::<G1>(32);
435            serialize_group_test::<G2>(128);
436        }
437    }
438
439    fn serialize_group_test<E: Element>(size: usize) {
440        let empty = bincode::deserialize::<E>(&[]);
441        assert!(empty.is_err());
442
443        let rng = &mut rand::thread_rng();
444        let sig = E::rand(rng);
445        let ser = bincode::serialize(&sig).unwrap();
446        assert_eq!(ser.len(), size);
447
448        let de: E = bincode::deserialize(&ser).unwrap();
449        assert_eq!(de, sig);
450    }
451
452    #[test]
453    fn serialize_field() {
454        serialize_field_test::<GT>(384);
455        serialize_field_test::<Scalar>(32);
456    }
457
458    fn serialize_field_test<E: Element>(size: usize) {
459        let rng = &mut rand::thread_rng();
460        let sig = E::rand(rng);
461        let ser = bincode::serialize(&sig).unwrap();
462        assert_eq!(ser.len(), size);
463
464        let de: E = bincode::deserialize(&ser).unwrap();
465        assert_eq!(de, sig);
466    }
467
468    #[test]
469    fn gt_exp() {
470        let rng = &mut rand::thread_rng();
471        let base = GT::rand(rng);
472
473        let mut sc = Scalar::one();
474        sc.add(&Scalar::one());
475        sc.add(&Scalar::one());
476
477        let mut exp = base.clone();
478        exp.mul(&sc);
479
480        let mut res = base.clone();
481        res.add(&base);
482        res.add(&base);
483
484        assert_eq!(exp, res);
485    }
486}