1use crate::group::{self, Element, PairingCurve as PC, Point, Scalar as Sc};
2use crate::hash::hasher::Keccak256Hasher;
3use crate::hash::try_and_increment::TryAndIncrement;
4use crate::hash::HashToCurve;
5use crate::serialize::ContractSerialize;
6use ark_bn254 as bn254;
7use ark_ec::{PairingEngine, ProjectiveCurve};
8use ark_ff::PrimeField;
9use ark_ff::{Field, One, UniformRand, Zero};
10use rand_core::RngCore;
11use serde::{
12 de::{Error as DeserializeError, SeqAccess, Visitor},
13 ser::{Error as SerializationError, SerializeTuple},
14 Deserialize, Deserializer, Serialize, Serializer,
15};
16use std::{
17 fmt,
18 marker::PhantomData,
19 ops::{AddAssign, MulAssign, Neg, SubAssign},
20};
21
22use thiserror::Error;
23
24use super::{BLSError, CurveType};
25
26#[derive(Debug, Error)]
27pub enum BNError {
28 #[error("{0}")]
29 SerializationError(#[from] ark_serialize::SerializationError),
30 #[error("{0}")]
31 BLSError(#[from] BLSError),
32}
33
34#[derive(Debug, Clone, Copy, Eq, PartialEq, Deserialize, Serialize)]
35pub struct Scalar(
36 #[serde(deserialize_with = "deserialize_field")]
37 #[serde(serialize_with = "serialize_field")]
38 <bn254::Bn254 as PairingEngine>::Fr,
39);
40
41type ZG1 = <bn254::Bn254 as PairingEngine>::G1Projective;
42
43#[derive(Debug, Clone, Copy, Eq, PartialEq)]
44pub struct G1(pub(crate) ZG1);
45
46type ZG2 = <bn254::Bn254 as PairingEngine>::G2Projective;
47
48#[derive(Debug, Clone, Copy, Eq, PartialEq)]
49pub struct G2(pub(crate) ZG2);
50
51impl Serialize for G1 {
52 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
53 where
54 S: Serializer,
55 {
56 let bytes = self
57 .serialize_to_contract_form()
58 .map_err(SerializationError::custom)?;
59
60 let mut tup = serializer.serialize_tuple(32)?;
61 for byte in &bytes {
62 tup.serialize_element(byte)?;
63 }
64 tup.end()
65 }
66}
67impl<'de> Deserialize<'de> for G1 {
68 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
69 where
70 D: Deserializer<'de>,
71 {
72 struct G1Visitor;
73
74 impl<'de> Visitor<'de> for G1Visitor {
75 type Value = G1;
76
77 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
78 formatter.write_str("a valid group element")
79 }
80
81 fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
82 where
83 S: SeqAccess<'de>,
84 {
85 let bytes: Vec<u8> = (0..32)
86 .map(|_| {
87 seq.next_element()?
88 .ok_or_else(|| DeserializeError::custom("could not read bytes"))
89 })
90 .collect::<Result<Vec<_>, _>>()?;
91
92 let ele =
93 G1::deserialize_from_contract_form(&bytes).map_err(DeserializeError::custom)?;
94 Ok(ele)
95 }
96 }
97
98 deserializer.deserialize_tuple(32, G1Visitor)
99 }
100}
101
102impl Serialize for G2 {
103 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
104 where
105 S: Serializer,
106 {
107 let bytes = self
108 .serialize_to_contract_form()
109 .map_err(SerializationError::custom)?;
110
111 let mut tup = serializer.serialize_tuple(128)?;
112 for byte in &bytes {
113 tup.serialize_element(byte)?;
114 }
115 tup.end()
116 }
117}
118impl<'de> Deserialize<'de> for G2 {
119 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
120 where
121 D: Deserializer<'de>,
122 {
123 struct G2Visitor;
124
125 impl<'de> Visitor<'de> for G2Visitor {
126 type Value = G2;
127
128 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
129 formatter.write_str("a valid group element")
130 }
131
132 fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
133 where
134 S: SeqAccess<'de>,
135 {
136 let bytes: Vec<u8> = (0..128)
137 .map(|_| {
138 seq.next_element()?
139 .ok_or_else(|| DeserializeError::custom("could not read bytes"))
140 })
141 .collect::<Result<Vec<_>, _>>()?;
142
143 let ele =
144 G2::deserialize_from_contract_form(&bytes).map_err(DeserializeError::custom)?;
145 Ok(ele)
146 }
147 }
148
149 deserializer.deserialize_tuple(128, G2Visitor)
150 }
151}
152
153#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
154pub struct GT(
155 #[serde(deserialize_with = "deserialize_field")]
156 #[serde(serialize_with = "serialize_field")]
157 <bn254::Bn254 as PairingEngine>::Fqk,
158);
159
160impl Element for Scalar {
161 type RHS = Scalar;
162
163 fn new() -> Self {
164 Self(Zero::zero())
165 }
166
167 fn one() -> Self {
168 Self(One::one())
169 }
170
171 fn add(&mut self, s2: &Self) {
172 self.0.add_assign(s2.0);
173 }
174
175 fn mul(&mut self, mul: &Scalar) {
176 self.0.mul_assign(mul.0)
177 }
178
179 fn rand<R: rand_core::RngCore>(rng: &mut R) -> Self {
180 Self(bn254::Fr::rand(rng))
181 }
182}
183
184impl Sc for Scalar {
185 fn set_int(&mut self, i: u64) {
186 *self = Self(bn254::Fr::from(i))
187 }
188
189 fn inverse(&self) -> Option<Self> {
190 Some(Self(Field::inverse(&self.0)?))
191 }
192
193 fn negate(&mut self) {
194 *self = Self(self.0.neg())
195 }
196
197 fn sub(&mut self, other: &Self) {
198 self.0.sub_assign(other.0);
199 }
200}
201
202impl fmt::Display for Scalar {
203 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
204 write!(f, "{{{:?}}}", self.0)
205 }
206}
207
208impl Element for G1 {
210 type RHS = Scalar;
211
212 fn new() -> Self {
213 Self(Zero::zero())
214 }
215
216 fn one() -> Self {
217 Self(ZG1::prime_subgroup_generator())
218 }
219
220 fn rand<R: RngCore>(rng: &mut R) -> Self {
221 Self(ZG1::rand(rng))
222 }
223
224 fn add(&mut self, s2: &Self) {
225 self.0.add_assign(s2.0);
226 }
227
228 fn mul(&mut self, mul: &Scalar) {
229 self.0.mul_assign(mul.0);
230 }
231}
232
233impl Point for G1 {
235 type Error = BNError;
236
237 fn map(&mut self, data: &[u8]) -> Result<(), BNError> {
238 let hasher = TryAndIncrement::new(&Keccak256Hasher);
239
240 let hash = hasher.hash(&[], data)?;
241
242 *self = Self(hash);
243
244 Ok(())
245 }
246}
247
248impl fmt::Display for G1 {
249 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
250 write!(f, "{{{:?}}}", self.0)
251 }
252}
253
254impl Element for G2 {
256 type RHS = Scalar;
257
258 fn new() -> Self {
259 Self(Zero::zero())
260 }
261
262 fn one() -> Self {
263 Self(ZG2::prime_subgroup_generator())
264 }
265
266 fn rand<R: RngCore>(mut rng: &mut R) -> Self {
267 Self(ZG2::rand(&mut rng))
268 }
269
270 fn add(&mut self, s2: &Self) {
271 self.0.add_assign(s2.0);
272 }
273
274 fn mul(&mut self, mul: &Scalar) {
275 self.0.mul_assign(mul.0)
276 }
277}
278
279impl Point for G2 {
281 type Error = BNError;
282
283 fn map(&mut self, data: &[u8]) -> Result<(), BNError> {
284 let hasher = TryAndIncrement::new(&Keccak256Hasher);
285
286 let hash = hasher.hash(&[], data)?;
287
288 *self = Self(hash);
289
290 Ok(())
291 }
292}
293
294impl fmt::Display for G2 {
295 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
296 write!(f, "{{{:?}}}", self.0)
297 }
298}
299
300impl Element for GT {
301 type RHS = Scalar;
302
303 fn new() -> Self {
304 Self(One::one())
305 }
306 fn one() -> Self {
307 Self(One::one())
308 }
309 fn add(&mut self, s2: &Self) {
310 self.0.mul_assign(s2.0);
311 }
312 fn mul(&mut self, mul: &Scalar) {
313 let scalar = mul.0.into_repr();
314 let mut res = Self::one();
315 let mut temp = *self;
316 for b in ark_ff::BitIteratorLE::without_trailing_zeros(scalar) {
317 if b {
318 res.0.mul_assign(temp.0);
319 }
320 temp.0.square_in_place();
321 }
322 *self = res;
323 }
324 fn rand<R: RngCore>(rng: &mut R) -> Self {
325 Self(bn254::Fq12::rand(rng))
326 }
327}
328
329impl fmt::Display for GT {
330 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
331 write!(f, "{{{:?}}}", self.0)
332 }
333}
334
335pub type G1Curve = group::G1Curve<PairingCurve>;
336pub type G2Curve = group::G2Curve<PairingCurve>;
337
338#[derive(Clone, Debug, Serialize)]
339pub struct PairingCurve;
340
341impl PC for PairingCurve {
342 type Scalar = Scalar;
343 type G1 = G1;
344 type G2 = G2;
345 type GT = GT;
346
347 fn pair(a: &Self::G1, b: &Self::G2) -> Self::GT {
348 GT(<bn254::Bn254 as PairingEngine>::pairing(a.0, b.0))
349 }
350}
351
352fn deserialize_field<'de, D, C>(deserializer: D) -> Result<C, D::Error>
355where
356 D: Deserializer<'de>,
357 C: Field,
358{
359 struct FieldVisitor<C>(PhantomData<C>);
360
361 impl<'de, C> Visitor<'de> for FieldVisitor<C>
362 where
363 C: Field,
364 {
365 type Value = C;
366
367 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
368 formatter.write_str("a valid group element")
369 }
370
371 fn visit_seq<S>(self, mut seq: S) -> Result<C, S::Error>
372 where
373 S: SeqAccess<'de>,
374 {
375 let len = C::zero().serialized_size();
376 let bytes: Vec<u8> = (0..len)
377 .map(|_| {
378 seq.next_element()?
379 .ok_or_else(|| DeserializeError::custom("could not read bytes"))
380 })
381 .collect::<Result<Vec<_>, _>>()?;
382
383 let res = C::deserialize(&mut &bytes[..]).map_err(DeserializeError::custom)?;
384 Ok(res)
385 }
386 }
387
388 let visitor = FieldVisitor(PhantomData);
389 deserializer.deserialize_tuple(C::zero().serialized_size(), visitor)
390}
391
392fn serialize_field<S, C>(c: &C, s: S) -> Result<S::Ok, S::Error>
393where
394 S: Serializer,
395 C: Field,
396{
397 let len = c.serialized_size();
398 let mut bytes = Vec::with_capacity(len);
399 c.serialize(&mut bytes)
400 .map_err(SerializationError::custom)?;
401
402 let mut tup = s.serialize_tuple(len)?;
403 for byte in &bytes {
404 tup.serialize_element(byte)?;
405 }
406 tup.end()
407}
408
409#[derive(Clone, Debug)]
410pub struct BN254Curve;
411
412impl CurveType for BN254Curve {
413 type G1Curve = G1Curve;
414
415 type G2Curve = G2Curve;
416
417 type PairingCurve = PairingCurve;
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use serde::de::DeserializeOwned;
424 use static_assertions::assert_impl_all;
425
426 assert_impl_all!(G1: Serialize, DeserializeOwned, Clone);
427 assert_impl_all!(G2: Serialize, DeserializeOwned, Clone);
428 assert_impl_all!(GT: Serialize, DeserializeOwned, Clone);
429 assert_impl_all!(Scalar: Serialize, DeserializeOwned, Clone);
430
431 #[test]
432 fn serialize_group() {
433 for _ in 0..10 {
434 serialize_group_test::<G1>(32);
435 serialize_group_test::<G2>(128);
436 }
437 }
438
439 fn serialize_group_test<E: Element>(size: usize) {
440 let empty = bincode::deserialize::<E>(&[]);
441 assert!(empty.is_err());
442
443 let rng = &mut rand::thread_rng();
444 let sig = E::rand(rng);
445 let ser = bincode::serialize(&sig).unwrap();
446 assert_eq!(ser.len(), size);
447
448 let de: E = bincode::deserialize(&ser).unwrap();
449 assert_eq!(de, sig);
450 }
451
452 #[test]
453 fn serialize_field() {
454 serialize_field_test::<GT>(384);
455 serialize_field_test::<Scalar>(32);
456 }
457
458 fn serialize_field_test<E: Element>(size: usize) {
459 let rng = &mut rand::thread_rng();
460 let sig = E::rand(rng);
461 let ser = bincode::serialize(&sig).unwrap();
462 assert_eq!(ser.len(), size);
463
464 let de: E = bincode::deserialize(&ser).unwrap();
465 assert_eq!(de, sig);
466 }
467
468 #[test]
469 fn gt_exp() {
470 let rng = &mut rand::thread_rng();
471 let base = GT::rand(rng);
472
473 let mut sc = Scalar::one();
474 sc.add(&Scalar::one());
475 sc.add(&Scalar::one());
476
477 let mut exp = base.clone();
478 exp.mul(&sc);
479
480 let mut res = base.clone();
481 res.add(&base);
482 res.add(&base);
483
484 assert_eq!(exp, res);
485 }
486}