../../.cargo/katex-header.html

winter_math/field/f128/
mod.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6//! An implementation of a 128-bit STARK-friendly prime field with modulus $2^{128} - 45 \cdot 2^{40} + 1$.
7//!
8//! Operations in this field are implemented using Barret reduction and are stored in their
9//! canonical form using `u128` as the backing type. However, this field was not chosen with any
10//! significant thought given to performance, and the implementations of most operations are
11//! sub-optimal as well.
12
13use alloc::{
14    string::{String, ToString},
15    vec::Vec,
16};
17use core::{
18    fmt::{Debug, Display, Formatter},
19    mem,
20    ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
21    slice,
22};
23
24#[cfg(feature = "serde")]
25use serde::{Deserialize, Serialize};
26use utils::{
27    AsBytes, ByteReader, ByteWriter, Deserializable, DeserializationError, Randomizable,
28    Serializable,
29};
30
31use super::{ExtensibleField, FieldElement, StarkField};
32
33#[cfg(test)]
34mod tests;
35
36// CONSTANTS
37// ================================================================================================
38
39// Field modulus = 2^128 - 45 * 2^40 + 1
40const M: u128 = 340282366920938463463374557953744961537;
41
42// 2^40 root of unity
43const G: u128 = 23953097886125630542083529559205016746;
44
45// Number of bytes needed to represent field element
46const ELEMENT_BYTES: usize = core::mem::size_of::<u128>();
47
48// FIELD ELEMENT
49// ================================================================================================
50
51/// Represents a base field element.
52///
53/// Internal values are stored in their canonical form in the range [0, M). The backing type is
54/// `u128`.
55#[derive(Copy, Clone, PartialEq, Eq, Default)]
56#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
57#[cfg_attr(feature = "serde", serde(transparent))]
58pub struct BaseElement(u128);
59
60impl BaseElement {
61    /// Creates a new field element from a u128 value. If the value is greater than or equal to
62    /// the field modulus, modular reduction is silently performed. This function can also be used
63    /// to initialize constants.
64    pub const fn new(value: u128) -> Self {
65        BaseElement(if value < M { value } else { value - M })
66    }
67}
68
69impl FieldElement for BaseElement {
70    type PositiveInteger = u128;
71    type BaseField = Self;
72
73    const EXTENSION_DEGREE: usize = 1;
74
75    const ZERO: Self = BaseElement(0);
76    const ONE: Self = BaseElement(1);
77
78    const ELEMENT_BYTES: usize = ELEMENT_BYTES;
79
80    const IS_CANONICAL: bool = true;
81
82    // ALGEBRA
83    // --------------------------------------------------------------------------------------------
84
85    fn inv(self) -> Self {
86        BaseElement(inv(self.0))
87    }
88
89    fn conjugate(&self) -> Self {
90        BaseElement(self.0)
91    }
92
93    // BASE ELEMENT CONVERSIONS
94    // --------------------------------------------------------------------------------------------
95
96    fn base_element(&self, i: usize) -> Self::BaseField {
97        match i {
98            0 => *self,
99            _ => panic!("element index must be 0, but was {i}"),
100        }
101    }
102
103    fn slice_as_base_elements(elements: &[Self]) -> &[Self::BaseField] {
104        elements
105    }
106
107    fn slice_from_base_elements(elements: &[Self::BaseField]) -> &[Self] {
108        elements
109    }
110
111    // SERIALIZATION / DESERIALIZATION
112    // --------------------------------------------------------------------------------------------
113
114    fn elements_as_bytes(elements: &[Self]) -> &[u8] {
115        // TODO: take endianness into account
116        let p = elements.as_ptr();
117        let len = elements.len() * Self::ELEMENT_BYTES;
118        unsafe { slice::from_raw_parts(p as *const u8, len) }
119    }
120
121    unsafe fn bytes_as_elements(bytes: &[u8]) -> Result<&[Self], DeserializationError> {
122        if !bytes.len().is_multiple_of(Self::ELEMENT_BYTES) {
123            return Err(DeserializationError::InvalidValue(format!(
124                "number of bytes ({}) does not divide into whole number of field elements",
125                bytes.len(),
126            )));
127        }
128
129        let p = bytes.as_ptr();
130        let len = bytes.len() / Self::ELEMENT_BYTES;
131
132        if !(p as usize).is_multiple_of(mem::align_of::<u128>()) {
133            return Err(DeserializationError::InvalidValue(
134                "slice memory alignment is not valid for this field element type".to_string(),
135            ));
136        }
137
138        Ok(slice::from_raw_parts(p as *const Self, len))
139    }
140}
141
142impl StarkField for BaseElement {
143    /// sage: MODULUS = 2^128 - 45 * 2^40 + 1 \
144    /// sage: GF(MODULUS).is_prime_field() \
145    /// True \
146    /// sage: GF(MODULUS).order() \
147    /// 340282366920938463463374557953744961537
148    const MODULUS: Self::PositiveInteger = M;
149    const MODULUS_BITS: u32 = 128;
150
151    /// sage: GF(MODULUS).primitive_element() \
152    /// 3
153    const GENERATOR: Self = BaseElement(3);
154
155    /// sage: is_odd((MODULUS - 1) / 2^40) \
156    /// True
157    const TWO_ADICITY: u32 = 40;
158
159    /// sage: k = (MODULUS - 1) / 2^40 \
160    /// sage: GF(MODULUS).primitive_element()^k \
161    /// 23953097886125630542083529559205016746
162    const TWO_ADIC_ROOT_OF_UNITY: Self = BaseElement(G);
163
164    fn get_modulus_le_bytes() -> Vec<u8> {
165        Self::MODULUS.to_le_bytes().to_vec()
166    }
167
168    #[inline]
169    fn as_int(&self) -> Self::PositiveInteger {
170        self.0
171    }
172}
173
174impl Randomizable for BaseElement {
175    const VALUE_SIZE: usize = Self::ELEMENT_BYTES;
176
177    fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
178        Self::try_from(bytes).ok()
179    }
180}
181
182impl Debug for BaseElement {
183    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
184        write!(f, "{self}")
185    }
186}
187
188impl Display for BaseElement {
189    fn fmt(&self, f: &mut Formatter) -> core::fmt::Result {
190        write!(f, "{}", self.0)
191    }
192}
193
194// OVERLOADED OPERATORS
195// ================================================================================================
196
197impl Add for BaseElement {
198    type Output = Self;
199
200    fn add(self, rhs: Self) -> Self {
201        Self(add(self.0, rhs.0))
202    }
203}
204
205impl AddAssign for BaseElement {
206    fn add_assign(&mut self, rhs: Self) {
207        *self = *self + rhs
208    }
209}
210
211impl Sub for BaseElement {
212    type Output = Self;
213
214    fn sub(self, rhs: Self) -> Self {
215        Self(sub(self.0, rhs.0))
216    }
217}
218
219impl SubAssign for BaseElement {
220    fn sub_assign(&mut self, rhs: Self) {
221        *self = *self - rhs;
222    }
223}
224
225impl Mul for BaseElement {
226    type Output = Self;
227
228    fn mul(self, rhs: Self) -> Self {
229        Self(mul(self.0, rhs.0))
230    }
231}
232
233impl MulAssign for BaseElement {
234    fn mul_assign(&mut self, rhs: Self) {
235        *self = *self * rhs
236    }
237}
238
239impl Div for BaseElement {
240    type Output = Self;
241
242    fn div(self, rhs: Self) -> Self {
243        Self(mul(self.0, inv(rhs.0)))
244    }
245}
246
247impl DivAssign for BaseElement {
248    fn div_assign(&mut self, rhs: Self) {
249        *self = *self / rhs
250    }
251}
252
253impl Neg for BaseElement {
254    type Output = Self;
255
256    fn neg(self) -> Self {
257        Self(sub(0, self.0))
258    }
259}
260
261// QUADRATIC EXTENSION
262// ================================================================================================
263
264/// Defines a quadratic extension of the base field over an irreducible polynomial x<sup>2</sup> -
265/// x - 1. Thus, an extension element is defined as α + β * φ, where φ is a root of this polynomial,
266/// and α and β are base field elements.
267impl ExtensibleField<2> for BaseElement {
268    #[inline(always)]
269    fn mul(a: [Self; 2], b: [Self; 2]) -> [Self; 2] {
270        let z = a[0] * b[0];
271        [z + a[1] * b[1], (a[0] + a[1]) * (b[0] + b[1]) - z]
272    }
273
274    #[inline(always)]
275    fn mul_base(a: [Self; 2], b: Self) -> [Self; 2] {
276        [a[0] * b, a[1] * b]
277    }
278
279    #[inline(always)]
280    fn frobenius(x: [Self; 2]) -> [Self; 2] {
281        [x[0] + x[1], Self::ZERO - x[1]]
282    }
283}
284
285// CUBIC EXTENSION
286// ================================================================================================
287
288/// Cubic extension for this field is not implemented as quadratic extension already provides
289/// sufficient security level.
290impl ExtensibleField<3> for BaseElement {
291    fn mul(_a: [Self; 3], _b: [Self; 3]) -> [Self; 3] {
292        unimplemented!()
293    }
294
295    #[inline(always)]
296    fn mul_base(_a: [Self; 3], _b: Self) -> [Self; 3] {
297        unimplemented!()
298    }
299
300    #[inline(always)]
301    fn frobenius(_x: [Self; 3]) -> [Self; 3] {
302        unimplemented!()
303    }
304
305    fn is_supported() -> bool {
306        false
307    }
308}
309
310// TYPE CONVERSIONS
311// ================================================================================================
312
313impl From<u64> for BaseElement {
314    /// Converts a 64-bit value into a field element.
315    fn from(value: u64) -> Self {
316        BaseElement(value as u128)
317    }
318}
319
320impl From<u32> for BaseElement {
321    /// Converts a 32-bit value into a field element.
322    fn from(value: u32) -> Self {
323        BaseElement(value as u128)
324    }
325}
326
327impl From<u16> for BaseElement {
328    /// Converts a 16-bit value into a field element.
329    fn from(value: u16) -> Self {
330        BaseElement(value as u128)
331    }
332}
333
334impl From<u8> for BaseElement {
335    /// Converts an 8-bit value into a field element.
336    fn from(value: u8) -> Self {
337        BaseElement(value as u128)
338    }
339}
340
341impl TryFrom<u128> for BaseElement {
342    type Error = String;
343
344    fn try_from(value: u128) -> Result<Self, Self::Error> {
345        if value >= M {
346            Err(format!(
347                "invalid field element: value {value} is greater than or equal to the field modulus"
348            ))
349        } else {
350            Ok(Self::new(value))
351        }
352    }
353}
354
355impl TryFrom<&'_ [u8]> for BaseElement {
356    type Error = String;
357
358    /// Converts a slice of bytes into a field element; returns error if the value encoded in bytes
359    /// is not a valid field element. The bytes are assumed to be in little-endian byte order.
360    fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
361        let value =
362            bytes.try_into().map(u128::from_le_bytes).map_err(|error| format!("{error}"))?;
363        if value >= M {
364            return Err(format!(
365                "cannot convert bytes into a field element: \
366                value {value} is greater or equal to the field modulus"
367            ));
368        }
369        Ok(BaseElement(value))
370    }
371}
372
373impl AsBytes for BaseElement {
374    fn as_bytes(&self) -> &[u8] {
375        // TODO: take endianness into account
376        let self_ptr: *const BaseElement = self;
377        unsafe { slice::from_raw_parts(self_ptr as *const u8, BaseElement::ELEMENT_BYTES) }
378    }
379}
380
381// SERIALIZATION / DESERIALIZATION
382// ------------------------------------------------------------------------------------------------
383
384impl Serializable for BaseElement {
385    fn write_into<W: ByteWriter>(&self, target: &mut W) {
386        target.write_bytes(&self.0.to_le_bytes());
387    }
388
389    fn get_size_hint(&self) -> usize {
390        self.0.get_size_hint()
391    }
392}
393
394impl Deserializable for BaseElement {
395    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
396        let value = source.read_u128()?;
397        if value >= M {
398            return Err(DeserializationError::InvalidValue(format!(
399                "invalid field element: value {value} is greater than or equal to the field modulus"
400            )));
401        }
402        Ok(BaseElement(value))
403    }
404}
405
406// FINITE FIELD ARITHMETIC
407// ================================================================================================
408
409/// Computes (a + b) % m; a and b are assumed to be valid field elements.
410fn add(a: u128, b: u128) -> u128 {
411    let z = M - b;
412    if a < z {
413        M - z + a
414    } else {
415        a - z
416    }
417}
418
419/// Computes (a - b) % m; a and b are assumed to be valid field elements.
420fn sub(a: u128, b: u128) -> u128 {
421    if a < b {
422        M - b + a
423    } else {
424        a - b
425    }
426}
427
428/// Computes (a * b) % m; a and b are assumed to be valid field elements.
429fn mul(a: u128, b: u128) -> u128 {
430    let (x0, x1, x2) = mul_128x64(a, (b >> 64) as u64); // x = a * b_hi
431    let (mut x0, mut x1, x2) = mul_reduce(x0, x1, x2); // x = x - (x >> 128) * m
432    if x2 == 1 {
433        // if there was an overflow beyond 128 bits, subtract
434        // modulus from the result to make sure it fits into
435        // 128 bits; this can potentially be removed in favor
436        // of checking overflow later
437        let (t0, t1) = sub_modulus(x0, x1); // x = x - m
438        x0 = t0;
439        x1 = t1;
440    }
441
442    let (y0, y1, y2) = mul_128x64(a, b as u64); // y = a * b_lo
443
444    let (mut y1, carry) = add64_with_carry(y1, x0, 0); // y = y + (x << 64)
445    let (mut y2, y3) = add64_with_carry(y2, x1, carry);
446    if y3 == 1 {
447        // if there was an overflow beyond 192 bits, subtract
448        // modulus * 2^64 from the result to make sure it fits
449        // into 192 bits; this can potentially replace the
450        // previous overflow check (but needs to be proven)
451        let (t0, t1) = sub_modulus(y1, y2); // y = y - (m << 64)
452        y1 = t0;
453        y2 = t1;
454    }
455
456    let (mut z0, mut z1, z2) = mul_reduce(y0, y1, y2); // z = y - (y >> 128) * m
457
458    // make sure z is smaller than m
459    if z2 == 1 || (z1 == (M >> 64) as u64 && z0 >= (M as u64)) {
460        let (t0, t1) = sub_modulus(z0, z1); // z = z - m
461        z0 = t0;
462        z1 = t1;
463    }
464
465    ((z1 as u128) << 64) + (z0 as u128)
466}
467
468/// Computes y such that (x * y) % m = 1 except for when when x = 0; in such a case,
469/// 0 is returned; x is assumed to be a valid field element.
470fn inv(x: u128) -> u128 {
471    if x == 0 {
472        return 0;
473    };
474
475    // initialize v, a, u, and d variables
476    let mut v = M;
477    let (mut a0, mut a1, mut a2) = (0, 0, 0);
478    let (mut u0, mut u1, mut u2) = if x & 1 == 1 {
479        // u = x
480        (x as u64, (x >> 64) as u64, 0)
481    } else {
482        // u = x + m
483        add_192x192(x as u64, (x >> 64) as u64, 0, M as u64, (M >> 64) as u64, 0)
484    };
485    // d = m - 1
486    let (mut d0, mut d1, mut d2) = ((M as u64) - 1, (M >> 64) as u64, 0);
487
488    // compute the inverse
489    while v != 1 {
490        while u2 > 0 || ((u0 as u128) + ((u1 as u128) << 64)) > v {
491            // u > v
492            // u = u - v
493            let (t0, t1, t2) = sub_192x192(u0, u1, u2, v as u64, (v >> 64) as u64, 0);
494            u0 = t0;
495            u1 = t1;
496            u2 = t2;
497
498            // d = d + a
499            let (t0, t1, t2) = add_192x192(d0, d1, d2, a0, a1, a2);
500            d0 = t0;
501            d1 = t1;
502            d2 = t2;
503
504            while u0 & 1 == 0 {
505                if d0 & 1 == 1 {
506                    // d = d + m
507                    let (t0, t1, t2) = add_192x192(d0, d1, d2, M as u64, (M >> 64) as u64, 0);
508                    d0 = t0;
509                    d1 = t1;
510                    d2 = t2;
511                }
512
513                // u = u >> 1
514                u0 = (u0 >> 1) | ((u1 & 1) << 63);
515                u1 = (u1 >> 1) | ((u2 & 1) << 63);
516                u2 >>= 1;
517
518                // d = d >> 1
519                d0 = (d0 >> 1) | ((d1 & 1) << 63);
520                d1 = (d1 >> 1) | ((d2 & 1) << 63);
521                d2 >>= 1;
522            }
523        }
524
525        // v = v - u (u is less than v at this point)
526        v -= (u0 as u128) + ((u1 as u128) << 64);
527
528        // a = a + d
529        let (t0, t1, t2) = add_192x192(a0, a1, a2, d0, d1, d2);
530        a0 = t0;
531        a1 = t1;
532        a2 = t2;
533
534        while v & 1 == 0 {
535            if a0 & 1 == 1 {
536                // a = a + m
537                let (t0, t1, t2) = add_192x192(a0, a1, a2, M as u64, (M >> 64) as u64, 0);
538                a0 = t0;
539                a1 = t1;
540                a2 = t2;
541            }
542
543            v >>= 1;
544
545            // a = a >> 1
546            a0 = (a0 >> 1) | ((a1 & 1) << 63);
547            a1 = (a1 >> 1) | ((a2 & 1) << 63);
548            a2 >>= 1;
549        }
550    }
551
552    // a = a mod m
553    let mut a = (a0 as u128) + ((a1 as u128) << 64);
554    while a2 > 0 || a >= M {
555        let (t0, t1, t2) = sub_192x192(a0, a1, a2, M as u64, (M >> 64) as u64, 0);
556        a0 = t0;
557        a1 = t1;
558        a2 = t2;
559        a = (a0 as u128) + ((a1 as u128) << 64);
560    }
561
562    a
563}
564
565// HELPER FUNCTIONS
566// ================================================================================================
567
568#[inline]
569fn mul_128x64(a: u128, b: u64) -> (u64, u64, u64) {
570    let z_lo = ((a as u64) as u128) * (b as u128);
571    let z_hi = (a >> 64) * (b as u128);
572    let z_hi = z_hi + (z_lo >> 64);
573    (z_lo as u64, z_hi as u64, (z_hi >> 64) as u64)
574}
575
576#[inline]
577fn mul_reduce(z0: u64, z1: u64, z2: u64) -> (u64, u64, u64) {
578    let (q0, q1, q2) = mul_by_modulus(z2);
579    let (z0, z1, z2) = sub_192x192(z0, z1, z2, q0, q1, q2);
580    (z0, z1, z2)
581}
582
583#[inline]
584fn mul_by_modulus(a: u64) -> (u64, u64, u64) {
585    let a_lo = (a as u128).wrapping_mul(M);
586    let a_hi = if a == 0 { 0 } else { a - 1 };
587    (a_lo as u64, (a_lo >> 64) as u64, a_hi)
588}
589
590#[inline]
591fn sub_modulus(a_lo: u64, a_hi: u64) -> (u64, u64) {
592    let mut z = 0u128.wrapping_sub(M);
593    z = z.wrapping_add(a_lo as u128);
594    z = z.wrapping_add((a_hi as u128) << 64);
595    (z as u64, (z >> 64) as u64)
596}
597
598#[inline]
599fn sub_192x192(a0: u64, a1: u64, a2: u64, b0: u64, b1: u64, b2: u64) -> (u64, u64, u64) {
600    let z0 = (a0 as u128).wrapping_sub(b0 as u128);
601    let z1 = (a1 as u128).wrapping_sub((b1 as u128) + (z0 >> 127));
602    let z2 = (a2 as u128).wrapping_sub((b2 as u128) + (z1 >> 127));
603    (z0 as u64, z1 as u64, z2 as u64)
604}
605
606#[inline]
607fn add_192x192(a0: u64, a1: u64, a2: u64, b0: u64, b1: u64, b2: u64) -> (u64, u64, u64) {
608    let z0 = (a0 as u128) + (b0 as u128);
609    let z1 = (a1 as u128) + (b1 as u128) + (z0 >> 64);
610    let z2 = (a2 as u128) + (b2 as u128) + (z1 >> 64);
611    (z0 as u64, z1 as u64, z2 as u64)
612}
613
614#[inline]
615const fn add64_with_carry(a: u64, b: u64, carry: u64) -> (u64, u64) {
616    let ret = (a as u128) + (b as u128) + (carry as u128);
617    (ret as u64, (ret >> 64) as u64)
618}