threshold_bls/curve/
bls12381.rs

1use crate::group::{self, Element, PairingCurve as PC, Point, Scalar as Sc};
2use crate::hash::hasher::Keccak256Hasher;
3use crate::hash::try_and_increment::TryAndIncrement;
4use crate::hash::HashToCurve;
5use ark_bls12_381 as bls12_381;
6use ark_ec::{AffineCurve, PairingEngine, ProjectiveCurve};
7use ark_ff::PrimeField;
8use ark_ff::{Field, One, UniformRand, Zero};
9use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
10use rand_core::RngCore;
11use serde::{
12    de::{Error as DeserializeError, SeqAccess, Visitor},
13    ser::{Error as SerializationError, SerializeTuple},
14    Deserialize, Deserializer, Serialize, Serializer,
15};
16use std::{
17    fmt,
18    marker::PhantomData,
19    ops::{AddAssign, MulAssign, Neg, SubAssign},
20};
21
22use thiserror::Error;
23
24use super::{BLSError, CurveType};
25
26#[derive(Debug, Error)]
27pub enum BLS12Error {
28    #[error("{0}")]
29    SerializationError(#[from] ark_serialize::SerializationError),
30    #[error("{0}")]
31    BLSError(#[from] BLSError),
32}
33
34#[derive(Debug, Clone, Copy, Eq, PartialEq, Deserialize, Serialize)]
35pub struct Scalar(
36    #[serde(deserialize_with = "deserialize_field")]
37    #[serde(serialize_with = "serialize_field")]
38    <bls12_381::Bls12_381 as PairingEngine>::Fr,
39);
40
41type ZG1 = <bls12_381::Bls12_381 as PairingEngine>::G1Projective;
42
43#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
44pub struct G1(
45    #[serde(deserialize_with = "deserialize_group")]
46    #[serde(serialize_with = "serialize_group")]
47    ZG1,
48);
49
50type ZG2 = <bls12_381::Bls12_381 as PairingEngine>::G2Projective;
51
52#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
53pub struct G2(
54    #[serde(deserialize_with = "deserialize_group")]
55    #[serde(serialize_with = "serialize_group")]
56    ZG2,
57);
58
59#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
60pub struct GT(
61    #[serde(deserialize_with = "deserialize_field")]
62    #[serde(serialize_with = "serialize_field")]
63    <bls12_381::Bls12_381 as PairingEngine>::Fqk,
64);
65
66impl Element for Scalar {
67    type RHS = Scalar;
68
69    fn new() -> Self {
70        Self(Zero::zero())
71    }
72
73    fn one() -> Self {
74        Self(One::one())
75    }
76
77    fn add(&mut self, s2: &Self) {
78        self.0.add_assign(s2.0);
79    }
80
81    fn mul(&mut self, mul: &Scalar) {
82        self.0.mul_assign(mul.0)
83    }
84
85    fn rand<R: rand_core::RngCore>(rng: &mut R) -> Self {
86        Self(bls12_381::Fr::rand(rng))
87    }
88}
89
90impl Sc for Scalar {
91    fn set_int(&mut self, i: u64) {
92        *self = Self(bls12_381::Fr::from(i))
93    }
94
95    fn inverse(&self) -> Option<Self> {
96        Some(Self(Field::inverse(&self.0)?))
97    }
98
99    fn negate(&mut self) {
100        *self = Self(self.0.neg())
101    }
102
103    fn sub(&mut self, other: &Self) {
104        self.0.sub_assign(other.0);
105    }
106}
107
108impl fmt::Display for Scalar {
109    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
110        write!(f, "{{{:?}}}", self.0)
111    }
112}
113
114/// G1 points can be multiplied by Fr elements
115impl Element for G1 {
116    type RHS = Scalar;
117
118    fn new() -> Self {
119        Self(Zero::zero())
120    }
121
122    fn one() -> Self {
123        Self(ZG1::prime_subgroup_generator())
124    }
125
126    fn rand<R: RngCore>(rng: &mut R) -> Self {
127        Self(ZG1::rand(rng))
128    }
129
130    fn add(&mut self, s2: &Self) {
131        self.0.add_assign(s2.0);
132    }
133
134    fn mul(&mut self, mul: &Scalar) {
135        self.0.mul_assign(mul.0);
136    }
137}
138
139/// Implementation of Point using G1 from BLS12_381
140impl Point for G1 {
141    type Error = BLS12Error;
142
143    fn map(&mut self, data: &[u8]) -> Result<(), BLS12Error> {
144        let hasher = TryAndIncrement::new(&Keccak256Hasher);
145
146        let hash = hasher.hash(&[], data)?;
147
148        *self = Self(hash);
149
150        Ok(())
151    }
152}
153
154impl fmt::Display for G1 {
155    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
156        write!(f, "{{{:?}}}", self.0)
157    }
158}
159
160/// G1 points can be multiplied by Fr elements
161impl Element for G2 {
162    type RHS = Scalar;
163
164    fn new() -> Self {
165        Self(Zero::zero())
166    }
167
168    fn one() -> Self {
169        Self(ZG2::prime_subgroup_generator())
170    }
171
172    fn rand<R: RngCore>(mut rng: &mut R) -> Self {
173        Self(ZG2::rand(&mut rng))
174    }
175
176    fn add(&mut self, s2: &Self) {
177        self.0.add_assign(s2.0);
178    }
179
180    fn mul(&mut self, mul: &Scalar) {
181        self.0.mul_assign(mul.0)
182    }
183}
184
185/// Implementation of Point using G2 from BLS12_381
186impl Point for G2 {
187    type Error = BLS12Error;
188
189    fn map(&mut self, data: &[u8]) -> Result<(), BLS12Error> {
190        let hasher = TryAndIncrement::new(&Keccak256Hasher);
191
192        let hash = hasher.hash(&[], data)?;
193
194        *self = Self(hash);
195
196        Ok(())
197    }
198}
199
200impl fmt::Display for G2 {
201    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
202        write!(f, "{{{:?}}}", self.0)
203    }
204}
205
206impl Element for GT {
207    type RHS = Scalar;
208
209    fn new() -> Self {
210        Self(One::one())
211    }
212    fn one() -> Self {
213        Self(One::one())
214    }
215    fn add(&mut self, s2: &Self) {
216        self.0.mul_assign(s2.0);
217    }
218    fn mul(&mut self, mul: &Scalar) {
219        let scalar = mul.0.into_repr();
220        let mut res = Self::one();
221        let mut temp = *self;
222        for b in ark_ff::BitIteratorLE::without_trailing_zeros(scalar) {
223            if b {
224                res.0.mul_assign(temp.0);
225            }
226            temp.0.square_in_place();
227        }
228        *self = res;
229    }
230    fn rand<R: RngCore>(rng: &mut R) -> Self {
231        Self(bls12_381::Fq12::rand(rng))
232    }
233}
234
235impl fmt::Display for GT {
236    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
237        write!(f, "{{{:?}}}", self.0)
238    }
239}
240
241pub type G1Curve = group::G1Curve<PairingCurve>;
242pub type G2Curve = group::G2Curve<PairingCurve>;
243
244#[derive(Clone, Debug)]
245pub struct PairingCurve;
246
247impl PC for PairingCurve {
248    type Scalar = Scalar;
249    type G1 = G1;
250    type G2 = G2;
251    type GT = GT;
252
253    fn pair(a: &Self::G1, b: &Self::G2) -> Self::GT {
254        GT(<bls12_381::Bls12_381 as PairingEngine>::pairing(a.0, b.0))
255    }
256}
257
258// Serde implementations (ideally, these should be upstreamed to Zexe)
259
260fn deserialize_field<'de, D, C>(deserializer: D) -> Result<C, D::Error>
261where
262    D: Deserializer<'de>,
263    C: Field,
264{
265    struct FieldVisitor<C>(PhantomData<C>);
266
267    impl<'de, C> Visitor<'de> for FieldVisitor<C>
268    where
269        C: Field,
270    {
271        type Value = C;
272
273        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
274            formatter.write_str("a valid group element")
275        }
276
277        fn visit_seq<S>(self, mut seq: S) -> Result<C, S::Error>
278        where
279            S: SeqAccess<'de>,
280        {
281            let len = C::zero().serialized_size();
282            let bytes: Vec<u8> = (0..len)
283                .map(|_| {
284                    seq.next_element()?
285                        .ok_or_else(|| DeserializeError::custom("could not read bytes"))
286                })
287                .collect::<Result<Vec<_>, _>>()?;
288
289            let res = C::deserialize(&mut &bytes[..]).map_err(DeserializeError::custom)?;
290            Ok(res)
291        }
292    }
293
294    let visitor = FieldVisitor(PhantomData);
295    deserializer.deserialize_tuple(C::zero().serialized_size(), visitor)
296}
297
298fn serialize_field<S, C>(c: &C, s: S) -> Result<S::Ok, S::Error>
299where
300    S: Serializer,
301    C: Field,
302{
303    let len = c.serialized_size();
304    let mut bytes = Vec::with_capacity(len);
305    c.serialize(&mut bytes)
306        .map_err(SerializationError::custom)?;
307
308    let mut tup = s.serialize_tuple(len)?;
309    for byte in &bytes {
310        tup.serialize_element(byte)?;
311    }
312    tup.end()
313}
314
315fn deserialize_group<'de, D, C>(deserializer: D) -> Result<C, D::Error>
316where
317    D: Deserializer<'de>,
318    C: ProjectiveCurve,
319    C::Affine: CanonicalDeserialize + CanonicalSerialize,
320{
321    struct GroupVisitor<C>(PhantomData<C>);
322
323    impl<'de, C> Visitor<'de> for GroupVisitor<C>
324    where
325        C: ProjectiveCurve,
326        //C::Affine: CanonicalDeserialize + CanonicalSerialize,
327    {
328        type Value = C;
329
330        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
331            formatter.write_str("a valid group element")
332        }
333
334        fn visit_seq<S>(self, mut seq: S) -> Result<C, S::Error>
335        where
336            S: SeqAccess<'de>,
337        {
338            let len = C::Affine::zero().serialized_size(); //C::Affine::SERIALIZED_SIZE;
339            let bytes: Vec<u8> = (0..len)
340                .map(|_| {
341                    seq.next_element()?
342                        .ok_or_else(|| DeserializeError::custom("could not read bytes"))
343                })
344                .collect::<Result<Vec<_>, _>>()?;
345
346            let affine =
347                C::Affine::deserialize(&mut &bytes[..]).map_err(DeserializeError::custom)?;
348            Ok(affine.into_projective())
349        }
350    }
351
352    let visitor = GroupVisitor(PhantomData);
353    deserializer.deserialize_tuple(C::Affine::zero().serialized_size(), visitor)
354}
355
356fn serialize_group<S, C>(c: &C, s: S) -> Result<S::Ok, S::Error>
357where
358    S: Serializer,
359    C: ProjectiveCurve,
360    C::Affine: CanonicalSerialize,
361{
362    let affine = c.into_affine();
363    let len = affine.serialized_size();
364    let mut bytes = Vec::with_capacity(len);
365    affine
366        .serialize(&mut bytes)
367        .map_err(SerializationError::custom)?;
368
369    let mut tup = s.serialize_tuple(len)?;
370    for byte in &bytes {
371        tup.serialize_element(byte)?;
372    }
373    tup.end()
374}
375
376#[derive(Clone, Debug)]
377pub struct BLS12381Curve;
378
379impl CurveType for BLS12381Curve {
380    type G1Curve = G1Curve;
381
382    type G2Curve = G2Curve;
383
384    type PairingCurve = PairingCurve;
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use serde::de::DeserializeOwned;
391    use static_assertions::assert_impl_all;
392
393    assert_impl_all!(G1: Serialize, DeserializeOwned, Clone);
394    assert_impl_all!(G2: Serialize, DeserializeOwned, Clone);
395    assert_impl_all!(GT: Serialize, DeserializeOwned, Clone);
396    assert_impl_all!(Scalar: Serialize, DeserializeOwned, Clone);
397
398    #[test]
399    fn serialize_group() {
400        serialize_group_test::<G1>(48);
401        serialize_group_test::<G2>(96);
402    }
403
404    fn serialize_group_test<E: Element>(size: usize) {
405        let rng = &mut rand::thread_rng();
406        let sig = E::rand(rng);
407        let ser = bincode::serialize(&sig).unwrap();
408        assert_eq!(ser.len(), size);
409
410        let de: E = bincode::deserialize(&ser).unwrap();
411        assert_eq!(de, sig);
412    }
413
414    #[test]
415    fn serialize_field() {
416        serialize_field_test::<GT>(576);
417        serialize_field_test::<Scalar>(32);
418    }
419
420    fn serialize_field_test<E: Element>(size: usize) {
421        let rng = &mut rand::thread_rng();
422        let sig = E::rand(rng);
423        let ser = bincode::serialize(&sig).unwrap();
424        assert_eq!(ser.len(), size);
425
426        let de: E = bincode::deserialize(&ser).unwrap();
427        assert_eq!(de, sig);
428    }
429
430    #[test]
431    fn gt_exp() {
432        let rng = &mut rand::thread_rng();
433        let base = GT::rand(rng);
434
435        let mut sc = Scalar::one();
436        sc.add(&Scalar::one());
437        sc.add(&Scalar::one());
438
439        let mut exp = base.clone();
440        exp.mul(&sc);
441
442        let mut res = base.clone();
443        res.add(&base);
444        res.add(&base);
445
446        assert_eq!(exp, res);
447    }
448}