winter_math/field/extensions/
quadratic.rs1use alloc::string::{String, ToString};
7use core::{
8 fmt,
9 ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
10 slice,
11};
12
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15use utils::{
16 AsBytes, ByteReader, ByteWriter, Deserializable, DeserializationError, Randomizable,
17 Serializable, SliceReader,
18};
19
20use super::{ExtensibleField, ExtensionOf, FieldElement};
21
22#[repr(C)]
31#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
32#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
33pub struct QuadExtension<B: ExtensibleField<2>>(B, B);
34
35impl<B: ExtensibleField<2>> QuadExtension<B> {
36 pub const fn new(a: B, b: B) -> Self {
38 Self(a, b)
39 }
40
41 pub fn is_supported() -> bool {
43 <B as ExtensibleField<2>>::is_supported()
44 }
45
46 pub const fn to_base_elements(self) -> [B; 2] {
51 [self.0, self.1]
52 }
53}
54
55impl<B: ExtensibleField<2>> FieldElement for QuadExtension<B> {
56 type PositiveInteger = B::PositiveInteger;
57 type BaseField = B;
58
59 const EXTENSION_DEGREE: usize = 2;
60
61 const ELEMENT_BYTES: usize = B::ELEMENT_BYTES * Self::EXTENSION_DEGREE;
62 const IS_CANONICAL: bool = B::IS_CANONICAL;
63 const ZERO: Self = Self(B::ZERO, B::ZERO);
64 const ONE: Self = Self(B::ONE, B::ZERO);
65
66 #[inline]
70 fn double(self) -> Self {
71 Self(self.0.double(), self.1.double())
72 }
73
74 #[inline]
75 fn square(self) -> Self {
76 let a = <B as ExtensibleField<2>>::square([self.0, self.1]);
77 Self(a[0], a[1])
78 }
79
80 #[inline]
81 fn inv(self) -> Self {
82 if self == Self::ZERO {
83 return self;
84 }
85
86 let x = [self.0, self.1];
87 let numerator = <B as ExtensibleField<2>>::frobenius(x);
88
89 let norm = <B as ExtensibleField<2>>::mul(x, numerator);
90 debug_assert_eq!(norm[1], B::ZERO, "norm must be in the base field");
91 let denom_inv = norm[0].inv();
92
93 Self(numerator[0] * denom_inv, numerator[1] * denom_inv)
94 }
95
96 #[inline]
97 fn conjugate(&self) -> Self {
98 let result = <B as ExtensibleField<2>>::frobenius([self.0, self.1]);
99 Self(result[0], result[1])
100 }
101
102 fn base_element(&self, i: usize) -> Self::BaseField {
106 match i {
107 0 => self.0,
108 1 => self.1,
109 _ => panic!("element index must be smaller than 2, but was {i}"),
110 }
111 }
112
113 fn slice_as_base_elements(elements: &[Self]) -> &[Self::BaseField] {
114 let ptr = elements.as_ptr();
115 let len = elements.len() * Self::EXTENSION_DEGREE;
116 unsafe { slice::from_raw_parts(ptr as *const Self::BaseField, len) }
117 }
118
119 fn slice_from_base_elements(elements: &[Self::BaseField]) -> &[Self] {
120 assert!(
121 elements.len() % Self::EXTENSION_DEGREE == 0,
122 "number of base elements must be divisible by 2, but was {}",
123 elements.len()
124 );
125
126 let ptr = elements.as_ptr();
127 let len = elements.len() / Self::EXTENSION_DEGREE;
128 unsafe { slice::from_raw_parts(ptr as *const Self, len) }
129 }
130
131 fn elements_as_bytes(elements: &[Self]) -> &[u8] {
135 unsafe {
136 slice::from_raw_parts(
137 elements.as_ptr() as *const u8,
138 elements.len() * Self::ELEMENT_BYTES,
139 )
140 }
141 }
142
143 unsafe fn bytes_as_elements(bytes: &[u8]) -> Result<&[Self], DeserializationError> {
144 if bytes.len() % Self::ELEMENT_BYTES != 0 {
145 return Err(DeserializationError::InvalidValue(format!(
146 "number of bytes ({}) does not divide into whole number of field elements",
147 bytes.len(),
148 )));
149 }
150
151 let p = bytes.as_ptr();
152 let len = bytes.len() / Self::ELEMENT_BYTES;
153
154 if (p as usize) % Self::BaseField::ELEMENT_BYTES != 0 {
156 return Err(DeserializationError::InvalidValue(
157 "slice memory alignment is not valid for this field element type".to_string(),
158 ));
159 }
160
161 Ok(slice::from_raw_parts(p as *const Self, len))
162 }
163}
164
165impl<B: ExtensibleField<2>> ExtensionOf<B> for QuadExtension<B> {
166 #[inline(always)]
167 fn mul_base(self, other: B) -> Self {
168 let result = <B as ExtensibleField<2>>::mul_base([self.0, self.1], other);
169 Self(result[0], result[1])
170 }
171}
172
173impl<B: ExtensibleField<2>> Randomizable for QuadExtension<B> {
174 const VALUE_SIZE: usize = Self::ELEMENT_BYTES;
175
176 fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
177 Self::try_from(bytes).ok()
178 }
179}
180
181impl<B: ExtensibleField<2>> fmt::Display for QuadExtension<B> {
182 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
183 write!(f, "({}, {})", self.0, self.1)
184 }
185}
186
187impl<B: ExtensibleField<2>> Add for QuadExtension<B> {
191 type Output = Self;
192
193 #[inline]
194 fn add(self, rhs: Self) -> Self {
195 Self(self.0 + rhs.0, self.1 + rhs.1)
196 }
197}
198
199impl<B: ExtensibleField<2>> AddAssign for QuadExtension<B> {
200 #[inline]
201 fn add_assign(&mut self, rhs: Self) {
202 *self = *self + rhs
203 }
204}
205
206impl<B: ExtensibleField<2>> Sub for QuadExtension<B> {
207 type Output = Self;
208
209 #[inline]
210 fn sub(self, rhs: Self) -> Self {
211 Self(self.0 - rhs.0, self.1 - rhs.1)
212 }
213}
214
215impl<B: ExtensibleField<2>> SubAssign for QuadExtension<B> {
216 #[inline]
217 fn sub_assign(&mut self, rhs: Self) {
218 *self = *self - rhs;
219 }
220}
221
222impl<B: ExtensibleField<2>> Mul for QuadExtension<B> {
223 type Output = Self;
224
225 #[inline]
226 fn mul(self, rhs: Self) -> Self {
227 let result = <B as ExtensibleField<2>>::mul([self.0, self.1], [rhs.0, rhs.1]);
228 Self(result[0], result[1])
229 }
230}
231
232impl<B: ExtensibleField<2>> MulAssign for QuadExtension<B> {
233 #[inline]
234 fn mul_assign(&mut self, rhs: Self) {
235 *self = *self * rhs
236 }
237}
238
239impl<B: ExtensibleField<2>> Div for QuadExtension<B> {
240 type Output = Self;
241
242 #[inline]
243 #[allow(clippy::suspicious_arithmetic_impl)]
244 fn div(self, rhs: Self) -> Self {
245 self * rhs.inv()
246 }
247}
248
249impl<B: ExtensibleField<2>> DivAssign for QuadExtension<B> {
250 #[inline]
251 fn div_assign(&mut self, rhs: Self) {
252 *self = *self / rhs
253 }
254}
255
256impl<B: ExtensibleField<2>> Neg for QuadExtension<B> {
257 type Output = Self;
258
259 #[inline]
260 fn neg(self) -> Self {
261 Self(-self.0, -self.1)
262 }
263}
264
265impl<B: ExtensibleField<2>> From<B> for QuadExtension<B> {
269 fn from(value: B) -> Self {
270 Self(value, B::ZERO)
271 }
272}
273
274impl<B: ExtensibleField<2>> From<u32> for QuadExtension<B> {
275 fn from(value: u32) -> Self {
276 Self(B::from(value), B::ZERO)
277 }
278}
279
280impl<B: ExtensibleField<2>> From<u16> for QuadExtension<B> {
281 fn from(value: u16) -> Self {
282 Self(B::from(value), B::ZERO)
283 }
284}
285
286impl<B: ExtensibleField<2>> From<u8> for QuadExtension<B> {
287 fn from(value: u8) -> Self {
288 Self(B::from(value), B::ZERO)
289 }
290}
291
292impl<B: ExtensibleField<2>> TryFrom<u64> for QuadExtension<B> {
293 type Error = String;
294
295 fn try_from(value: u64) -> Result<Self, Self::Error> {
296 match B::try_from(value) {
297 Ok(elem) => Ok(Self::from(elem)),
298 Err(_) => Err(format!(
299 "invalid field element: value {value} is greater than or equal to the field modulus"
300 )),
301 }
302 }
303}
304
305impl<B: ExtensibleField<2>> TryFrom<u128> for QuadExtension<B> {
306 type Error = String;
307
308 fn try_from(value: u128) -> Result<Self, Self::Error> {
309 match B::try_from(value) {
310 Ok(elem) => Ok(Self::from(elem)),
311 Err(_) => Err(format!(
312 "invalid field element: value {value} is greater than or equal to the field modulus"
313 )),
314 }
315 }
316}
317
318impl<B: ExtensibleField<2>> TryFrom<&'_ [u8]> for QuadExtension<B> {
319 type Error = DeserializationError;
320
321 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
324 if bytes.len() < Self::ELEMENT_BYTES {
325 return Err(DeserializationError::InvalidValue(format!(
326 "not enough bytes for a full field element; expected {} bytes, but was {} bytes",
327 Self::ELEMENT_BYTES,
328 bytes.len(),
329 )));
330 }
331 if bytes.len() > Self::ELEMENT_BYTES {
332 return Err(DeserializationError::InvalidValue(format!(
333 "too many bytes for a field element; expected {} bytes, but was {} bytes",
334 Self::ELEMENT_BYTES,
335 bytes.len(),
336 )));
337 }
338 let mut reader = SliceReader::new(bytes);
339 Self::read_from(&mut reader)
340 }
341}
342
343impl<B: ExtensibleField<2>> AsBytes for QuadExtension<B> {
344 fn as_bytes(&self) -> &[u8] {
345 let self_ptr: *const Self = self;
347 unsafe { slice::from_raw_parts(self_ptr as *const u8, Self::ELEMENT_BYTES) }
348 }
349}
350
351impl<B: ExtensibleField<2>> Serializable for QuadExtension<B> {
355 fn write_into<W: ByteWriter>(&self, target: &mut W) {
356 self.0.write_into(target);
357 self.1.write_into(target);
358 }
359}
360
361impl<B: ExtensibleField<2>> Deserializable for QuadExtension<B> {
362 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
363 let value0 = B::read_from(source)?;
364 let value1 = B::read_from(source)?;
365 Ok(Self(value0, value1))
366 }
367}
368
369#[cfg(test)]
373mod tests {
374 use rand_utils::rand_value;
375
376 use super::{DeserializationError, FieldElement, QuadExtension};
377 use crate::field::f64::BaseElement;
378
379 #[test]
383 fn add() {
384 let r: QuadExtension<BaseElement> = rand_value();
386 assert_eq!(r, r + QuadExtension::<BaseElement>::ZERO);
387
388 let r1: QuadExtension<BaseElement> = rand_value();
390 let r2: QuadExtension<BaseElement> = rand_value();
391
392 let expected = QuadExtension(r1.0 + r2.0, r1.1 + r2.1);
393 assert_eq!(expected, r1 + r2);
394 }
395
396 #[test]
397 fn sub() {
398 let r: QuadExtension<BaseElement> = rand_value();
400 assert_eq!(r, r - QuadExtension::<BaseElement>::ZERO);
401
402 let r1: QuadExtension<BaseElement> = rand_value();
404 let r2: QuadExtension<BaseElement> = rand_value();
405
406 let expected = QuadExtension(r1.0 - r2.0, r1.1 - r2.1);
407 assert_eq!(expected, r1 - r2);
408 }
409
410 #[test]
414 fn elements_as_bytes() {
415 let source = vec![
416 QuadExtension(BaseElement::new(1), BaseElement::new(2)),
417 QuadExtension(BaseElement::new(3), BaseElement::new(4)),
418 ];
419
420 let mut expected = vec![];
421 expected.extend_from_slice(&source[0].0.inner().to_le_bytes());
422 expected.extend_from_slice(&source[0].1.inner().to_le_bytes());
423 expected.extend_from_slice(&source[1].0.inner().to_le_bytes());
424 expected.extend_from_slice(&source[1].1.inner().to_le_bytes());
425
426 assert_eq!(expected, QuadExtension::<BaseElement>::elements_as_bytes(&source));
427 }
428
429 #[test]
430 fn bytes_as_elements() {
431 let elements = vec![
432 QuadExtension(BaseElement::new(1), BaseElement::new(2)),
433 QuadExtension(BaseElement::new(3), BaseElement::new(4)),
434 ];
435
436 let mut bytes = vec![];
437 bytes.extend_from_slice(&elements[0].0.inner().to_le_bytes());
438 bytes.extend_from_slice(&elements[0].1.inner().to_le_bytes());
439 bytes.extend_from_slice(&elements[1].0.inner().to_le_bytes());
440 bytes.extend_from_slice(&elements[1].1.inner().to_le_bytes());
441 bytes.extend_from_slice(&BaseElement::new(5).inner().to_le_bytes());
442 let result = unsafe { QuadExtension::<BaseElement>::bytes_as_elements(&bytes[..32]) };
443 assert!(result.is_ok());
444 assert_eq!(elements, result.unwrap());
445
446 let result = unsafe { QuadExtension::<BaseElement>::bytes_as_elements(&bytes) };
447 assert!(matches!(result, Err(DeserializationError::InvalidValue(_))));
448
449 let result = unsafe { QuadExtension::<BaseElement>::bytes_as_elements(&bytes[1..]) };
450 assert!(matches!(result, Err(DeserializationError::InvalidValue(_))));
451 }
452
453 #[test]
457 fn as_base_elements() {
458 let elements = vec![
459 QuadExtension(BaseElement::new(1), BaseElement::new(2)),
460 QuadExtension(BaseElement::new(3), BaseElement::new(4)),
461 ];
462
463 let expected = vec![
464 BaseElement::new(1),
465 BaseElement::new(2),
466 BaseElement::new(3),
467 BaseElement::new(4),
468 ];
469
470 assert_eq!(expected, QuadExtension::<BaseElement>::slice_as_base_elements(&elements));
471 }
472}