../../.cargo/katex-header.html

winter_math/field/extensions/
quadratic.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6use alloc::string::{String, ToString};
7use core::{
8    fmt,
9    ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
10    slice,
11};
12
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15use utils::{
16    AsBytes, ByteReader, ByteWriter, Deserializable, DeserializationError, Randomizable,
17    Serializable, SliceReader,
18};
19
20use super::{ExtensibleField, ExtensionOf, FieldElement};
21
22// QUADRATIC EXTENSION FIELD
23// ================================================================================================
24
25/// Represents an element in a quadratic extension of a [StarkField](crate::StarkField).
26///
27/// The extension element is defined as α + β * φ, where φ is a root of in irreducible polynomial
28/// defined by the implementation of the [ExtensibleField] trait, and α and β are base field
29/// elements.
30#[repr(C)]
31#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
32#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
33pub struct QuadExtension<B: ExtensibleField<2>>(B, B);
34
35impl<B: ExtensibleField<2>> QuadExtension<B> {
36    /// Returns a new extension element instantiated from the provided base elements.
37    pub const fn new(a: B, b: B) -> Self {
38        Self(a, b)
39    }
40
41    /// Returns true if the base field specified by B type parameter supports quadratic extensions.
42    pub fn is_supported() -> bool {
43        <B as ExtensibleField<2>>::is_supported()
44    }
45
46    /// Returns an array of base field elements comprising this extension field element.
47    ///
48    /// The order of abase elements in the returned array is the same as the order in which
49    /// the elements are provided to the [QuadExtension::new()] constructor.
50    pub const fn to_base_elements(self) -> [B; 2] {
51        [self.0, self.1]
52    }
53}
54
55impl<B: ExtensibleField<2>> FieldElement for QuadExtension<B> {
56    type PositiveInteger = B::PositiveInteger;
57    type BaseField = B;
58
59    const EXTENSION_DEGREE: usize = 2;
60
61    const ELEMENT_BYTES: usize = B::ELEMENT_BYTES * Self::EXTENSION_DEGREE;
62    const IS_CANONICAL: bool = B::IS_CANONICAL;
63    const ZERO: Self = Self(B::ZERO, B::ZERO);
64    const ONE: Self = Self(B::ONE, B::ZERO);
65
66    // ALGEBRA
67    // --------------------------------------------------------------------------------------------
68
69    #[inline]
70    fn double(self) -> Self {
71        Self(self.0.double(), self.1.double())
72    }
73
74    #[inline]
75    fn square(self) -> Self {
76        let a = <B as ExtensibleField<2>>::square([self.0, self.1]);
77        Self(a[0], a[1])
78    }
79
80    #[inline]
81    fn inv(self) -> Self {
82        if self == Self::ZERO {
83            return self;
84        }
85
86        let x = [self.0, self.1];
87        let numerator = <B as ExtensibleField<2>>::frobenius(x);
88
89        let norm = <B as ExtensibleField<2>>::mul(x, numerator);
90        debug_assert_eq!(norm[1], B::ZERO, "norm must be in the base field");
91        let denom_inv = norm[0].inv();
92
93        Self(numerator[0] * denom_inv, numerator[1] * denom_inv)
94    }
95
96    #[inline]
97    fn conjugate(&self) -> Self {
98        let result = <B as ExtensibleField<2>>::frobenius([self.0, self.1]);
99        Self(result[0], result[1])
100    }
101
102    // BASE ELEMENT CONVERSIONS
103    // --------------------------------------------------------------------------------------------
104
105    fn base_element(&self, i: usize) -> Self::BaseField {
106        match i {
107            0 => self.0,
108            1 => self.1,
109            _ => panic!("element index must be smaller than 2, but was {i}"),
110        }
111    }
112
113    fn slice_as_base_elements(elements: &[Self]) -> &[Self::BaseField] {
114        let ptr = elements.as_ptr();
115        let len = elements.len() * Self::EXTENSION_DEGREE;
116        unsafe { slice::from_raw_parts(ptr as *const Self::BaseField, len) }
117    }
118
119    fn slice_from_base_elements(elements: &[Self::BaseField]) -> &[Self] {
120        assert!(
121            elements.len() % Self::EXTENSION_DEGREE == 0,
122            "number of base elements must be divisible by 2, but was {}",
123            elements.len()
124        );
125
126        let ptr = elements.as_ptr();
127        let len = elements.len() / Self::EXTENSION_DEGREE;
128        unsafe { slice::from_raw_parts(ptr as *const Self, len) }
129    }
130
131    // SERIALIZATION / DESERIALIZATION
132    // --------------------------------------------------------------------------------------------
133
134    fn elements_as_bytes(elements: &[Self]) -> &[u8] {
135        unsafe {
136            slice::from_raw_parts(
137                elements.as_ptr() as *const u8,
138                elements.len() * Self::ELEMENT_BYTES,
139            )
140        }
141    }
142
143    unsafe fn bytes_as_elements(bytes: &[u8]) -> Result<&[Self], DeserializationError> {
144        if bytes.len() % Self::ELEMENT_BYTES != 0 {
145            return Err(DeserializationError::InvalidValue(format!(
146                "number of bytes ({}) does not divide into whole number of field elements",
147                bytes.len(),
148            )));
149        }
150
151        let p = bytes.as_ptr();
152        let len = bytes.len() / Self::ELEMENT_BYTES;
153
154        // make sure the bytes are aligned on the boundary consistent with base element alignment
155        if (p as usize) % Self::BaseField::ELEMENT_BYTES != 0 {
156            return Err(DeserializationError::InvalidValue(
157                "slice memory alignment is not valid for this field element type".to_string(),
158            ));
159        }
160
161        Ok(slice::from_raw_parts(p as *const Self, len))
162    }
163}
164
165impl<B: ExtensibleField<2>> ExtensionOf<B> for QuadExtension<B> {
166    #[inline(always)]
167    fn mul_base(self, other: B) -> Self {
168        let result = <B as ExtensibleField<2>>::mul_base([self.0, self.1], other);
169        Self(result[0], result[1])
170    }
171}
172
173impl<B: ExtensibleField<2>> Randomizable for QuadExtension<B> {
174    const VALUE_SIZE: usize = Self::ELEMENT_BYTES;
175
176    fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
177        Self::try_from(bytes).ok()
178    }
179}
180
181impl<B: ExtensibleField<2>> fmt::Display for QuadExtension<B> {
182    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
183        write!(f, "({}, {})", self.0, self.1)
184    }
185}
186
187// OVERLOADED OPERATORS
188// ------------------------------------------------------------------------------------------------
189
190impl<B: ExtensibleField<2>> Add for QuadExtension<B> {
191    type Output = Self;
192
193    #[inline]
194    fn add(self, rhs: Self) -> Self {
195        Self(self.0 + rhs.0, self.1 + rhs.1)
196    }
197}
198
199impl<B: ExtensibleField<2>> AddAssign for QuadExtension<B> {
200    #[inline]
201    fn add_assign(&mut self, rhs: Self) {
202        *self = *self + rhs
203    }
204}
205
206impl<B: ExtensibleField<2>> Sub for QuadExtension<B> {
207    type Output = Self;
208
209    #[inline]
210    fn sub(self, rhs: Self) -> Self {
211        Self(self.0 - rhs.0, self.1 - rhs.1)
212    }
213}
214
215impl<B: ExtensibleField<2>> SubAssign for QuadExtension<B> {
216    #[inline]
217    fn sub_assign(&mut self, rhs: Self) {
218        *self = *self - rhs;
219    }
220}
221
222impl<B: ExtensibleField<2>> Mul for QuadExtension<B> {
223    type Output = Self;
224
225    #[inline]
226    fn mul(self, rhs: Self) -> Self {
227        let result = <B as ExtensibleField<2>>::mul([self.0, self.1], [rhs.0, rhs.1]);
228        Self(result[0], result[1])
229    }
230}
231
232impl<B: ExtensibleField<2>> MulAssign for QuadExtension<B> {
233    #[inline]
234    fn mul_assign(&mut self, rhs: Self) {
235        *self = *self * rhs
236    }
237}
238
239impl<B: ExtensibleField<2>> Div for QuadExtension<B> {
240    type Output = Self;
241
242    #[inline]
243    #[allow(clippy::suspicious_arithmetic_impl)]
244    fn div(self, rhs: Self) -> Self {
245        self * rhs.inv()
246    }
247}
248
249impl<B: ExtensibleField<2>> DivAssign for QuadExtension<B> {
250    #[inline]
251    fn div_assign(&mut self, rhs: Self) {
252        *self = *self / rhs
253    }
254}
255
256impl<B: ExtensibleField<2>> Neg for QuadExtension<B> {
257    type Output = Self;
258
259    #[inline]
260    fn neg(self) -> Self {
261        Self(-self.0, -self.1)
262    }
263}
264
265// TYPE CONVERSIONS
266// ------------------------------------------------------------------------------------------------
267
268impl<B: ExtensibleField<2>> From<B> for QuadExtension<B> {
269    fn from(value: B) -> Self {
270        Self(value, B::ZERO)
271    }
272}
273
274impl<B: ExtensibleField<2>> From<u32> for QuadExtension<B> {
275    fn from(value: u32) -> Self {
276        Self(B::from(value), B::ZERO)
277    }
278}
279
280impl<B: ExtensibleField<2>> From<u16> for QuadExtension<B> {
281    fn from(value: u16) -> Self {
282        Self(B::from(value), B::ZERO)
283    }
284}
285
286impl<B: ExtensibleField<2>> From<u8> for QuadExtension<B> {
287    fn from(value: u8) -> Self {
288        Self(B::from(value), B::ZERO)
289    }
290}
291
292impl<B: ExtensibleField<2>> TryFrom<u64> for QuadExtension<B> {
293    type Error = String;
294
295    fn try_from(value: u64) -> Result<Self, Self::Error> {
296        match B::try_from(value) {
297            Ok(elem) => Ok(Self::from(elem)),
298            Err(_) => Err(format!(
299                "invalid field element: value {value} is greater than or equal to the field modulus"
300            )),
301        }
302    }
303}
304
305impl<B: ExtensibleField<2>> TryFrom<u128> for QuadExtension<B> {
306    type Error = String;
307
308    fn try_from(value: u128) -> Result<Self, Self::Error> {
309        match B::try_from(value) {
310            Ok(elem) => Ok(Self::from(elem)),
311            Err(_) => Err(format!(
312                "invalid field element: value {value} is greater than or equal to the field modulus"
313            )),
314        }
315    }
316}
317
318impl<B: ExtensibleField<2>> TryFrom<&'_ [u8]> for QuadExtension<B> {
319    type Error = DeserializationError;
320
321    /// Converts a slice of bytes into a field element; returns error if the value encoded in bytes
322    /// is not a valid field element. The bytes are assumed to be in little-endian byte order.
323    fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
324        if bytes.len() < Self::ELEMENT_BYTES {
325            return Err(DeserializationError::InvalidValue(format!(
326                "not enough bytes for a full field element; expected {} bytes, but was {} bytes",
327                Self::ELEMENT_BYTES,
328                bytes.len(),
329            )));
330        }
331        if bytes.len() > Self::ELEMENT_BYTES {
332            return Err(DeserializationError::InvalidValue(format!(
333                "too many bytes for a field element; expected {} bytes, but was {} bytes",
334                Self::ELEMENT_BYTES,
335                bytes.len(),
336            )));
337        }
338        let mut reader = SliceReader::new(bytes);
339        Self::read_from(&mut reader)
340    }
341}
342
343impl<B: ExtensibleField<2>> AsBytes for QuadExtension<B> {
344    fn as_bytes(&self) -> &[u8] {
345        // TODO: take endianness into account
346        let self_ptr: *const Self = self;
347        unsafe { slice::from_raw_parts(self_ptr as *const u8, Self::ELEMENT_BYTES) }
348    }
349}
350
351// SERIALIZATION / DESERIALIZATION
352// ------------------------------------------------------------------------------------------------
353
354impl<B: ExtensibleField<2>> Serializable for QuadExtension<B> {
355    fn write_into<W: ByteWriter>(&self, target: &mut W) {
356        self.0.write_into(target);
357        self.1.write_into(target);
358    }
359}
360
361impl<B: ExtensibleField<2>> Deserializable for QuadExtension<B> {
362    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
363        let value0 = B::read_from(source)?;
364        let value1 = B::read_from(source)?;
365        Ok(Self(value0, value1))
366    }
367}
368
369// TESTS
370// ================================================================================================
371
372#[cfg(test)]
373mod tests {
374    use rand_utils::rand_value;
375
376    use super::{DeserializationError, FieldElement, QuadExtension};
377    use crate::field::f64::BaseElement;
378
379    // BASIC ALGEBRA
380    // --------------------------------------------------------------------------------------------
381
382    #[test]
383    fn add() {
384        // identity
385        let r: QuadExtension<BaseElement> = rand_value();
386        assert_eq!(r, r + QuadExtension::<BaseElement>::ZERO);
387
388        // test random values
389        let r1: QuadExtension<BaseElement> = rand_value();
390        let r2: QuadExtension<BaseElement> = rand_value();
391
392        let expected = QuadExtension(r1.0 + r2.0, r1.1 + r2.1);
393        assert_eq!(expected, r1 + r2);
394    }
395
396    #[test]
397    fn sub() {
398        // identity
399        let r: QuadExtension<BaseElement> = rand_value();
400        assert_eq!(r, r - QuadExtension::<BaseElement>::ZERO);
401
402        // test random values
403        let r1: QuadExtension<BaseElement> = rand_value();
404        let r2: QuadExtension<BaseElement> = rand_value();
405
406        let expected = QuadExtension(r1.0 - r2.0, r1.1 - r2.1);
407        assert_eq!(expected, r1 - r2);
408    }
409
410    // SERIALIZATION / DESERIALIZATION
411    // --------------------------------------------------------------------------------------------
412
413    #[test]
414    fn elements_as_bytes() {
415        let source = vec![
416            QuadExtension(BaseElement::new(1), BaseElement::new(2)),
417            QuadExtension(BaseElement::new(3), BaseElement::new(4)),
418        ];
419
420        let mut expected = vec![];
421        expected.extend_from_slice(&source[0].0.inner().to_le_bytes());
422        expected.extend_from_slice(&source[0].1.inner().to_le_bytes());
423        expected.extend_from_slice(&source[1].0.inner().to_le_bytes());
424        expected.extend_from_slice(&source[1].1.inner().to_le_bytes());
425
426        assert_eq!(expected, QuadExtension::<BaseElement>::elements_as_bytes(&source));
427    }
428
429    #[test]
430    fn bytes_as_elements() {
431        let elements = vec![
432            QuadExtension(BaseElement::new(1), BaseElement::new(2)),
433            QuadExtension(BaseElement::new(3), BaseElement::new(4)),
434        ];
435
436        let mut bytes = vec![];
437        bytes.extend_from_slice(&elements[0].0.inner().to_le_bytes());
438        bytes.extend_from_slice(&elements[0].1.inner().to_le_bytes());
439        bytes.extend_from_slice(&elements[1].0.inner().to_le_bytes());
440        bytes.extend_from_slice(&elements[1].1.inner().to_le_bytes());
441        bytes.extend_from_slice(&BaseElement::new(5).inner().to_le_bytes());
442        let result = unsafe { QuadExtension::<BaseElement>::bytes_as_elements(&bytes[..32]) };
443        assert!(result.is_ok());
444        assert_eq!(elements, result.unwrap());
445
446        let result = unsafe { QuadExtension::<BaseElement>::bytes_as_elements(&bytes) };
447        assert!(matches!(result, Err(DeserializationError::InvalidValue(_))));
448
449        let result = unsafe { QuadExtension::<BaseElement>::bytes_as_elements(&bytes[1..]) };
450        assert!(matches!(result, Err(DeserializationError::InvalidValue(_))));
451    }
452
453    // UTILITIES
454    // --------------------------------------------------------------------------------------------
455
456    #[test]
457    fn as_base_elements() {
458        let elements = vec![
459            QuadExtension(BaseElement::new(1), BaseElement::new(2)),
460            QuadExtension(BaseElement::new(3), BaseElement::new(4)),
461        ];
462
463        let expected = vec![
464            BaseElement::new(1),
465            BaseElement::new(2),
466            BaseElement::new(3),
467            BaseElement::new(4),
468        ];
469
470        assert_eq!(expected, QuadExtension::<BaseElement>::slice_as_base_elements(&elements));
471    }
472}