../../.cargo/katex-header.html

winter_math/field/extensions/
cubic.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6use alloc::string::{String, ToString};
7use core::{
8    fmt,
9    ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
10    slice,
11};
12
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15use utils::{
16    AsBytes, ByteReader, ByteWriter, Deserializable, DeserializationError, Randomizable,
17    Serializable, SliceReader,
18};
19
20use super::{ExtensibleField, ExtensionOf, FieldElement};
21
22// QUADRATIC EXTENSION FIELD
23// ================================================================================================
24
25/// Represents an element in a cubic extension of a [StarkField](crate::StarkField).
26///
27/// The extension element is defined as α + β * φ + γ * φ^2, where φ is a root of in irreducible
28/// polynomial defined by the implementation of the [ExtensibleField] trait, and α, β, γ are base
29/// field elements.
30#[repr(C)]
31#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
32#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
33pub struct CubeExtension<B: ExtensibleField<3>>(B, B, B);
34
35impl<B: ExtensibleField<3>> CubeExtension<B> {
36    /// Returns a new extension element instantiated from the provided base elements.
37    pub const fn new(a: B, b: B, c: B) -> Self {
38        Self(a, b, c)
39    }
40
41    /// Returns true if the base field specified by B type parameter supports cubic extensions.
42    pub fn is_supported() -> bool {
43        <B as ExtensibleField<3>>::is_supported()
44    }
45
46    /// Returns an array of base field elements comprising this extension field element.
47    ///
48    /// The order of abase elements in the returned array is the same as the order in which
49    /// the elements are provided to the [CubeExtension::new()] constructor.
50    pub const fn to_base_elements(self) -> [B; 3] {
51        [self.0, self.1, self.2]
52    }
53}
54
55impl<B: ExtensibleField<3>> FieldElement for CubeExtension<B> {
56    type PositiveInteger = B::PositiveInteger;
57    type BaseField = B;
58
59    const EXTENSION_DEGREE: usize = 3;
60
61    const ELEMENT_BYTES: usize = B::ELEMENT_BYTES * Self::EXTENSION_DEGREE;
62    const IS_CANONICAL: bool = B::IS_CANONICAL;
63    const ZERO: Self = Self(B::ZERO, B::ZERO, B::ZERO);
64    const ONE: Self = Self(B::ONE, B::ZERO, B::ZERO);
65
66    // ALGEBRA
67    // --------------------------------------------------------------------------------------------
68
69    #[inline]
70    fn double(self) -> Self {
71        Self(self.0.double(), self.1.double(), self.2.double())
72    }
73
74    #[inline]
75    fn square(self) -> Self {
76        let a = <B as ExtensibleField<3>>::square([self.0, self.1, self.2]);
77        Self(a[0], a[1], a[2])
78    }
79
80    #[inline]
81    fn inv(self) -> Self {
82        if self == Self::ZERO {
83            return self;
84        }
85
86        let x = [self.0, self.1, self.2];
87        let c1 = <B as ExtensibleField<3>>::frobenius(x);
88        let c2 = <B as ExtensibleField<3>>::frobenius(c1);
89        let numerator = <B as ExtensibleField<3>>::mul(c1, c2);
90
91        let norm = <B as ExtensibleField<3>>::mul(x, numerator);
92        debug_assert_eq!(norm[1], B::ZERO, "norm must be in the base field");
93        debug_assert_eq!(norm[2], B::ZERO, "norm must be in the base field");
94        let denom_inv = norm[0].inv();
95
96        Self(numerator[0] * denom_inv, numerator[1] * denom_inv, numerator[2] * denom_inv)
97    }
98
99    #[inline]
100    fn conjugate(&self) -> Self {
101        let result = <B as ExtensibleField<3>>::frobenius([self.0, self.1, self.2]);
102        Self(result[0], result[1], result[2])
103    }
104
105    // BASE ELEMENT CONVERSIONS
106    // --------------------------------------------------------------------------------------------
107
108    fn base_element(&self, i: usize) -> Self::BaseField {
109        match i {
110            0 => self.0,
111            1 => self.1,
112            2 => self.2,
113            _ => panic!("element index must be smaller than 3, but was {i}"),
114        }
115    }
116
117    fn slice_as_base_elements(elements: &[Self]) -> &[Self::BaseField] {
118        let ptr = elements.as_ptr();
119        let len = elements.len() * Self::EXTENSION_DEGREE;
120        unsafe { slice::from_raw_parts(ptr as *const Self::BaseField, len) }
121    }
122
123    fn slice_from_base_elements(elements: &[Self::BaseField]) -> &[Self] {
124        assert!(
125            elements.len().is_multiple_of(Self::EXTENSION_DEGREE),
126            "number of base elements must be divisible by 3, but was {}",
127            elements.len()
128        );
129
130        let ptr = elements.as_ptr();
131        let len = elements.len() / Self::EXTENSION_DEGREE;
132        unsafe { slice::from_raw_parts(ptr as *const Self, len) }
133    }
134
135    // SERIALIZATION / DESERIALIZATION
136    // --------------------------------------------------------------------------------------------
137
138    fn elements_as_bytes(elements: &[Self]) -> &[u8] {
139        unsafe {
140            slice::from_raw_parts(
141                elements.as_ptr() as *const u8,
142                elements.len() * Self::ELEMENT_BYTES,
143            )
144        }
145    }
146
147    unsafe fn bytes_as_elements(bytes: &[u8]) -> Result<&[Self], DeserializationError> {
148        if !bytes.len().is_multiple_of(Self::ELEMENT_BYTES) {
149            return Err(DeserializationError::InvalidValue(format!(
150                "number of bytes ({}) does not divide into whole number of field elements",
151                bytes.len(),
152            )));
153        }
154
155        let p = bytes.as_ptr();
156        let len = bytes.len() / Self::ELEMENT_BYTES;
157
158        // make sure the bytes are aligned on the boundary consistent with base element alignment
159        if !(p as usize).is_multiple_of(Self::BaseField::ELEMENT_BYTES) {
160            return Err(DeserializationError::InvalidValue(
161                "slice memory alignment is not valid for this field element type".to_string(),
162            ));
163        }
164
165        Ok(slice::from_raw_parts(p as *const Self, len))
166    }
167}
168
169impl<B: ExtensibleField<3>> ExtensionOf<B> for CubeExtension<B> {
170    #[inline(always)]
171    fn mul_base(self, other: B) -> Self {
172        let result = <B as ExtensibleField<3>>::mul_base([self.0, self.1, self.2], other);
173        Self(result[0], result[1], result[2])
174    }
175}
176
177impl<B: ExtensibleField<3>> Randomizable for CubeExtension<B> {
178    const VALUE_SIZE: usize = Self::ELEMENT_BYTES;
179
180    fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
181        Self::try_from(bytes).ok()
182    }
183}
184
185impl<B: ExtensibleField<3>> fmt::Display for CubeExtension<B> {
186    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
187        write!(f, "({}, {}, {})", self.0, self.1, self.2)
188    }
189}
190
191// OVERLOADED OPERATORS
192// ------------------------------------------------------------------------------------------------
193
194impl<B: ExtensibleField<3>> Add for CubeExtension<B> {
195    type Output = Self;
196
197    #[inline]
198    fn add(self, rhs: Self) -> Self {
199        Self(self.0 + rhs.0, self.1 + rhs.1, self.2 + rhs.2)
200    }
201}
202
203impl<B: ExtensibleField<3>> AddAssign for CubeExtension<B> {
204    #[inline]
205    fn add_assign(&mut self, rhs: Self) {
206        *self = *self + rhs
207    }
208}
209
210impl<B: ExtensibleField<3>> Sub for CubeExtension<B> {
211    type Output = Self;
212
213    #[inline]
214    fn sub(self, rhs: Self) -> Self {
215        Self(self.0 - rhs.0, self.1 - rhs.1, self.2 - rhs.2)
216    }
217}
218
219impl<B: ExtensibleField<3>> SubAssign for CubeExtension<B> {
220    #[inline]
221    fn sub_assign(&mut self, rhs: Self) {
222        *self = *self - rhs;
223    }
224}
225
226impl<B: ExtensibleField<3>> Mul for CubeExtension<B> {
227    type Output = Self;
228
229    #[inline]
230    fn mul(self, rhs: Self) -> Self {
231        let result =
232            <B as ExtensibleField<3>>::mul([self.0, self.1, self.2], [rhs.0, rhs.1, rhs.2]);
233        Self(result[0], result[1], result[2])
234    }
235}
236
237impl<B: ExtensibleField<3>> MulAssign for CubeExtension<B> {
238    #[inline]
239    fn mul_assign(&mut self, rhs: Self) {
240        *self = *self * rhs
241    }
242}
243
244impl<B: ExtensibleField<3>> Div for CubeExtension<B> {
245    type Output = Self;
246
247    #[inline]
248    #[allow(clippy::suspicious_arithmetic_impl)]
249    fn div(self, rhs: Self) -> Self {
250        self * rhs.inv()
251    }
252}
253
254impl<B: ExtensibleField<3>> DivAssign for CubeExtension<B> {
255    #[inline]
256    fn div_assign(&mut self, rhs: Self) {
257        *self = *self / rhs
258    }
259}
260
261impl<B: ExtensibleField<3>> Neg for CubeExtension<B> {
262    type Output = Self;
263
264    #[inline]
265    fn neg(self) -> Self {
266        Self(-self.0, -self.1, -self.2)
267    }
268}
269
270// TYPE CONVERSIONS
271// ------------------------------------------------------------------------------------------------
272
273impl<B: ExtensibleField<3>> From<B> for CubeExtension<B> {
274    fn from(value: B) -> Self {
275        Self(value, B::ZERO, B::ZERO)
276    }
277}
278
279impl<B: ExtensibleField<3>> From<u32> for CubeExtension<B> {
280    fn from(value: u32) -> Self {
281        Self(B::from(value), B::ZERO, B::ZERO)
282    }
283}
284
285impl<B: ExtensibleField<3>> From<u16> for CubeExtension<B> {
286    fn from(value: u16) -> Self {
287        Self(B::from(value), B::ZERO, B::ZERO)
288    }
289}
290
291impl<B: ExtensibleField<3>> From<u8> for CubeExtension<B> {
292    fn from(value: u8) -> Self {
293        Self(B::from(value), B::ZERO, B::ZERO)
294    }
295}
296
297impl<B: ExtensibleField<3>> TryFrom<u64> for CubeExtension<B> {
298    type Error = String;
299
300    fn try_from(value: u64) -> Result<Self, Self::Error> {
301        match B::try_from(value) {
302            Ok(elem) => Ok(Self::from(elem)),
303            Err(_) => Err(format!(
304                "invalid field element: value {value} is greater than or equal to the field modulus"
305            )),
306        }
307    }
308}
309
310impl<B: ExtensibleField<3>> TryFrom<u128> for CubeExtension<B> {
311    type Error = String;
312
313    fn try_from(value: u128) -> Result<Self, Self::Error> {
314        match B::try_from(value) {
315            Ok(elem) => Ok(Self::from(elem)),
316            Err(_) => Err(format!(
317                "invalid field element: value {value} is greater than or equal to the field modulus"
318            )),
319        }
320    }
321}
322
323impl<B: ExtensibleField<3>> TryFrom<&'_ [u8]> for CubeExtension<B> {
324    type Error = DeserializationError;
325
326    /// Converts a slice of bytes into a field element; returns error if the value encoded in bytes
327    /// is not a valid field element. The bytes are assumed to be in little-endian byte order.
328    fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
329        if bytes.len() < Self::ELEMENT_BYTES {
330            return Err(DeserializationError::InvalidValue(format!(
331                "not enough bytes for a full field element; expected {} bytes, but was {} bytes",
332                Self::ELEMENT_BYTES,
333                bytes.len(),
334            )));
335        }
336        if bytes.len() > Self::ELEMENT_BYTES {
337            return Err(DeserializationError::InvalidValue(format!(
338                "too many bytes for a field element; expected {} bytes, but was {} bytes",
339                Self::ELEMENT_BYTES,
340                bytes.len(),
341            )));
342        }
343        let mut reader = SliceReader::new(bytes);
344        Self::read_from(&mut reader)
345    }
346}
347
348impl<B: ExtensibleField<3>> AsBytes for CubeExtension<B> {
349    fn as_bytes(&self) -> &[u8] {
350        // TODO: take endianness into account
351        let self_ptr: *const Self = self;
352        unsafe { slice::from_raw_parts(self_ptr as *const u8, Self::ELEMENT_BYTES) }
353    }
354}
355
356// SERIALIZATION / DESERIALIZATION
357// ------------------------------------------------------------------------------------------------
358
359impl<B: ExtensibleField<3>> Serializable for CubeExtension<B> {
360    fn write_into<W: ByteWriter>(&self, target: &mut W) {
361        self.0.write_into(target);
362        self.1.write_into(target);
363        self.2.write_into(target);
364    }
365}
366
367impl<B: ExtensibleField<3>> Deserializable for CubeExtension<B> {
368    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
369        let value0 = B::read_from(source)?;
370        let value1 = B::read_from(source)?;
371        let value2 = B::read_from(source)?;
372        Ok(Self(value0, value1, value2))
373    }
374}
375
376// TESTS
377// ================================================================================================
378
379#[cfg(test)]
380mod tests {
381    use rand_utils::rand_value;
382
383    use super::{CubeExtension, DeserializationError, FieldElement};
384    use crate::field::f64::BaseElement;
385
386    // BASIC ALGEBRA
387    // --------------------------------------------------------------------------------------------
388
389    #[test]
390    fn add() {
391        // identity
392        let r: CubeExtension<BaseElement> = rand_value();
393        assert_eq!(r, r + CubeExtension::<BaseElement>::ZERO);
394
395        // test random values
396        let r1: CubeExtension<BaseElement> = rand_value();
397        let r2: CubeExtension<BaseElement> = rand_value();
398
399        let expected = CubeExtension(r1.0 + r2.0, r1.1 + r2.1, r1.2 + r2.2);
400        assert_eq!(expected, r1 + r2);
401    }
402
403    #[test]
404    fn sub() {
405        // identity
406        let r: CubeExtension<BaseElement> = rand_value();
407        assert_eq!(r, r - CubeExtension::<BaseElement>::ZERO);
408
409        // test random values
410        let r1: CubeExtension<BaseElement> = rand_value();
411        let r2: CubeExtension<BaseElement> = rand_value();
412
413        let expected = CubeExtension(r1.0 - r2.0, r1.1 - r2.1, r1.2 - r2.2);
414        assert_eq!(expected, r1 - r2);
415    }
416
417    // SERIALIZATION / DESERIALIZATION
418    // --------------------------------------------------------------------------------------------
419
420    #[test]
421    fn elements_as_bytes() {
422        let source = vec![
423            CubeExtension(BaseElement::new(1), BaseElement::new(2), BaseElement::new(3)),
424            CubeExtension(BaseElement::new(4), BaseElement::new(5), BaseElement::new(6)),
425        ];
426
427        let mut expected = vec![];
428        expected.extend_from_slice(&source[0].0.inner().to_le_bytes());
429        expected.extend_from_slice(&source[0].1.inner().to_le_bytes());
430        expected.extend_from_slice(&source[0].2.inner().to_le_bytes());
431        expected.extend_from_slice(&source[1].0.inner().to_le_bytes());
432        expected.extend_from_slice(&source[1].1.inner().to_le_bytes());
433        expected.extend_from_slice(&source[1].2.inner().to_le_bytes());
434
435        assert_eq!(expected, CubeExtension::<BaseElement>::elements_as_bytes(&source));
436    }
437
438    #[test]
439    fn bytes_as_elements() {
440        let elements = vec![
441            CubeExtension(BaseElement::new(1), BaseElement::new(2), BaseElement::new(3)),
442            CubeExtension(BaseElement::new(4), BaseElement::new(5), BaseElement::new(6)),
443        ];
444
445        let mut bytes = vec![];
446        bytes.extend_from_slice(&elements[0].0.inner().to_le_bytes());
447        bytes.extend_from_slice(&elements[0].1.inner().to_le_bytes());
448        bytes.extend_from_slice(&elements[0].2.inner().to_le_bytes());
449        bytes.extend_from_slice(&elements[1].0.inner().to_le_bytes());
450        bytes.extend_from_slice(&elements[1].1.inner().to_le_bytes());
451        bytes.extend_from_slice(&elements[1].2.inner().to_le_bytes());
452        bytes.extend_from_slice(&BaseElement::new(5).inner().to_le_bytes());
453
454        let result = unsafe { CubeExtension::<BaseElement>::bytes_as_elements(&bytes[..48]) };
455        assert!(result.is_ok());
456        assert_eq!(elements, result.unwrap());
457
458        let result = unsafe { CubeExtension::<BaseElement>::bytes_as_elements(&bytes) };
459        assert!(matches!(result, Err(DeserializationError::InvalidValue(_))));
460
461        let result = unsafe { CubeExtension::<BaseElement>::bytes_as_elements(&bytes[1..]) };
462        assert!(matches!(result, Err(DeserializationError::InvalidValue(_))));
463    }
464
465    // UTILITIES
466    // --------------------------------------------------------------------------------------------
467
468    #[test]
469    fn as_base_elements() {
470        let elements = vec![
471            CubeExtension(BaseElement::new(1), BaseElement::new(2), BaseElement::new(3)),
472            CubeExtension(BaseElement::new(4), BaseElement::new(5), BaseElement::new(6)),
473        ];
474
475        let expected = vec![
476            BaseElement::new(1),
477            BaseElement::new(2),
478            BaseElement::new(3),
479            BaseElement::new(4),
480            BaseElement::new(5),
481            BaseElement::new(6),
482        ];
483
484        assert_eq!(expected, CubeExtension::<BaseElement>::slice_as_base_elements(&elements));
485    }
486}