1use crate::group::{self, Element, PairingCurve as PC, Point, Scalar as Sc};
2use crate::hash::hasher::Keccak256Hasher;
3use crate::hash::try_and_increment::TryAndIncrement;
4use crate::hash::HashToCurve;
5use ark_bls12_381 as bls12_381;
6use ark_ec::{AffineCurve, PairingEngine, ProjectiveCurve};
7use ark_ff::PrimeField;
8use ark_ff::{Field, One, UniformRand, Zero};
9use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
10use rand_core::RngCore;
11use serde::{
12 de::{Error as DeserializeError, SeqAccess, Visitor},
13 ser::{Error as SerializationError, SerializeTuple},
14 Deserialize, Deserializer, Serialize, Serializer,
15};
16use std::{
17 fmt,
18 marker::PhantomData,
19 ops::{AddAssign, MulAssign, Neg, SubAssign},
20};
21
22use thiserror::Error;
23
24use super::{BLSError, CurveType};
25
26#[derive(Debug, Error)]
27pub enum BLS12Error {
28 #[error("{0}")]
29 SerializationError(#[from] ark_serialize::SerializationError),
30 #[error("{0}")]
31 BLSError(#[from] BLSError),
32}
33
34#[derive(Debug, Clone, Copy, Eq, PartialEq, Deserialize, Serialize)]
35pub struct Scalar(
36 #[serde(deserialize_with = "deserialize_field")]
37 #[serde(serialize_with = "serialize_field")]
38 <bls12_381::Bls12_381 as PairingEngine>::Fr,
39);
40
41type ZG1 = <bls12_381::Bls12_381 as PairingEngine>::G1Projective;
42
43#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
44pub struct G1(
45 #[serde(deserialize_with = "deserialize_group")]
46 #[serde(serialize_with = "serialize_group")]
47 ZG1,
48);
49
50type ZG2 = <bls12_381::Bls12_381 as PairingEngine>::G2Projective;
51
52#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
53pub struct G2(
54 #[serde(deserialize_with = "deserialize_group")]
55 #[serde(serialize_with = "serialize_group")]
56 ZG2,
57);
58
59#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
60pub struct GT(
61 #[serde(deserialize_with = "deserialize_field")]
62 #[serde(serialize_with = "serialize_field")]
63 <bls12_381::Bls12_381 as PairingEngine>::Fqk,
64);
65
66impl Element for Scalar {
67 type RHS = Scalar;
68
69 fn new() -> Self {
70 Self(Zero::zero())
71 }
72
73 fn one() -> Self {
74 Self(One::one())
75 }
76
77 fn add(&mut self, s2: &Self) {
78 self.0.add_assign(s2.0);
79 }
80
81 fn mul(&mut self, mul: &Scalar) {
82 self.0.mul_assign(mul.0)
83 }
84
85 fn rand<R: rand_core::RngCore>(rng: &mut R) -> Self {
86 Self(bls12_381::Fr::rand(rng))
87 }
88}
89
90impl Sc for Scalar {
91 fn set_int(&mut self, i: u64) {
92 *self = Self(bls12_381::Fr::from(i))
93 }
94
95 fn inverse(&self) -> Option<Self> {
96 Some(Self(Field::inverse(&self.0)?))
97 }
98
99 fn negate(&mut self) {
100 *self = Self(self.0.neg())
101 }
102
103 fn sub(&mut self, other: &Self) {
104 self.0.sub_assign(other.0);
105 }
106}
107
108impl fmt::Display for Scalar {
109 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
110 write!(f, "{{{:?}}}", self.0)
111 }
112}
113
114impl Element for G1 {
116 type RHS = Scalar;
117
118 fn new() -> Self {
119 Self(Zero::zero())
120 }
121
122 fn one() -> Self {
123 Self(ZG1::prime_subgroup_generator())
124 }
125
126 fn rand<R: RngCore>(rng: &mut R) -> Self {
127 Self(ZG1::rand(rng))
128 }
129
130 fn add(&mut self, s2: &Self) {
131 self.0.add_assign(s2.0);
132 }
133
134 fn mul(&mut self, mul: &Scalar) {
135 self.0.mul_assign(mul.0);
136 }
137}
138
139impl Point for G1 {
141 type Error = BLS12Error;
142
143 fn map(&mut self, data: &[u8]) -> Result<(), BLS12Error> {
144 let hasher = TryAndIncrement::new(&Keccak256Hasher);
145
146 let hash = hasher.hash(&[], data)?;
147
148 *self = Self(hash);
149
150 Ok(())
151 }
152}
153
154impl fmt::Display for G1 {
155 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
156 write!(f, "{{{:?}}}", self.0)
157 }
158}
159
160impl Element for G2 {
162 type RHS = Scalar;
163
164 fn new() -> Self {
165 Self(Zero::zero())
166 }
167
168 fn one() -> Self {
169 Self(ZG2::prime_subgroup_generator())
170 }
171
172 fn rand<R: RngCore>(mut rng: &mut R) -> Self {
173 Self(ZG2::rand(&mut rng))
174 }
175
176 fn add(&mut self, s2: &Self) {
177 self.0.add_assign(s2.0);
178 }
179
180 fn mul(&mut self, mul: &Scalar) {
181 self.0.mul_assign(mul.0)
182 }
183}
184
185impl Point for G2 {
187 type Error = BLS12Error;
188
189 fn map(&mut self, data: &[u8]) -> Result<(), BLS12Error> {
190 let hasher = TryAndIncrement::new(&Keccak256Hasher);
191
192 let hash = hasher.hash(&[], data)?;
193
194 *self = Self(hash);
195
196 Ok(())
197 }
198}
199
200impl fmt::Display for G2 {
201 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
202 write!(f, "{{{:?}}}", self.0)
203 }
204}
205
206impl Element for GT {
207 type RHS = Scalar;
208
209 fn new() -> Self {
210 Self(One::one())
211 }
212 fn one() -> Self {
213 Self(One::one())
214 }
215 fn add(&mut self, s2: &Self) {
216 self.0.mul_assign(s2.0);
217 }
218 fn mul(&mut self, mul: &Scalar) {
219 let scalar = mul.0.into_repr();
220 let mut res = Self::one();
221 let mut temp = *self;
222 for b in ark_ff::BitIteratorLE::without_trailing_zeros(scalar) {
223 if b {
224 res.0.mul_assign(temp.0);
225 }
226 temp.0.square_in_place();
227 }
228 *self = res;
229 }
230 fn rand<R: RngCore>(rng: &mut R) -> Self {
231 Self(bls12_381::Fq12::rand(rng))
232 }
233}
234
235impl fmt::Display for GT {
236 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
237 write!(f, "{{{:?}}}", self.0)
238 }
239}
240
241pub type G1Curve = group::G1Curve<PairingCurve>;
242pub type G2Curve = group::G2Curve<PairingCurve>;
243
244#[derive(Clone, Debug)]
245pub struct PairingCurve;
246
247impl PC for PairingCurve {
248 type Scalar = Scalar;
249 type G1 = G1;
250 type G2 = G2;
251 type GT = GT;
252
253 fn pair(a: &Self::G1, b: &Self::G2) -> Self::GT {
254 GT(<bls12_381::Bls12_381 as PairingEngine>::pairing(a.0, b.0))
255 }
256}
257
258fn deserialize_field<'de, D, C>(deserializer: D) -> Result<C, D::Error>
261where
262 D: Deserializer<'de>,
263 C: Field,
264{
265 struct FieldVisitor<C>(PhantomData<C>);
266
267 impl<'de, C> Visitor<'de> for FieldVisitor<C>
268 where
269 C: Field,
270 {
271 type Value = C;
272
273 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
274 formatter.write_str("a valid group element")
275 }
276
277 fn visit_seq<S>(self, mut seq: S) -> Result<C, S::Error>
278 where
279 S: SeqAccess<'de>,
280 {
281 let len = C::zero().serialized_size();
282 let bytes: Vec<u8> = (0..len)
283 .map(|_| {
284 seq.next_element()?
285 .ok_or_else(|| DeserializeError::custom("could not read bytes"))
286 })
287 .collect::<Result<Vec<_>, _>>()?;
288
289 let res = C::deserialize(&mut &bytes[..]).map_err(DeserializeError::custom)?;
290 Ok(res)
291 }
292 }
293
294 let visitor = FieldVisitor(PhantomData);
295 deserializer.deserialize_tuple(C::zero().serialized_size(), visitor)
296}
297
298fn serialize_field<S, C>(c: &C, s: S) -> Result<S::Ok, S::Error>
299where
300 S: Serializer,
301 C: Field,
302{
303 let len = c.serialized_size();
304 let mut bytes = Vec::with_capacity(len);
305 c.serialize(&mut bytes)
306 .map_err(SerializationError::custom)?;
307
308 let mut tup = s.serialize_tuple(len)?;
309 for byte in &bytes {
310 tup.serialize_element(byte)?;
311 }
312 tup.end()
313}
314
315fn deserialize_group<'de, D, C>(deserializer: D) -> Result<C, D::Error>
316where
317 D: Deserializer<'de>,
318 C: ProjectiveCurve,
319 C::Affine: CanonicalDeserialize + CanonicalSerialize,
320{
321 struct GroupVisitor<C>(PhantomData<C>);
322
323 impl<'de, C> Visitor<'de> for GroupVisitor<C>
324 where
325 C: ProjectiveCurve,
326 {
328 type Value = C;
329
330 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
331 formatter.write_str("a valid group element")
332 }
333
334 fn visit_seq<S>(self, mut seq: S) -> Result<C, S::Error>
335 where
336 S: SeqAccess<'de>,
337 {
338 let len = C::Affine::zero().serialized_size(); let bytes: Vec<u8> = (0..len)
340 .map(|_| {
341 seq.next_element()?
342 .ok_or_else(|| DeserializeError::custom("could not read bytes"))
343 })
344 .collect::<Result<Vec<_>, _>>()?;
345
346 let affine =
347 C::Affine::deserialize(&mut &bytes[..]).map_err(DeserializeError::custom)?;
348 Ok(affine.into_projective())
349 }
350 }
351
352 let visitor = GroupVisitor(PhantomData);
353 deserializer.deserialize_tuple(C::Affine::zero().serialized_size(), visitor)
354}
355
356fn serialize_group<S, C>(c: &C, s: S) -> Result<S::Ok, S::Error>
357where
358 S: Serializer,
359 C: ProjectiveCurve,
360 C::Affine: CanonicalSerialize,
361{
362 let affine = c.into_affine();
363 let len = affine.serialized_size();
364 let mut bytes = Vec::with_capacity(len);
365 affine
366 .serialize(&mut bytes)
367 .map_err(SerializationError::custom)?;
368
369 let mut tup = s.serialize_tuple(len)?;
370 for byte in &bytes {
371 tup.serialize_element(byte)?;
372 }
373 tup.end()
374}
375
376#[derive(Clone, Debug)]
377pub struct BLS12381Curve;
378
379impl CurveType for BLS12381Curve {
380 type G1Curve = G1Curve;
381
382 type G2Curve = G2Curve;
383
384 type PairingCurve = PairingCurve;
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use serde::de::DeserializeOwned;
391 use static_assertions::assert_impl_all;
392
393 assert_impl_all!(G1: Serialize, DeserializeOwned, Clone);
394 assert_impl_all!(G2: Serialize, DeserializeOwned, Clone);
395 assert_impl_all!(GT: Serialize, DeserializeOwned, Clone);
396 assert_impl_all!(Scalar: Serialize, DeserializeOwned, Clone);
397
398 #[test]
399 fn serialize_group() {
400 serialize_group_test::<G1>(48);
401 serialize_group_test::<G2>(96);
402 }
403
404 fn serialize_group_test<E: Element>(size: usize) {
405 let rng = &mut rand::thread_rng();
406 let sig = E::rand(rng);
407 let ser = bincode::serialize(&sig).unwrap();
408 assert_eq!(ser.len(), size);
409
410 let de: E = bincode::deserialize(&ser).unwrap();
411 assert_eq!(de, sig);
412 }
413
414 #[test]
415 fn serialize_field() {
416 serialize_field_test::<GT>(576);
417 serialize_field_test::<Scalar>(32);
418 }
419
420 fn serialize_field_test<E: Element>(size: usize) {
421 let rng = &mut rand::thread_rng();
422 let sig = E::rand(rng);
423 let ser = bincode::serialize(&sig).unwrap();
424 assert_eq!(ser.len(), size);
425
426 let de: E = bincode::deserialize(&ser).unwrap();
427 assert_eq!(de, sig);
428 }
429
430 #[test]
431 fn gt_exp() {
432 let rng = &mut rand::thread_rng();
433 let base = GT::rand(rng);
434
435 let mut sc = Scalar::one();
436 sc.add(&Scalar::one());
437 sc.add(&Scalar::one());
438
439 let mut exp = base.clone();
440 exp.mul(&sc);
441
442 let mut res = base.clone();
443 res.add(&base);
444 res.add(&base);
445
446 assert_eq!(exp, res);
447 }
448}