Skip to main content

zksync_ff_derive/
lib.rs

1#![recursion_limit = "1024"]
2
3extern crate proc_macro;
4extern crate proc_macro2;
5extern crate syn;
6#[macro_use]
7extern crate quote;
8
9extern crate num_bigint;
10extern crate num_integer;
11extern crate num_traits;
12
13use num_bigint::BigUint;
14use num_integer::Integer;
15use num_traits::{One, ToPrimitive, Zero};
16use quote::TokenStreamExt;
17use std::str::FromStr;
18
19mod utils;
20use utils::*;
21
22#[cfg(feature = "asm")]
23mod asm;
24
25#[cfg(feature = "asm")]
26#[proc_macro_derive(PrimeFieldAsm, attributes(PrimeFieldModulus, PrimeFieldGenerator, UseADX))]
27pub fn prime_field_asm(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
28    self::asm::prime_field_asm_impl(input)
29}
30
31#[proc_macro_derive(PrimeField, attributes(PrimeFieldModulus, PrimeFieldGenerator, OptimisticCIOSMultiplication, OptimisticCIOSSquaring))]
32pub fn prime_field(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
33    // Parse the type definition
34    let ast: syn::DeriveInput = syn::parse(input).unwrap();
35
36    // The struct we're deriving for is a wrapper around a "Repr" type we must construct.
37    let repr_ident = fetch_wrapped_ident(&ast.data).expect("PrimeField derive only operates over tuple structs of a single item");
38
39    // We're given the modulus p of the prime field
40    let modulus: BigUint = fetch_attr("PrimeFieldModulus", &ast.attrs)
41        .expect("Please supply a PrimeFieldModulus attribute")
42        .parse()
43        .expect("PrimeFieldModulus should be a number");
44
45    // We may be provided with a generator of p - 1 order. It is required that this generator be quadratic
46    // nonresidue.
47    let generator: BigUint = fetch_attr("PrimeFieldGenerator", &ast.attrs)
48        .expect("Please supply a PrimeFieldGenerator attribute")
49        .parse()
50        .expect("PrimeFieldGenerator should be a number");
51
52    // User may opt-in for feature to generate CIOS based multiplication operation
53    let opt_in_cios_mul: Option<bool> = fetch_attr("OptimisticCIOSMultiplication", &ast.attrs).map(|el| el.parse().expect("OptimisticCIOSMultiplication should be `true` or `false`"));
54
55    // User may opt-in for feature to generate CIOS based squaring operation
56    let opt_in_cios_square: Option<bool> = fetch_attr("OptimisticCIOSSquaring", &ast.attrs).map(|el| el.parse().expect("OptimisticCIOSSquaring should be `true` or `false`"));
57
58    // The arithmetic in this library only works if the modulus*2 is smaller than the backing
59    // representation. Compute the number of limbs we need.
60    let mut limbs = 1;
61    {
62        let mod2 = (&modulus) << 1; // modulus * 2
63        let mut cur = BigUint::one() << 64; // always 64-bit limbs for now
64        while cur < mod2 {
65            limbs += 1;
66            cur = cur << 64;
67        }
68    }
69
70    let modulus_limbs = biguint_to_real_u64_vec(modulus.clone(), limbs);
71    let top_limb = modulus_limbs.last().unwrap().clone().to_u64().unwrap();
72    let can_use_optimistic_cios_mul = {
73        let mut can_use = if let Some(cios) = opt_in_cios_mul { cios } else { false };
74        if top_limb == 0 {
75            can_use = false;
76        }
77
78        if top_limb > (std::u64::MAX / 2) - 1 {
79            can_use = false;
80        }
81        can_use
82    };
83
84    let can_use_optimistic_cios_sqr = {
85        let mut can_use = if let Some(cios) = opt_in_cios_square { cios } else { false };
86        if top_limb == 0 {
87            can_use = false;
88        }
89
90        if top_limb > (std::u64::MAX / 4) - 1 {
91            assert!(!can_use, "can not use optimistic CIOS for this modulus");
92            can_use = false;
93        }
94        can_use
95    };
96
97    let mut gen = proc_macro2::TokenStream::new();
98
99    let (constants_impl, sqrt_impl) = prime_field_constants_and_sqrt(&ast.ident, &repr_ident, modulus, limbs, generator);
100
101    gen.extend(constants_impl);
102    gen.extend(prime_field_repr_impl(&repr_ident, limbs));
103    gen.extend(prime_field_impl(&ast.ident, &repr_ident, can_use_optimistic_cios_mul, can_use_optimistic_cios_sqr, limbs));
104    gen.extend(sqrt_impl);
105
106    // Return the generated impl
107    gen.into()
108}
109
110/// Fetches the ident being wrapped by the type we're deriving.
111fn fetch_wrapped_ident(body: &syn::Data) -> Option<syn::Ident> {
112    match body {
113        &syn::Data::Struct(ref variant_data) => match variant_data.fields {
114            syn::Fields::Unnamed(ref fields) => {
115                if fields.unnamed.len() == 1 {
116                    match fields.unnamed[0].ty {
117                        syn::Type::Path(ref path) => {
118                            if path.path.segments.len() == 1 {
119                                return Some(path.path.segments[0].ident.clone());
120                            }
121                        }
122                        _ => {}
123                    }
124                }
125            }
126            _ => {}
127        },
128        _ => {}
129    };
130
131    None
132}
133
134/// Fetch an attribute string from the derived struct.
135fn fetch_attr(name: &str, attrs: &[syn::Attribute]) -> Option<String> {
136    for attr in attrs {
137        if let Ok(meta) = attr.parse_meta() {
138            match meta {
139                syn::Meta::NameValue(nv) => {
140                    if nv.path.is_ident(name) {
141                        match nv.lit {
142                            syn::Lit::Str(ref s) => return Some(s.value()),
143                            _ => {
144                                panic!("attribute {} should be a string", name);
145                            }
146                        }
147                    }
148                }
149                _ => {
150                    panic!("attribute {} should be a string", name);
151                }
152            }
153        }
154    }
155
156    None
157}
158
159// Implement PrimeFieldRepr for the wrapped ident `repr` with `limbs` limbs.
160fn prime_field_repr_impl(repr: &syn::Ident, limbs: usize) -> proc_macro2::TokenStream {
161    quote! {
162
163        #[derive(Copy, Clone, PartialEq, Eq, Default, ::serde::Serialize, ::serde::Deserialize)]
164        pub struct #repr(
165            pub [u64; #limbs]
166        );
167
168        impl ::std::fmt::Debug for #repr
169        {
170            fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
171                write!(f, "0x")?;
172                for i in self.0.iter().rev() {
173                    write!(f, "{:016x}", *i)?;
174                }
175
176                Ok(())
177            }
178        }
179
180        impl crate::ff::Rand for #repr {
181            #[inline(always)]
182            fn rand<R: crate::ff::rand::Rng + ?Sized>(rng: &mut R) -> Self {
183                #repr(rng.gen())
184            }
185        }
186
187        impl ::std::fmt::Display for #repr {
188            fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
189                write!(f, "0x")?;
190                for i in self.0.iter().rev() {
191                    write!(f, "{:016x}", *i)?;
192                }
193
194                Ok(())
195            }
196        }
197
198        impl std::hash::Hash for #repr {
199            fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
200                for limb in self.0.iter() {
201                    limb.hash(state);
202                }
203            }
204        }
205
206        impl AsRef<[u64]> for #repr {
207            #[inline(always)]
208            fn as_ref(&self) -> &[u64] {
209                &self.0
210            }
211        }
212
213        impl AsMut<[u64]> for #repr {
214            #[inline(always)]
215            fn as_mut(&mut self) -> &mut [u64] {
216                &mut self.0
217            }
218        }
219
220        impl From<u64> for #repr {
221            #[inline(always)]
222            fn from(val: u64) -> #repr {
223                use std::default::Default;
224
225                let mut repr = Self::default();
226                repr.0[0] = val;
227                repr
228            }
229        }
230
231        impl Ord for #repr {
232            #[inline(always)]
233            fn cmp(&self, other: &#repr) -> ::std::cmp::Ordering {
234                for (a, b) in self.0.iter().rev().zip(other.0.iter().rev()) {
235                    if a < b {
236                        return ::std::cmp::Ordering::Less
237                    } else if a > b {
238                        return ::std::cmp::Ordering::Greater
239                    }
240                }
241
242                ::std::cmp::Ordering::Equal
243            }
244        }
245
246        impl PartialOrd for #repr {
247            #[inline(always)]
248            fn partial_cmp(&self, other: &#repr) -> Option<::std::cmp::Ordering> {
249                Some(self.cmp(other))
250            }
251        }
252
253        impl crate::ff::PrimeFieldRepr for #repr {
254            #[inline(always)]
255            fn is_odd(&self) -> bool {
256                self.0[0] & 1 == 1
257            }
258
259            #[inline(always)]
260            fn is_even(&self) -> bool {
261                !self.is_odd()
262            }
263
264            #[inline(always)]
265            fn is_zero(&self) -> bool {
266                self.0.iter().all(|&e| e == 0)
267            }
268
269            #[inline(always)]
270            fn shr(&mut self, mut n: u32) {
271                if n as usize >= 64 * #limbs {
272                    *self = Self::from(0);
273                    return;
274                }
275
276                while n >= 64 {
277                    let mut t = 0;
278                    for i in self.0.iter_mut().rev() {
279                        ::std::mem::swap(&mut t, i);
280                    }
281                    n -= 64;
282                }
283
284                if n > 0 {
285                    let mut t = 0;
286                    for i in self.0.iter_mut().rev() {
287                        let t2 = *i << (64 - n);
288                        *i >>= n;
289                        *i |= t;
290                        t = t2;
291                    }
292                }
293            }
294
295            #[inline(always)]
296            fn div2(&mut self) {
297                let mut t = 0;
298                for i in self.0.iter_mut().rev() {
299                    let t2 = *i << 63;
300                    *i >>= 1;
301                    *i |= t;
302                    t = t2;
303                }
304            }
305
306            #[inline(always)]
307            fn mul2(&mut self) {
308                let mut last = 0;
309                for i in &mut self.0 {
310                    let tmp = *i >> 63;
311                    *i <<= 1;
312                    *i |= last;
313                    last = tmp;
314                }
315            }
316
317            #[inline(always)]
318            fn shl(&mut self, mut n: u32) {
319                if n as usize >= 64 * #limbs {
320                    *self = Self::from(0);
321                    return;
322                }
323
324                while n >= 64 {
325                    let mut t = 0;
326                    for i in &mut self.0 {
327                        ::std::mem::swap(&mut t, i);
328                    }
329                    n -= 64;
330                }
331
332                if n > 0 {
333                    let mut t = 0;
334                    for i in &mut self.0 {
335                        let t2 = *i >> (64 - n);
336                        *i <<= n;
337                        *i |= t;
338                        t = t2;
339                    }
340                }
341            }
342
343            #[inline(always)]
344            fn num_bits(&self) -> u32 {
345                let mut ret = (#limbs as u32) * 64;
346                for i in self.0.iter().rev() {
347                    let leading = i.leading_zeros();
348                    ret -= leading;
349                    if leading != 64 {
350                        break;
351                    }
352                }
353
354                ret
355            }
356
357            #[inline(always)]
358            fn add_nocarry(&mut self, other: &#repr) {
359                let mut carry = 0;
360
361                for (a, b) in self.0.iter_mut().zip(other.0.iter()) {
362                    *a = crate::ff::adc(*a, *b, &mut carry);
363                }
364            }
365
366            #[inline(always)]
367            fn sub_noborrow(&mut self, other: &#repr) {
368                let mut borrow = 0;
369
370                for (a, b) in self.0.iter_mut().zip(other.0.iter()) {
371                    *a = crate::ff::sbb(*a, *b, &mut borrow);
372                }
373            }
374        }
375    }
376}
377
378fn prime_field_constants_and_sqrt(name: &syn::Ident, repr: &syn::Ident, modulus: BigUint, limbs: usize, generator: BigUint) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
379    let modulus_num_bits = biguint_num_bits(modulus.clone());
380
381    // The number of bits we should "shave" from a randomly sampled reputation, i.e.,
382    // if our modulus is 381 bits and our representation is 384 bits, we should shave
383    // 3 bits from the beginning of a randomly sampled 384 bit representation to
384    // reduce the cost of rejection sampling.
385    let repr_shave_bits = (64 * limbs as u32) - biguint_num_bits(modulus.clone());
386    let repr_shave_mask = if repr_shave_bits == 64 { 0u64 } else { 0xffffffffffffffffu64 >> repr_shave_bits };
387
388    // Compute R = 2**(64 * limbs) mod m
389    let r = (BigUint::one() << (limbs * 64)) % &modulus;
390
391    // modulus - 1 = 2^s * t
392    let mut s: u32 = 0;
393    let mut t = &modulus - BigUint::from_str("1").unwrap();
394    while t.is_even() {
395        t = t >> 1;
396        s += 1;
397    }
398
399    // Compute 2^s root of unity given the generator
400    let root_of_unity = biguint_to_u64_vec((generator.clone().modpow(&t, &modulus) * &r) % &modulus, limbs);
401    let generator = biguint_to_u64_vec((generator.clone() * &r) % &modulus, limbs);
402
403    let mod_minus_1_over_2 = biguint_to_u64_vec((&modulus - BigUint::from_str("1").unwrap()) >> 1, limbs);
404    let legendre_impl = quote! {
405        fn legendre(&self) -> crate::ff::LegendreSymbol {
406            // s = self^((modulus - 1) // 2)
407            let s = self.pow(#mod_minus_1_over_2);
408            if s == Self::zero() {
409                crate::ff::LegendreSymbol::Zero
410            } else if s == Self::one() {
411                crate::ff::LegendreSymbol::QuadraticResidue
412            } else {
413                crate::ff::LegendreSymbol::QuadraticNonResidue
414            }
415        }
416    };
417
418    let sqrt_impl = if (&modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() {
419        let mod_minus_3_over_4 = biguint_to_u64_vec((&modulus - BigUint::from_str("3").unwrap()) >> 2, limbs);
420
421        // Compute -R as (m - r)
422        let rneg = biguint_to_u64_vec(&modulus - &r, limbs);
423
424        quote! {
425            impl crate::ff::SqrtField for #name {
426                #legendre_impl
427
428                fn sqrt(&self) -> Option<Self> {
429                    // Shank's algorithm for q mod 4 = 3
430                    // https://eprint.iacr.org/2012/685.pdf (page 9, algorithm 2)
431
432                    let mut a1 = self.pow(#mod_minus_3_over_4);
433
434                    let mut a0 = a1;
435                    a0.square();
436                    a0.mul_assign(self);
437
438                    if a0.0 == #repr(#rneg) {
439                        None
440                    } else {
441                        a1.mul_assign(self);
442                        Some(a1)
443                    }
444                }
445            }
446        }
447    } else if (&modulus % BigUint::from_str("16").unwrap()) == BigUint::from_str("1").unwrap() {
448        let t_plus_1_over_2 = biguint_to_u64_vec((&t + BigUint::one()) >> 1, limbs);
449        let t = biguint_to_u64_vec(t.clone(), limbs);
450
451        quote! {
452            impl crate::ff::SqrtField for #name {
453                #legendre_impl
454
455                fn sqrt(&self) -> Option<Self> {
456                    // Tonelli-Shank's algorithm for q mod 16 = 1
457                    // https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5)
458
459                    match self.legendre() {
460                        crate::ff::LegendreSymbol::Zero => Some(*self),
461                        crate::ff::LegendreSymbol::QuadraticNonResidue => None,
462                        crate::ff::LegendreSymbol::QuadraticResidue => {
463                            let mut c = #name(ROOT_OF_UNITY);
464                            let mut r = self.pow(#t_plus_1_over_2);
465                            let mut t = self.pow(#t);
466                            let mut m = S;
467
468                            while t != Self::one() {
469                                let mut i = 1;
470                                {
471                                    let mut t2i = t;
472                                    t2i.square();
473                                    loop {
474                                        if t2i == Self::one() {
475                                            break;
476                                        }
477                                        t2i.square();
478                                        i += 1;
479                                    }
480                                }
481
482                                for _ in 0..(m - i - 1) {
483                                    c.square();
484                                }
485                                r.mul_assign(&c);
486                                c.square();
487                                t.mul_assign(&c);
488                                m = i;
489                            }
490
491                            Some(r)
492                        }
493                    }
494                }
495            }
496        }
497    } else {
498        quote! {}
499    };
500
501    // Compute R^2 mod m
502    let r2 = biguint_to_u64_vec((&r * &r) % &modulus, limbs);
503
504    let r = biguint_to_u64_vec(r, limbs);
505    let modulus = biguint_to_real_u64_vec(modulus, limbs);
506
507    // Compute -m^-1 mod 2**64 by exponentiating by totient(2**64) - 1
508    let mut inv = 1u64;
509    for _ in 0..63 {
510        inv = inv.wrapping_mul(inv);
511        inv = inv.wrapping_mul(modulus[0]);
512    }
513    inv = inv.wrapping_neg();
514
515    (
516        quote! {
517            /// This is the modulus m of the prime field
518            const MODULUS: #repr = #repr([#(#modulus,)*]);
519
520            /// The number of bits needed to represent the modulus.
521            const MODULUS_BITS: u32 = #modulus_num_bits;
522
523            /// The number of bits that must be shaved from the beginning of
524            /// the representation when randomly sampling.
525            const REPR_SHAVE_BITS: u32 = #repr_shave_bits;
526
527            /// Precalculated mask to shave bits from the top limb in random sampling
528            const TOP_LIMB_SHAVE_MASK: u64 = #repr_shave_mask;
529
530            /// 2^{limbs*64} mod m
531            const R: #repr = #repr(#r);
532
533            /// 2^{limbs*64*2} mod m
534            const R2: #repr = #repr(#r2);
535
536            /// -(m^{-1} mod m) mod m
537            const INV: u64 = #inv;
538
539            /// Multiplicative generator of `MODULUS` - 1 order, also quadratic
540            /// nonresidue.
541            const GENERATOR: #repr = #repr(#generator);
542
543            /// 2^s * t = MODULUS - 1 with t odd
544            const S: u32 = #s;
545
546            /// 2^s root of unity computed by GENERATOR^t
547            const ROOT_OF_UNITY: #repr = #repr(#root_of_unity);
548        },
549        sqrt_impl,
550    )
551}
552
553// Returns r{n} as an ident.
554fn get_temp(n: usize) -> syn::Ident {
555    syn::Ident::new(&format!("r{}", n), proc_macro2::Span::call_site())
556}
557
558fn get_temp_with_literal(literal: &str, n: usize) -> syn::Ident {
559    syn::Ident::new(&format!("{}{}", literal, n), proc_macro2::Span::call_site())
560}
561
562/// Implement PrimeField for the derived type.
563fn prime_field_impl(name: &syn::Ident, repr: &syn::Ident, can_use_cios_mul: bool, can_use_cios_sqr: bool, limbs: usize) -> proc_macro2::TokenStream {
564    // The parameter list for the mont_reduce() internal method.
565    // r0: u64, mut r1: u64, mut r2: u64, ...
566    let mut mont_paramlist = proc_macro2::TokenStream::new();
567    mont_paramlist.append_separated(
568        (0..(limbs * 2)).map(|i| (i, get_temp(i))).map(|(i, x)| {
569            if i != 0 {
570                quote! {mut #x: u64}
571            } else {
572                quote! {#x: u64}
573            }
574        }),
575        proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
576    );
577
578    // Implement montgomery reduction for some number of limbs
579    fn mont_impl(limbs: usize) -> proc_macro2::TokenStream {
580        let mut gen = proc_macro2::TokenStream::new();
581
582        for i in 0..limbs {
583            {
584                let temp = get_temp(i);
585                gen.extend(quote! {
586                    let k = #temp.wrapping_mul(INV);
587                    let mut carry = 0;
588                    crate::ff::mac_with_carry(#temp, k, MODULUS.0[0], &mut carry);
589                });
590            }
591
592            for j in 1..limbs {
593                let temp = get_temp(i + j);
594                gen.extend(quote! {
595                    #temp = crate::ff::mac_with_carry(#temp, k, MODULUS.0[#j], &mut carry);
596                });
597            }
598
599            let temp = get_temp(i + limbs);
600
601            if i == 0 {
602                gen.extend(quote! {
603                    #temp = crate::ff::adc(#temp, 0, &mut carry);
604                });
605            } else {
606                gen.extend(quote! {
607                    #temp = crate::ff::adc(#temp, carry2, &mut carry);
608                });
609            }
610
611            if i != (limbs - 1) {
612                gen.extend(quote! {
613                    let carry2 = carry;
614                });
615            }
616        }
617
618        for i in 0..limbs {
619            let temp = get_temp(limbs + i);
620
621            gen.extend(quote! {
622                (self.0).0[#i] = #temp;
623            });
624        }
625
626        gen
627    }
628
629    fn sqr_impl(a: proc_macro2::TokenStream, limbs: usize) -> proc_macro2::TokenStream {
630        let mut gen = proc_macro2::TokenStream::new();
631
632        for i in 0..(limbs - 1) {
633            gen.extend(quote! {
634                let mut carry = 0;
635            });
636
637            for j in (i + 1)..limbs {
638                let temp = get_temp(i + j);
639                if i == 0 {
640                    gen.extend(quote! {
641                        let #temp = crate::ff::mac_with_carry(0, (#a.0).0[#i], (#a.0).0[#j], &mut carry);
642                    });
643                } else {
644                    gen.extend(quote! {
645                        let #temp = crate::ff::mac_with_carry(#temp, (#a.0).0[#i], (#a.0).0[#j], &mut carry);
646                    });
647                }
648            }
649
650            let temp = get_temp(i + limbs);
651
652            gen.extend(quote! {
653                let #temp = carry;
654            });
655        }
656
657        if limbs != 1 {
658            for i in 1..(limbs * 2) {
659                let temp0 = get_temp(limbs * 2 - i);
660                let temp1 = get_temp(limbs * 2 - i - 1);
661
662                if i == 1 {
663                    gen.extend(quote! {
664                        let #temp0 = #temp1 >> 63;
665                    });
666                } else if i == (limbs * 2 - 1) {
667                    gen.extend(quote! {
668                        let #temp0 = #temp0 << 1;
669                    });
670                } else {
671                    gen.extend(quote! {
672                        let #temp0 = (#temp0 << 1) | (#temp1 >> 63);
673                    });
674                }
675            }
676        } else {
677            gen.extend(quote! {
678                let r1 = 0;
679            });
680        }
681
682        gen.extend(quote! {
683            let mut carry = 0;
684        });
685
686        for i in 0..limbs {
687            let temp0 = get_temp(i * 2);
688            let temp1 = get_temp(i * 2 + 1);
689            if i == 0 {
690                gen.extend(quote! {
691                    let #temp0 = crate::ff::mac_with_carry(0, (#a.0).0[#i], (#a.0).0[#i], &mut carry);
692                });
693            } else {
694                gen.extend(quote! {
695                    let #temp0 = crate::ff::mac_with_carry(#temp0, (#a.0).0[#i], (#a.0).0[#i], &mut carry);
696                });
697            }
698
699            gen.extend(quote! {
700                let #temp1 = crate::ff::adc(#temp1, 0, &mut carry);
701            });
702        }
703
704        let mut mont_calling = proc_macro2::TokenStream::new();
705        mont_calling.append_separated((0..(limbs * 2)).map(|i| get_temp(i)), proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone));
706
707        gen.extend(quote! {
708            self.mont_reduce(#mont_calling);
709        });
710
711        gen
712    }
713
714    fn mul_impl(a: proc_macro2::TokenStream, b: proc_macro2::TokenStream, limbs: usize) -> proc_macro2::TokenStream {
715        let mut gen = proc_macro2::TokenStream::new();
716
717        for i in 0..limbs {
718            gen.extend(quote! {
719                let mut carry = 0;
720            });
721
722            for j in 0..limbs {
723                let temp = get_temp(i + j);
724
725                if i == 0 {
726                    gen.extend(quote! {
727                        let #temp = crate::ff::mac_with_carry(0, (#a.0).0[#i], (#b.0).0[#j], &mut carry);
728                    });
729                } else {
730                    gen.extend(quote! {
731                        let #temp = crate::ff::mac_with_carry(#temp, (#a.0).0[#i], (#b.0).0[#j], &mut carry);
732                    });
733                }
734            }
735
736            let temp = get_temp(i + limbs);
737
738            gen.extend(quote! {
739                let #temp = carry;
740            });
741        }
742
743        let mut mont_calling = proc_macro2::TokenStream::new();
744        mont_calling.append_separated((0..(limbs * 2)).map(|i| get_temp(i)), proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone));
745
746        gen.extend(quote! {
747            self.mont_reduce(#mont_calling);
748        });
749
750        gen
751    }
752
753    fn optimistic_cios_mul_impl(a: proc_macro2::TokenStream, b: proc_macro2::TokenStream, name: &syn::Ident, repr: &syn::Ident, limbs: usize) -> proc_macro2::TokenStream {
754        let mut gen = proc_macro2::TokenStream::new();
755
756        let mut other_limbs_set = proc_macro2::TokenStream::new();
757        other_limbs_set.append_separated((0..limbs).map(|i| get_temp_with_literal("b", i)), proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone));
758
759        gen.extend(quote! {
760            let [#other_limbs_set] = (#b.0).0;
761        });
762
763        for i in 0..limbs {
764            gen.extend(quote! {
765                let a = (#a.0).0[#i];
766            });
767
768            let temp = get_temp(0);
769
770            let b = get_temp_with_literal("b", 0);
771
772            if i == 0 {
773                gen.extend(quote! {
774                    let (#temp, carry) = crate::ff::full_width_mul(a, #b);
775                });
776            } else {
777                gen.extend(quote! {
778                    let (#temp, carry) = crate::ff::mac_by_value(#temp, a, #b);
779                });
780            }
781            gen.extend(quote! {
782                let m = r0.wrapping_mul(INV);
783                let red_carry = crate::ff::mac_by_value_return_carry_only(#temp, m, MODULUS.0[0]);
784            });
785
786            for j in 1..limbs {
787                let temp = get_temp(j);
788
789                let b = get_temp_with_literal("b", j);
790
791                if i == 0 {
792                    gen.extend(quote! {
793                        let (#temp, carry) = crate::ff::mac_by_value(carry, a, #b);
794                    });
795                } else {
796                    gen.extend(quote! {
797                        let (#temp, carry) = crate::ff::mac_with_carry_by_value(#temp, a, #b, carry);
798                    });
799                }
800
801                let temp_prev = get_temp(j - 1);
802
803                gen.extend(quote! {
804                    let (#temp_prev, red_carry) = crate::ff::mac_with_carry_by_value(#temp, m, MODULUS.0[#j], red_carry);
805                });
806            }
807
808            let temp = get_temp(limbs - 1);
809            gen.extend(quote! {
810                let #temp = red_carry + carry;
811            });
812        }
813
814        let mut limbs_set = proc_macro2::TokenStream::new();
815        limbs_set.append_separated((0..limbs).map(|i| get_temp(i)), proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone));
816
817        gen.extend(quote! {
818            *self = #name(#repr([#limbs_set]));
819            self.reduce();
820        });
821
822        gen
823    }
824
825    fn optimistic_cios_sqr_impl(a: proc_macro2::TokenStream, name: &syn::Ident, repr: &syn::Ident, limbs: usize) -> proc_macro2::TokenStream {
826        let mut gen = proc_macro2::TokenStream::new();
827
828        let mut this_limbs_set = proc_macro2::TokenStream::new();
829        this_limbs_set.append_separated((0..limbs).map(|i| get_temp_with_literal("a", i)), proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone));
830
831        gen.extend(quote! {
832            let [#this_limbs_set] = (#a.0).0;
833        });
834
835        for i in 0..limbs {
836            for red_idx in 0..i {
837                if red_idx == 0 {
838                    let temp = get_temp(0);
839
840                    gen.extend(quote! {
841                        let m = r0.wrapping_mul(INV);
842                        let red_carry = crate::ff::mac_by_value_return_carry_only(#temp, m, MODULUS.0[0]);
843                    });
844                } else {
845                    let temp = get_temp(red_idx);
846                    let temp_prev = get_temp(red_idx - 1);
847                    gen.extend(quote! {
848                        let (#temp_prev, red_carry) = crate::ff::mac_with_carry_by_value(#temp, m, MODULUS.0[#red_idx], red_carry);
849                    });
850                }
851            }
852            let a = get_temp_with_literal("a", i);
853
854            // single square step
855            if i == 0 {
856                // for a first pass just square and reduce
857                let temp = get_temp(0);
858
859                gen.extend(quote! {
860                    let (#temp, carry) = crate::ff::full_width_mul(#a, #a);
861                    let m = r0.wrapping_mul(INV);
862                    let red_carry = crate::ff::mac_by_value_return_carry_only(#temp, m, MODULUS.0[0]);
863                });
864            } else {
865                // for next passes square, add previous value and reduce
866                let temp = get_temp(i);
867                let temp_prev = get_temp(i - 1);
868                gen.extend(quote! {
869                    let (#temp, carry) = crate::ff::mac_by_value(#temp, #a, #a);
870
871                });
872
873                if i == limbs - 1 {
874                    gen.extend(quote! {
875                        let (#temp_prev, #temp) = crate::ff::mac_with_low_and_high_carry_by_value(
876                            red_carry, m, MODULUS.0[#i], #temp, carry
877                        );
878                    });
879                } else {
880                    gen.extend(quote! {
881                        let (#temp_prev, red_carry) = crate::ff::mac_with_carry_by_value(#temp, m, MODULUS.0[#i], red_carry);
882                    });
883                }
884            }
885
886            // continue with propagation and reduction
887            for j in (i + 1)..limbs {
888                let b = get_temp_with_literal("a", j);
889
890                let temp = get_temp(j);
891
892                if i == 0 {
893                    if j == limbs - 1 {
894                        let temp_prev = get_temp(j - 1);
895
896                        gen.extend(quote! {
897                            let (#temp, carry) = crate::ff::mul_double_add_low_and_high_carry_by_value_ignore_superhi(
898                                #a, #b, carry, superhi
899                            );
900
901                            let (#temp_prev, #temp) = crate::ff::mac_with_low_and_high_carry_by_value(
902                                red_carry, m, MODULUS.0[#j], #temp, carry
903                            );
904                        });
905                    } else {
906                        if j == i + 1 {
907                            gen.extend(quote! {
908                                let (#temp, carry, superhi) = crate::ff::mul_double_add_by_value(
909                                    carry, #a, #b,
910                                );
911                            });
912                        } else {
913                            gen.extend(quote! {
914                                let (#temp, carry, superhi) = crate::ff::mul_double_add_low_and_high_carry_by_value(
915                                    #a, #b, carry, superhi
916                                );
917                            });
918                        }
919
920                        let temp_prev = get_temp(j - 1);
921
922                        gen.extend(quote! {
923                            let (#temp_prev, red_carry) = crate::ff::mac_with_carry_by_value(#temp, m, MODULUS.0[#j], red_carry);
924                        });
925                    }
926                } else {
927                    if j == limbs - 1 {
928                        let temp_prev = get_temp(j - 1);
929
930                        if j == i + 1 {
931                            gen.extend(quote! {
932                                let (#temp, carry) = crate::ff::mul_double_add_add_carry_by_value_ignore_superhi(
933                                    #temp, #a, #b, carry
934                                );
935                            });
936                        } else {
937                            gen.extend(quote! {
938                                let (#temp, carry) = crate::ff::mul_double_add_add_low_and_high_carry_by_value_ignore_superhi(
939                                    #temp, #a, #b, carry, superhi
940                                );
941                            });
942                        }
943
944                        gen.extend(quote! {
945                            let (#temp_prev, #temp) = crate::ff::mac_with_low_and_high_carry_by_value(
946                                red_carry, m, MODULUS.0[#j], #temp, carry
947                            );
948                        });
949                    } else {
950                        if j == i + 1 {
951                            gen.extend(quote! {
952                                let (#temp, carry, superhi) = crate::ff::mul_double_add_add_carry_by_value(
953                                    #temp, #a, #b, carry
954                                );
955                            });
956                        } else {
957                            gen.extend(quote! {
958                                let (#temp, carry, superhi) = crate::ff::mul_double_add_add_low_and_high_carry_by_value_ignore_superhi(
959                                    #temp, #a, #b, carry, superhi
960                                );
961                            });
962                        }
963                        let temp_prev = get_temp(j - 1);
964
965                        gen.extend(quote! {
966                            let (#temp_prev, red_carry) = mac_with_carry_by_value(#temp, m, MODULUS.0[#j], red_carry);
967                        });
968                    }
969                }
970            }
971
972            // let temp = get_temp(limbs-1);
973
974            // gen.extend(quote!{
975            //     let #temp = red_carry + carry;
976            // });
977        }
978
979        let mut limbs_set = proc_macro2::TokenStream::new();
980        limbs_set.append_separated((0..limbs).map(|i| get_temp(i)), proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone));
981
982        gen.extend(quote! {
983            *self = #name(#repr([#limbs_set]));
984            self.reduce();
985        });
986
987        gen
988    }
989    let multiply_impl = if can_use_cios_mul {
990        optimistic_cios_mul_impl(quote! {self}, quote! {other}, name, repr, limbs)
991    } else {
992        mul_impl(quote! {self}, quote! {other}, limbs)
993    };
994    let squaring_impl = if can_use_cios_sqr {
995        optimistic_cios_sqr_impl(quote! {self}, name, repr, limbs)
996    } else {
997        sqr_impl(quote! {self}, limbs)
998    };
999
1000    let top_limb_index = limbs - 1;
1001
1002    let montgomery_impl = mont_impl(limbs);
1003
1004    // (self.0).0[0], (self.0).0[1], ..., 0, 0, 0, 0, ...
1005    let mut into_repr_params = proc_macro2::TokenStream::new();
1006    into_repr_params.append_separated(
1007        (0..limbs).map(|i| quote! { (self.0).0[#i] }).chain((0..limbs).map(|_| quote! {0})),
1008        proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
1009    );
1010
1011    quote! {
1012        impl ::std::marker::Copy for #name { }
1013
1014        impl ::std::clone::Clone for #name {
1015            fn clone(&self) -> #name {
1016                *self
1017            }
1018        }
1019
1020        impl ::std::cmp::PartialEq for #name {
1021            fn eq(&self, other: &#name) -> bool {
1022                self.0 == other.0
1023            }
1024        }
1025
1026        impl ::std::cmp::Eq for #name { }
1027
1028        impl ::std::fmt::Debug for #name
1029        {
1030            fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
1031                write!(f, "{}({:?})", stringify!(#name), self.into_repr())
1032            }
1033        }
1034
1035        /// Elements are ordered lexicographically.
1036        impl Ord for #name {
1037            #[inline(always)]
1038            fn cmp(&self, other: &#name) -> ::std::cmp::Ordering {
1039                self.into_repr().cmp(&other.into_repr())
1040            }
1041        }
1042
1043        impl PartialOrd for #name {
1044            #[inline(always)]
1045            fn partial_cmp(&self, other: &#name) -> Option<::std::cmp::Ordering> {
1046                Some(self.cmp(other))
1047            }
1048        }
1049
1050        impl ::std::fmt::Display for #name {
1051            fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
1052                write!(f, "{}({})", stringify!(#name), self.into_repr())
1053            }
1054        }
1055
1056        impl crate::ff::Rand for #name {
1057            /// Computes a uniformly random element using rejection sampling.
1058            fn rand<R: crate::ff::rand::Rng + ?Sized>(rng: &mut R) -> Self {
1059                loop {
1060                    let mut tmp = #name(<#repr as crate::ff::Rand>::rand(rng));
1061
1062                    // Mask away the unused bits at the beginning.
1063                    tmp.0.as_mut()[#top_limb_index] &= TOP_LIMB_SHAVE_MASK;
1064
1065                    if tmp.is_valid() {
1066                        return tmp
1067                    }
1068                }
1069            }
1070        }
1071
1072        impl From<#name> for #repr {
1073            fn from(e: #name) -> #repr {
1074                e.into_repr()
1075            }
1076        }
1077
1078        impl crate::ff::PrimeField for #name {
1079            type Repr = #repr;
1080
1081            fn from_repr(r: #repr) -> Result<#name, crate::ff::PrimeFieldDecodingError> {
1082                let mut r = #name(r);
1083                if r.is_valid() {
1084                    r.mul_assign(&#name(R2));
1085
1086                    Ok(r)
1087                } else {
1088                    Err(crate::ff::PrimeFieldDecodingError::NotInField(format!("{}", r.0)))
1089                }
1090            }
1091
1092            fn from_raw_repr(r: #repr) -> Result<Self, crate::ff::PrimeFieldDecodingError> {
1093                let mut r = #name(r);
1094                if r.is_valid() {
1095                    Ok(r)
1096                } else {
1097                    Err(crate::ff::PrimeFieldDecodingError::NotInField(format!("{}", r.0)))
1098                }
1099            }
1100
1101            fn into_repr(&self) -> #repr {
1102                let mut r = *self;
1103                r.mont_reduce(
1104                    #into_repr_params
1105                );
1106
1107                r.0
1108            }
1109
1110            fn into_raw_repr(&self) -> #repr {
1111                let r = *self;
1112
1113                r.0
1114            }
1115
1116            fn char() -> #repr {
1117                MODULUS
1118            }
1119
1120            const NUM_BITS: u32 = MODULUS_BITS;
1121
1122            const CAPACITY: u32 = Self::NUM_BITS - 1;
1123
1124            fn multiplicative_generator() -> Self {
1125                #name(GENERATOR)
1126            }
1127
1128            const S: u32 = S;
1129
1130            fn root_of_unity() -> Self {
1131                #name(ROOT_OF_UNITY)
1132            }
1133
1134        }
1135
1136        impl crate::ff::Field for #name {
1137            #[inline]
1138            fn zero() -> Self {
1139                #name(#repr::from(0))
1140            }
1141
1142            #[inline]
1143            fn one() -> Self {
1144                #name(R)
1145            }
1146
1147            #[inline]
1148            fn is_zero(&self) -> bool {
1149                self.0.is_zero()
1150            }
1151
1152            #[inline]
1153            fn add_assign(&mut self, other: &#name) {
1154                // This cannot exceed the backing capacity.
1155                self.0.add_nocarry(&other.0);
1156
1157                // However, it may need to be reduced.
1158                self.reduce();
1159            }
1160
1161            #[inline]
1162            fn double(&mut self) {
1163                // This cannot exceed the backing capacity.
1164                self.0.mul2();
1165
1166                // However, it may need to be reduced.
1167                self.reduce();
1168            }
1169
1170            #[inline]
1171            fn sub_assign(&mut self, other: &#name) {
1172                // If `other` is larger than `self`, we'll need to add the modulus to self first.
1173                if other.0 > self.0 {
1174                    self.0.add_nocarry(&MODULUS);
1175                }
1176
1177                self.0.sub_noborrow(&other.0);
1178            }
1179
1180            #[inline]
1181            fn negate(&mut self) {
1182                if !self.is_zero() {
1183                    let mut tmp = MODULUS;
1184                    tmp.sub_noborrow(&self.0);
1185                    self.0 = tmp;
1186                }
1187            }
1188
1189            fn inverse(&self) -> Option<Self> {
1190                if self.is_zero() {
1191                    None
1192                } else {
1193                    // Guajardo Kumar Paar Pelzl
1194                    // Efficient Software-Implementation of Finite Fields with Applications to Cryptography
1195                    // Algorithm 16 (BEA for Inversion in Fp)
1196
1197                    let one = #repr::from(1);
1198
1199                    let mut u = self.0;
1200                    let mut v = MODULUS;
1201                    let mut b = #name(R2); // Avoids unnecessary reduction step.
1202                    let mut c = Self::zero();
1203
1204                    while u != one && v != one {
1205                        while u.is_even() {
1206                            u.div2();
1207
1208                            if b.0.is_even() {
1209                                b.0.div2();
1210                            } else {
1211                                b.0.add_nocarry(&MODULUS);
1212                                b.0.div2();
1213                            }
1214                        }
1215
1216                        while v.is_even() {
1217                            v.div2();
1218
1219                            if c.0.is_even() {
1220                                c.0.div2();
1221                            } else {
1222                                c.0.add_nocarry(&MODULUS);
1223                                c.0.div2();
1224                            }
1225                        }
1226
1227                        if v < u {
1228                            u.sub_noborrow(&v);
1229                            b.sub_assign(&c);
1230                        } else {
1231                            v.sub_noborrow(&u);
1232                            c.sub_assign(&b);
1233                        }
1234                    }
1235
1236                    if u == one {
1237                        Some(b)
1238                    } else {
1239                        Some(c)
1240                    }
1241                }
1242            }
1243
1244            #[inline(always)]
1245            fn frobenius_map(&mut self, _: usize) {
1246                // This has no effect in a prime field.
1247            }
1248
1249            #[inline]
1250            fn mul_assign(&mut self, other: &#name)
1251            {
1252                #multiply_impl
1253            }
1254
1255            #[inline]
1256            fn square(&mut self)
1257            {
1258                #squaring_impl
1259            }
1260        }
1261
1262        impl std::default::Default for #name {
1263            fn default() -> Self {
1264                Self::zero()
1265            }
1266        }
1267
1268        impl std::hash::Hash for #name {
1269            fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
1270                for limb in self.0.as_ref().iter() {
1271                    limb.hash(state);
1272                }
1273            }
1274        }
1275
1276        impl #name {
1277            /// Determines if the element is really in the field. This is only used
1278            /// internally.
1279            #[inline(always)]
1280            fn is_valid(&self) -> bool {
1281                self.0 < MODULUS
1282            }
1283
1284            /// Subtracts the modulus from this element if this element is not in the
1285            /// field. Only used interally.
1286            #[inline(always)]
1287            fn reduce(&mut self) {
1288                if !self.is_valid() {
1289                    self.0.sub_noborrow(&MODULUS);
1290                }
1291            }
1292
1293            #[inline(always)]
1294            fn mont_reduce(
1295                &mut self,
1296                #mont_paramlist
1297            )
1298            {
1299                // The Montgomery reduction here is based on Algorithm 14.32 in
1300                // Handbook of Applied Cryptography
1301                // <http://cacr.uwaterloo.ca/hac/about/chap14.pdf>.
1302
1303                #montgomery_impl
1304
1305                self.reduce();
1306            }
1307        }
1308
1309        impl ::serde::Serialize for #name {
1310            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1311                where S: ::serde::Serializer
1312            {
1313                let repr = self.into_repr();
1314                repr.serialize(serializer)
1315            }
1316        }
1317
1318        impl<'de> ::serde::Deserialize<'de> for #name {
1319            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1320            where D: ::serde::Deserializer<'de>
1321            {
1322                let repr = #repr::deserialize(deserializer)?;
1323                let new = Self::from_repr(repr).map_err(::serde::de::Error::custom)?;
1324
1325                Ok(new)
1326            }
1327        }
1328    }
1329}