../../.cargo/katex-header.html

winter_math/field/f62/
mod.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6//! An implementation of a 62-bit STARK-friendly prime field with modulus $2^{62} - 111 \cdot 2^{39} + 1$.
7//!
8//! All operations in this field are implemented using Montgomery arithmetic. It supports very
9//! fast modular arithmetic including branchless multiplication and addition. Base elements are
10//! stored in the Montgomery form using `u64` as the backing type.
11
12use alloc::{
13    string::{String, ToString},
14    vec::Vec,
15};
16use core::{
17    fmt::{Debug, Display, Formatter},
18    mem,
19    ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
20    slice,
21};
22
23#[cfg(feature = "serde")]
24use serde::{Deserialize, Serialize};
25use utils::{
26    AsBytes, ByteReader, ByteWriter, Deserializable, DeserializationError, Randomizable,
27    Serializable,
28};
29
30use super::{ExtensibleField, FieldElement, StarkField};
31
32#[cfg(test)]
33mod tests;
34
35// CONSTANTS
36// ================================================================================================
37
38/// Field modulus = 2^62 - 111 * 2^39 + 1
39const M: u64 = 4611624995532046337;
40
41/// 2^128 mod M; this is used for conversion of elements into Montgomery representation.
42const R2: u64 = 630444561284293700;
43
44/// 2^192 mod M; this is used during element inversion.
45const R3: u64 = 732984146687909319;
46
47/// -M^{-1} mod 2^64; this is used during element multiplication.
48const U: u128 = 4611624995532046335;
49
50/// Number of bytes needed to represent field element
51const ELEMENT_BYTES: usize = core::mem::size_of::<u64>();
52
53// 2^39 root of unity
54const G: u64 = 4421547261963328785;
55
56// FIELD ELEMENT
57// ================================================================================================
58
59/// Represents base field element in the field.
60///
61/// Internal values are stored in Montgomery representation and can be in the range [0; 2M). The
62/// backing type is `u64`.
63#[derive(Copy, Clone, Default)]
64#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
65#[cfg_attr(feature = "serde", serde(try_from = "u64", into = "u64"))]
66pub struct BaseElement(u64);
67
68impl BaseElement {
69    /// Creates a new field element from the provided `value`; the value is converted into
70    /// Montgomery representation.
71    pub const fn new(value: u64) -> BaseElement {
72        // multiply the value with R2 to convert to Montgomery representation; this is OK because
73        // given the value of R2, the product of R2 and `value` is guaranteed to be in the range
74        // [0, 4M^2 - 4M + 1)
75        let z = mul(value, R2);
76        BaseElement(z)
77    }
78}
79
80impl FieldElement for BaseElement {
81    type PositiveInteger = u64;
82    type BaseField = Self;
83
84    const EXTENSION_DEGREE: usize = 1;
85
86    const ZERO: Self = BaseElement::new(0);
87    const ONE: Self = BaseElement::new(1);
88
89    const ELEMENT_BYTES: usize = ELEMENT_BYTES;
90    const IS_CANONICAL: bool = false;
91
92    // ALGEBRA
93    // --------------------------------------------------------------------------------------------
94
95    #[inline]
96    fn double(self) -> Self {
97        let z = self.0 << 1;
98        let q = (z >> 62) * M;
99        Self(z - q)
100    }
101
102    fn exp(self, power: Self::PositiveInteger) -> Self {
103        let mut b = self;
104
105        if power == 0 {
106            return Self::ONE;
107        } else if b == Self::ZERO {
108            return Self::ZERO;
109        }
110
111        let mut r = if power & 1 == 1 { b } else { Self::ONE };
112        for i in 1..64 - power.leading_zeros() {
113            b = b.square();
114            if (power >> i) & 1 == 1 {
115                r *= b;
116            }
117        }
118
119        r
120    }
121
122    fn inv(self) -> Self {
123        BaseElement(inv(self.0))
124    }
125
126    fn conjugate(&self) -> Self {
127        BaseElement(self.0)
128    }
129
130    // BASE ELEMENT CONVERSIONS
131    // --------------------------------------------------------------------------------------------
132
133    fn base_element(&self, i: usize) -> Self::BaseField {
134        match i {
135            0 => *self,
136            _ => panic!("element index must be 0, but was {i}"),
137        }
138    }
139
140    fn slice_as_base_elements(elements: &[Self]) -> &[Self::BaseField] {
141        elements
142    }
143
144    fn slice_from_base_elements(elements: &[Self::BaseField]) -> &[Self] {
145        elements
146    }
147
148    // SERIALIZATION / DESERIALIZATION
149    // --------------------------------------------------------------------------------------------
150
151    fn elements_as_bytes(elements: &[Self]) -> &[u8] {
152        // TODO: take endianness into account
153        let p = elements.as_ptr();
154        let len = elements.len() * Self::ELEMENT_BYTES;
155        unsafe { slice::from_raw_parts(p as *const u8, len) }
156    }
157
158    unsafe fn bytes_as_elements(bytes: &[u8]) -> Result<&[Self], DeserializationError> {
159        if bytes.len() % Self::ELEMENT_BYTES != 0 {
160            return Err(DeserializationError::InvalidValue(format!(
161                "number of bytes ({}) does not divide into whole number of field elements",
162                bytes.len(),
163            )));
164        }
165
166        let p = bytes.as_ptr();
167        let len = bytes.len() / Self::ELEMENT_BYTES;
168
169        if (p as usize) % mem::align_of::<u64>() != 0 {
170            return Err(DeserializationError::InvalidValue(
171                "slice memory alignment is not valid for this field element type".to_string(),
172            ));
173        }
174
175        Ok(slice::from_raw_parts(p as *const Self, len))
176    }
177}
178
179impl StarkField for BaseElement {
180    /// sage: MODULUS = 2^62 - 111 * 2^39 + 1 \
181    /// sage: GF(MODULUS).is_prime_field() \
182    /// True \
183    /// sage: GF(MODULUS).order() \
184    /// 4611624995532046337
185    const MODULUS: Self::PositiveInteger = M;
186    const MODULUS_BITS: u32 = 62;
187
188    /// sage: GF(MODULUS).primitive_element() \
189    /// 3
190    const GENERATOR: Self = BaseElement::new(3);
191
192    /// sage: is_odd((MODULUS - 1) / 2^39) \
193    /// True
194    const TWO_ADICITY: u32 = 39;
195
196    /// sage: k = (MODULUS - 1) / 2^39 \
197    /// sage: GF(MODULUS).primitive_element()^k \
198    /// 4421547261963328785
199    const TWO_ADIC_ROOT_OF_UNITY: Self = BaseElement::new(G);
200
201    fn get_modulus_le_bytes() -> Vec<u8> {
202        Self::MODULUS.to_le_bytes().to_vec()
203    }
204
205    #[inline]
206    fn as_int(&self) -> Self::PositiveInteger {
207        // convert from Montgomery representation by multiplying by 1
208        let result = mul(self.0, 1);
209        // since the result of multiplication can be in [0, 2M), we need to normalize it
210        normalize(result)
211    }
212}
213
214impl Randomizable for BaseElement {
215    const VALUE_SIZE: usize = Self::ELEMENT_BYTES;
216
217    fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
218        Self::try_from(bytes).ok()
219    }
220}
221
222impl Debug for BaseElement {
223    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
224        write!(f, "{}", self)
225    }
226}
227
228impl Display for BaseElement {
229    fn fmt(&self, f: &mut Formatter) -> core::fmt::Result {
230        write!(f, "{}", self.as_int())
231    }
232}
233
234// EQUALITY CHECKS
235// ================================================================================================
236
237impl PartialEq for BaseElement {
238    #[inline]
239    fn eq(&self, other: &Self) -> bool {
240        // since either of the elements can be in [0, 2M) range, we normalize them first to be
241        // in [0, M) range and then compare them.
242        normalize(self.0) == normalize(other.0)
243    }
244}
245
246impl Eq for BaseElement {}
247
248// OVERLOADED OPERATORS
249// ================================================================================================
250
251impl Add for BaseElement {
252    type Output = Self;
253
254    fn add(self, rhs: Self) -> Self {
255        Self(add(self.0, rhs.0))
256    }
257}
258
259impl AddAssign for BaseElement {
260    fn add_assign(&mut self, rhs: Self) {
261        *self = *self + rhs
262    }
263}
264
265impl Sub for BaseElement {
266    type Output = Self;
267
268    fn sub(self, rhs: Self) -> Self {
269        Self(sub(self.0, rhs.0))
270    }
271}
272
273impl SubAssign for BaseElement {
274    fn sub_assign(&mut self, rhs: Self) {
275        *self = *self - rhs;
276    }
277}
278
279impl Mul for BaseElement {
280    type Output = Self;
281
282    fn mul(self, rhs: Self) -> Self {
283        Self(mul(self.0, rhs.0))
284    }
285}
286
287impl MulAssign for BaseElement {
288    fn mul_assign(&mut self, rhs: Self) {
289        *self = *self * rhs
290    }
291}
292
293impl Div for BaseElement {
294    type Output = Self;
295
296    fn div(self, rhs: Self) -> Self {
297        Self(mul(self.0, inv(rhs.0)))
298    }
299}
300
301impl DivAssign for BaseElement {
302    fn div_assign(&mut self, rhs: Self) {
303        *self = *self / rhs
304    }
305}
306
307impl Neg for BaseElement {
308    type Output = Self;
309
310    fn neg(self) -> Self {
311        Self(sub(0, self.0))
312    }
313}
314
315// QUADRATIC EXTENSION
316// ================================================================================================
317
318/// Defines a quadratic extension of the base field over an irreducible polynomial x<sup>2</sup> -
319/// x - 1. Thus, an extension element is defined as α + β * φ, where φ is a root of this polynomial,
320/// and α and β are base field elements.
321impl ExtensibleField<2> for BaseElement {
322    #[inline(always)]
323    fn mul(a: [Self; 2], b: [Self; 2]) -> [Self; 2] {
324        let z = a[0] * b[0];
325        [z + a[1] * b[1], (a[0] + a[1]) * (b[0] + b[1]) - z]
326    }
327
328    #[inline(always)]
329    fn mul_base(a: [Self; 2], b: Self) -> [Self; 2] {
330        [a[0] * b, a[1] * b]
331    }
332
333    #[inline(always)]
334    fn frobenius(x: [Self; 2]) -> [Self; 2] {
335        [x[0] + x[1], -x[1]]
336    }
337}
338
339// CUBIC EXTENSION
340// ================================================================================================
341
342/// Defines a cubic extension of the base field over an irreducible polynomial x<sup>3</sup> +
343/// 2x + 2. Thus, an extension element is defined as α + β * φ + γ * φ^2, where φ is a root of this
344/// polynomial, and α, β and γ are base field elements.
345impl ExtensibleField<3> for BaseElement {
346    #[inline(always)]
347    fn mul(a: [Self; 3], b: [Self; 3]) -> [Self; 3] {
348        // performs multiplication in the extension field using 6 multiplications, 8 additions,
349        // and 9 subtractions in the base field. overall, a single multiplication in the extension
350        // field is roughly equal to 12 multiplications in the base field.
351        let a0b0 = a[0] * b[0];
352        let a1b1 = a[1] * b[1];
353        let a2b2 = a[2] * b[2];
354
355        let a0b0_a0b1_a1b0_a1b1 = (a[0] + a[1]) * (b[0] + b[1]);
356        let minus_a0b0_a0b2_a2b0_minus_a2b2 = (a[0] - a[2]) * (b[2] - b[0]);
357        let a1b1_minus_a1b2_minus_a2b1_a2b2 = (a[1] - a[2]) * (b[1] - b[2]);
358        let a0b0_a1b1 = a0b0 + a1b1;
359
360        let minus_2a1b2_minus_2a2b1 = (a1b1_minus_a1b2_minus_a2b1_a2b2 - a1b1 - a2b2).double();
361
362        let a0b0_minus_2a1b2_minus_2a2b1 = a0b0 + minus_2a1b2_minus_2a2b1;
363        let a0b1_a1b0_minus_2a1b2_minus_2a2b1_minus_2a2b2 =
364            a0b0_a0b1_a1b0_a1b1 + minus_2a1b2_minus_2a2b1 - a2b2.double() - a0b0_a1b1;
365        let a0b2_a1b1_a2b0_minus_2a2b2 = minus_a0b0_a0b2_a2b0_minus_a2b2 + a0b0_a1b1 - a2b2;
366        [
367            a0b0_minus_2a1b2_minus_2a2b1,
368            a0b1_a1b0_minus_2a1b2_minus_2a2b1_minus_2a2b2,
369            a0b2_a1b1_a2b0_minus_2a2b2,
370        ]
371    }
372
373    #[inline(always)]
374    fn mul_base(a: [Self; 3], b: Self) -> [Self; 3] {
375        [a[0] * b, a[1] * b, a[2] * b]
376    }
377
378    #[inline(always)]
379    fn frobenius(x: [Self; 3]) -> [Self; 3] {
380        // coefficients were computed using SageMath
381        [
382            x[0] + BaseElement::new(2061766055618274781) * x[1]
383                + BaseElement::new(786836585661389001) * x[2],
384            BaseElement::new(2868591307402993000) * x[1]
385                + BaseElement::new(3336695525575160559) * x[2],
386            BaseElement::new(2699230790596717670) * x[1]
387                + BaseElement::new(1743033688129053336) * x[2],
388        ]
389    }
390}
391
392// TYPE CONVERSIONS
393// ================================================================================================
394
395impl From<u32> for BaseElement {
396    /// Converts a 32-bit value into a field element.
397    fn from(value: u32) -> Self {
398        BaseElement::new(value as u64)
399    }
400}
401
402impl From<u16> for BaseElement {
403    /// Converts a 16-bit value into a field element.
404    fn from(value: u16) -> Self {
405        BaseElement::new(value as u64)
406    }
407}
408
409impl From<u8> for BaseElement {
410    /// Converts an 8-bit value into a field element.
411    fn from(value: u8) -> Self {
412        BaseElement::new(value as u64)
413    }
414}
415
416impl From<BaseElement> for u128 {
417    fn from(value: BaseElement) -> Self {
418        value.as_int() as u128
419    }
420}
421
422impl From<BaseElement> for u64 {
423    fn from(value: BaseElement) -> Self {
424        value.as_int()
425    }
426}
427
428impl TryFrom<u64> for BaseElement {
429    type Error = String;
430
431    fn try_from(value: u64) -> Result<Self, Self::Error> {
432        if value >= M {
433            Err(format!(
434                "invalid field element: value {value} is greater than or equal to the field modulus"
435            ))
436        } else {
437            Ok(Self::new(value))
438        }
439    }
440}
441
442impl TryFrom<u128> for BaseElement {
443    type Error = String;
444
445    fn try_from(value: u128) -> Result<Self, Self::Error> {
446        if value >= M as u128 {
447            Err(format!(
448                "invalid field element: value {value} is greater than or equal to the field modulus"
449            ))
450        } else {
451            Ok(Self::new(value as u64))
452        }
453    }
454}
455
456impl TryFrom<[u8; 8]> for BaseElement {
457    type Error = String;
458
459    fn try_from(bytes: [u8; 8]) -> Result<Self, Self::Error> {
460        let value = u64::from_le_bytes(bytes);
461        Self::try_from(value)
462    }
463}
464
465impl TryFrom<&'_ [u8]> for BaseElement {
466    type Error = DeserializationError;
467
468    /// Converts a slice of bytes into a field element; returns error if the value encoded in bytes
469    /// is not a valid field element. The bytes are assumed to encode the element in the canonical
470    /// representation in little-endian byte order.
471    fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
472        if bytes.len() < ELEMENT_BYTES {
473            return Err(DeserializationError::InvalidValue(format!(
474                "not enough bytes for a full field element; expected {} bytes, but was {} bytes",
475                ELEMENT_BYTES,
476                bytes.len(),
477            )));
478        }
479        if bytes.len() > ELEMENT_BYTES {
480            return Err(DeserializationError::InvalidValue(format!(
481                "too many bytes for a field element; expected {} bytes, but was {} bytes",
482                ELEMENT_BYTES,
483                bytes.len(),
484            )));
485        }
486        let value = bytes
487            .try_into()
488            .map(u64::from_le_bytes)
489            .map_err(|error| DeserializationError::UnknownError(format!("{error}")))?;
490        if value >= M {
491            return Err(DeserializationError::InvalidValue(format!(
492                "invalid field element: value {value} is greater than or equal to the field modulus"
493            )));
494        }
495        Ok(BaseElement::new(value))
496    }
497}
498
499impl AsBytes for BaseElement {
500    fn as_bytes(&self) -> &[u8] {
501        // TODO: take endianness into account
502        let self_ptr: *const BaseElement = self;
503        unsafe { slice::from_raw_parts(self_ptr as *const u8, ELEMENT_BYTES) }
504    }
505}
506
507// SERIALIZATION / DESERIALIZATION
508// ------------------------------------------------------------------------------------------------
509
510impl Serializable for BaseElement {
511    fn write_into<W: ByteWriter>(&self, target: &mut W) {
512        // convert from Montgomery representation into canonical representation
513        target.write_bytes(&self.as_int().to_le_bytes());
514    }
515
516    fn get_size_hint(&self) -> usize {
517        self.as_int().get_size_hint()
518    }
519}
520
521impl Deserializable for BaseElement {
522    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
523        let value = source.read_u64()?;
524        if value >= M {
525            return Err(DeserializationError::InvalidValue(format!(
526                "invalid field element: value {value} is greater than or equal to the field modulus"
527            )));
528        }
529        Ok(BaseElement::new(value))
530    }
531}
532
533// FINITE FIELD ARITHMETIC
534// ================================================================================================
535
536/// Computes (a + b) reduced by M such that the output is in [0, 2M) range; a and b are assumed to
537/// be in [0, 2M).
538#[inline(always)]
539fn add(a: u64, b: u64) -> u64 {
540    let z = a + b;
541    let q = (z >> 62) * M;
542    z - q
543}
544
545/// Computes (a - b) reduced by M such that the output is in [0, 2M) range; a and b are assumed to
546/// be in [0, 2M).
547#[inline(always)]
548fn sub(a: u64, b: u64) -> u64 {
549    if a < b {
550        2 * M - b + a
551    } else {
552        a - b
553    }
554}
555
556/// Computes (a * b) reduced by M such that the output is in [0, 2M) range; a and b are assumed to
557/// be in [0, 2M).
558#[inline(always)]
559const fn mul(a: u64, b: u64) -> u64 {
560    let z = (a as u128) * (b as u128);
561    let q = (((z as u64) as u128) * U) as u64;
562    let z = z + (q as u128) * (M as u128);
563    (z >> 64) as u64
564}
565
566/// Computes y such that (x * y) % M = 1 except for when when x = 0; in such a case, 0 is returned;
567/// x is assumed to in [0, 2M) range, and the output will also be in [0, 2M) range.
568#[inline(always)]
569#[allow(clippy::many_single_char_names)]
570fn inv(x: u64) -> u64 {
571    if x == 0 {
572        return 0;
573    };
574
575    let mut a: u128 = 0;
576    let mut u: u128 = if x & 1 == 1 {
577        x as u128
578    } else {
579        (x as u128) + (M as u128)
580    };
581    let mut v: u128 = M as u128;
582    let mut d = (M as u128) - 1;
583
584    while v != 1 {
585        while v < u {
586            u -= v;
587            d += a;
588            while u & 1 == 0 {
589                if d & 1 == 1 {
590                    d += M as u128;
591                }
592                u >>= 1;
593                d >>= 1;
594            }
595        }
596
597        v -= u;
598        a += d;
599
600        while v & 1 == 0 {
601            if a & 1 == 1 {
602                a += M as u128;
603            }
604            v >>= 1;
605            a >>= 1;
606        }
607    }
608
609    while a > (M as u128) {
610        a -= M as u128;
611    }
612
613    mul(a as u64, R3)
614}
615
616// HELPER FUNCTIONS
617// ================================================================================================
618
619/// Reduces any value in [0, 2M) range to [0, M) range
620#[inline(always)]
621fn normalize(value: u64) -> u64 {
622    if value >= M {
623        value - M
624    } else {
625        value
626    }
627}