1use alloc::string::{String, ToString};
7use core::{
8 fmt,
9 ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
10 slice,
11};
12
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15use utils::{
16 AsBytes, ByteReader, ByteWriter, Deserializable, DeserializationError, Randomizable,
17 Serializable, SliceReader,
18};
19
20use super::{ExtensibleField, ExtensionOf, FieldElement};
21
22#[repr(C)]
31#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
32#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
33pub struct CubeExtension<B: ExtensibleField<3>>(B, B, B);
34
35impl<B: ExtensibleField<3>> CubeExtension<B> {
36 pub const fn new(a: B, b: B, c: B) -> Self {
38 Self(a, b, c)
39 }
40
41 pub fn is_supported() -> bool {
43 <B as ExtensibleField<3>>::is_supported()
44 }
45
46 pub const fn to_base_elements(self) -> [B; 3] {
51 [self.0, self.1, self.2]
52 }
53}
54
55impl<B: ExtensibleField<3>> FieldElement for CubeExtension<B> {
56 type PositiveInteger = B::PositiveInteger;
57 type BaseField = B;
58
59 const EXTENSION_DEGREE: usize = 3;
60
61 const ELEMENT_BYTES: usize = B::ELEMENT_BYTES * Self::EXTENSION_DEGREE;
62 const IS_CANONICAL: bool = B::IS_CANONICAL;
63 const ZERO: Self = Self(B::ZERO, B::ZERO, B::ZERO);
64 const ONE: Self = Self(B::ONE, B::ZERO, B::ZERO);
65
66 #[inline]
70 fn double(self) -> Self {
71 Self(self.0.double(), self.1.double(), self.2.double())
72 }
73
74 #[inline]
75 fn square(self) -> Self {
76 let a = <B as ExtensibleField<3>>::square([self.0, self.1, self.2]);
77 Self(a[0], a[1], a[2])
78 }
79
80 #[inline]
81 fn inv(self) -> Self {
82 if self == Self::ZERO {
83 return self;
84 }
85
86 let x = [self.0, self.1, self.2];
87 let c1 = <B as ExtensibleField<3>>::frobenius(x);
88 let c2 = <B as ExtensibleField<3>>::frobenius(c1);
89 let numerator = <B as ExtensibleField<3>>::mul(c1, c2);
90
91 let norm = <B as ExtensibleField<3>>::mul(x, numerator);
92 debug_assert_eq!(norm[1], B::ZERO, "norm must be in the base field");
93 debug_assert_eq!(norm[2], B::ZERO, "norm must be in the base field");
94 let denom_inv = norm[0].inv();
95
96 Self(numerator[0] * denom_inv, numerator[1] * denom_inv, numerator[2] * denom_inv)
97 }
98
99 #[inline]
100 fn conjugate(&self) -> Self {
101 let result = <B as ExtensibleField<3>>::frobenius([self.0, self.1, self.2]);
102 Self(result[0], result[1], result[2])
103 }
104
105 fn base_element(&self, i: usize) -> Self::BaseField {
109 match i {
110 0 => self.0,
111 1 => self.1,
112 2 => self.2,
113 _ => panic!("element index must be smaller than 3, but was {i}"),
114 }
115 }
116
117 fn slice_as_base_elements(elements: &[Self]) -> &[Self::BaseField] {
118 let ptr = elements.as_ptr();
119 let len = elements.len() * Self::EXTENSION_DEGREE;
120 unsafe { slice::from_raw_parts(ptr as *const Self::BaseField, len) }
121 }
122
123 fn slice_from_base_elements(elements: &[Self::BaseField]) -> &[Self] {
124 assert!(
125 elements.len().is_multiple_of(Self::EXTENSION_DEGREE),
126 "number of base elements must be divisible by 3, but was {}",
127 elements.len()
128 );
129
130 let ptr = elements.as_ptr();
131 let len = elements.len() / Self::EXTENSION_DEGREE;
132 unsafe { slice::from_raw_parts(ptr as *const Self, len) }
133 }
134
135 fn elements_as_bytes(elements: &[Self]) -> &[u8] {
139 unsafe {
140 slice::from_raw_parts(
141 elements.as_ptr() as *const u8,
142 elements.len() * Self::ELEMENT_BYTES,
143 )
144 }
145 }
146
147 unsafe fn bytes_as_elements(bytes: &[u8]) -> Result<&[Self], DeserializationError> {
148 if !bytes.len().is_multiple_of(Self::ELEMENT_BYTES) {
149 return Err(DeserializationError::InvalidValue(format!(
150 "number of bytes ({}) does not divide into whole number of field elements",
151 bytes.len(),
152 )));
153 }
154
155 let p = bytes.as_ptr();
156 let len = bytes.len() / Self::ELEMENT_BYTES;
157
158 if !(p as usize).is_multiple_of(Self::BaseField::ELEMENT_BYTES) {
160 return Err(DeserializationError::InvalidValue(
161 "slice memory alignment is not valid for this field element type".to_string(),
162 ));
163 }
164
165 Ok(slice::from_raw_parts(p as *const Self, len))
166 }
167}
168
169impl<B: ExtensibleField<3>> ExtensionOf<B> for CubeExtension<B> {
170 #[inline(always)]
171 fn mul_base(self, other: B) -> Self {
172 let result = <B as ExtensibleField<3>>::mul_base([self.0, self.1, self.2], other);
173 Self(result[0], result[1], result[2])
174 }
175}
176
177impl<B: ExtensibleField<3>> Randomizable for CubeExtension<B> {
178 const VALUE_SIZE: usize = Self::ELEMENT_BYTES;
179
180 fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
181 Self::try_from(bytes).ok()
182 }
183}
184
185impl<B: ExtensibleField<3>> fmt::Display for CubeExtension<B> {
186 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
187 write!(f, "({}, {}, {})", self.0, self.1, self.2)
188 }
189}
190
191impl<B: ExtensibleField<3>> Add for CubeExtension<B> {
195 type Output = Self;
196
197 #[inline]
198 fn add(self, rhs: Self) -> Self {
199 Self(self.0 + rhs.0, self.1 + rhs.1, self.2 + rhs.2)
200 }
201}
202
203impl<B: ExtensibleField<3>> AddAssign for CubeExtension<B> {
204 #[inline]
205 fn add_assign(&mut self, rhs: Self) {
206 *self = *self + rhs
207 }
208}
209
210impl<B: ExtensibleField<3>> Sub for CubeExtension<B> {
211 type Output = Self;
212
213 #[inline]
214 fn sub(self, rhs: Self) -> Self {
215 Self(self.0 - rhs.0, self.1 - rhs.1, self.2 - rhs.2)
216 }
217}
218
219impl<B: ExtensibleField<3>> SubAssign for CubeExtension<B> {
220 #[inline]
221 fn sub_assign(&mut self, rhs: Self) {
222 *self = *self - rhs;
223 }
224}
225
226impl<B: ExtensibleField<3>> Mul for CubeExtension<B> {
227 type Output = Self;
228
229 #[inline]
230 fn mul(self, rhs: Self) -> Self {
231 let result =
232 <B as ExtensibleField<3>>::mul([self.0, self.1, self.2], [rhs.0, rhs.1, rhs.2]);
233 Self(result[0], result[1], result[2])
234 }
235}
236
237impl<B: ExtensibleField<3>> MulAssign for CubeExtension<B> {
238 #[inline]
239 fn mul_assign(&mut self, rhs: Self) {
240 *self = *self * rhs
241 }
242}
243
244impl<B: ExtensibleField<3>> Div for CubeExtension<B> {
245 type Output = Self;
246
247 #[inline]
248 #[allow(clippy::suspicious_arithmetic_impl)]
249 fn div(self, rhs: Self) -> Self {
250 self * rhs.inv()
251 }
252}
253
254impl<B: ExtensibleField<3>> DivAssign for CubeExtension<B> {
255 #[inline]
256 fn div_assign(&mut self, rhs: Self) {
257 *self = *self / rhs
258 }
259}
260
261impl<B: ExtensibleField<3>> Neg for CubeExtension<B> {
262 type Output = Self;
263
264 #[inline]
265 fn neg(self) -> Self {
266 Self(-self.0, -self.1, -self.2)
267 }
268}
269
270impl<B: ExtensibleField<3>> From<B> for CubeExtension<B> {
274 fn from(value: B) -> Self {
275 Self(value, B::ZERO, B::ZERO)
276 }
277}
278
279impl<B: ExtensibleField<3>> From<u32> for CubeExtension<B> {
280 fn from(value: u32) -> Self {
281 Self(B::from(value), B::ZERO, B::ZERO)
282 }
283}
284
285impl<B: ExtensibleField<3>> From<u16> for CubeExtension<B> {
286 fn from(value: u16) -> Self {
287 Self(B::from(value), B::ZERO, B::ZERO)
288 }
289}
290
291impl<B: ExtensibleField<3>> From<u8> for CubeExtension<B> {
292 fn from(value: u8) -> Self {
293 Self(B::from(value), B::ZERO, B::ZERO)
294 }
295}
296
297impl<B: ExtensibleField<3>> TryFrom<u64> for CubeExtension<B> {
298 type Error = String;
299
300 fn try_from(value: u64) -> Result<Self, Self::Error> {
301 match B::try_from(value) {
302 Ok(elem) => Ok(Self::from(elem)),
303 Err(_) => Err(format!(
304 "invalid field element: value {value} is greater than or equal to the field modulus"
305 )),
306 }
307 }
308}
309
310impl<B: ExtensibleField<3>> TryFrom<u128> for CubeExtension<B> {
311 type Error = String;
312
313 fn try_from(value: u128) -> Result<Self, Self::Error> {
314 match B::try_from(value) {
315 Ok(elem) => Ok(Self::from(elem)),
316 Err(_) => Err(format!(
317 "invalid field element: value {value} is greater than or equal to the field modulus"
318 )),
319 }
320 }
321}
322
323impl<B: ExtensibleField<3>> TryFrom<&'_ [u8]> for CubeExtension<B> {
324 type Error = DeserializationError;
325
326 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
329 if bytes.len() < Self::ELEMENT_BYTES {
330 return Err(DeserializationError::InvalidValue(format!(
331 "not enough bytes for a full field element; expected {} bytes, but was {} bytes",
332 Self::ELEMENT_BYTES,
333 bytes.len(),
334 )));
335 }
336 if bytes.len() > Self::ELEMENT_BYTES {
337 return Err(DeserializationError::InvalidValue(format!(
338 "too many bytes for a field element; expected {} bytes, but was {} bytes",
339 Self::ELEMENT_BYTES,
340 bytes.len(),
341 )));
342 }
343 let mut reader = SliceReader::new(bytes);
344 Self::read_from(&mut reader)
345 }
346}
347
348impl<B: ExtensibleField<3>> AsBytes for CubeExtension<B> {
349 fn as_bytes(&self) -> &[u8] {
350 let self_ptr: *const Self = self;
352 unsafe { slice::from_raw_parts(self_ptr as *const u8, Self::ELEMENT_BYTES) }
353 }
354}
355
356impl<B: ExtensibleField<3>> Serializable for CubeExtension<B> {
360 fn write_into<W: ByteWriter>(&self, target: &mut W) {
361 self.0.write_into(target);
362 self.1.write_into(target);
363 self.2.write_into(target);
364 }
365}
366
367impl<B: ExtensibleField<3>> Deserializable for CubeExtension<B> {
368 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
369 let value0 = B::read_from(source)?;
370 let value1 = B::read_from(source)?;
371 let value2 = B::read_from(source)?;
372 Ok(Self(value0, value1, value2))
373 }
374}
375
376#[cfg(test)]
380mod tests {
381 use rand_utils::rand_value;
382
383 use super::{CubeExtension, DeserializationError, FieldElement};
384 use crate::field::f64::BaseElement;
385
386 #[test]
390 fn add() {
391 let r: CubeExtension<BaseElement> = rand_value();
393 assert_eq!(r, r + CubeExtension::<BaseElement>::ZERO);
394
395 let r1: CubeExtension<BaseElement> = rand_value();
397 let r2: CubeExtension<BaseElement> = rand_value();
398
399 let expected = CubeExtension(r1.0 + r2.0, r1.1 + r2.1, r1.2 + r2.2);
400 assert_eq!(expected, r1 + r2);
401 }
402
403 #[test]
404 fn sub() {
405 let r: CubeExtension<BaseElement> = rand_value();
407 assert_eq!(r, r - CubeExtension::<BaseElement>::ZERO);
408
409 let r1: CubeExtension<BaseElement> = rand_value();
411 let r2: CubeExtension<BaseElement> = rand_value();
412
413 let expected = CubeExtension(r1.0 - r2.0, r1.1 - r2.1, r1.2 - r2.2);
414 assert_eq!(expected, r1 - r2);
415 }
416
417 #[test]
421 fn elements_as_bytes() {
422 let source = vec![
423 CubeExtension(BaseElement::new(1), BaseElement::new(2), BaseElement::new(3)),
424 CubeExtension(BaseElement::new(4), BaseElement::new(5), BaseElement::new(6)),
425 ];
426
427 let mut expected = vec![];
428 expected.extend_from_slice(&source[0].0.inner().to_le_bytes());
429 expected.extend_from_slice(&source[0].1.inner().to_le_bytes());
430 expected.extend_from_slice(&source[0].2.inner().to_le_bytes());
431 expected.extend_from_slice(&source[1].0.inner().to_le_bytes());
432 expected.extend_from_slice(&source[1].1.inner().to_le_bytes());
433 expected.extend_from_slice(&source[1].2.inner().to_le_bytes());
434
435 assert_eq!(expected, CubeExtension::<BaseElement>::elements_as_bytes(&source));
436 }
437
438 #[test]
439 fn bytes_as_elements() {
440 let elements = vec![
441 CubeExtension(BaseElement::new(1), BaseElement::new(2), BaseElement::new(3)),
442 CubeExtension(BaseElement::new(4), BaseElement::new(5), BaseElement::new(6)),
443 ];
444
445 let mut bytes = vec![];
446 bytes.extend_from_slice(&elements[0].0.inner().to_le_bytes());
447 bytes.extend_from_slice(&elements[0].1.inner().to_le_bytes());
448 bytes.extend_from_slice(&elements[0].2.inner().to_le_bytes());
449 bytes.extend_from_slice(&elements[1].0.inner().to_le_bytes());
450 bytes.extend_from_slice(&elements[1].1.inner().to_le_bytes());
451 bytes.extend_from_slice(&elements[1].2.inner().to_le_bytes());
452 bytes.extend_from_slice(&BaseElement::new(5).inner().to_le_bytes());
453
454 let result = unsafe { CubeExtension::<BaseElement>::bytes_as_elements(&bytes[..48]) };
455 assert!(result.is_ok());
456 assert_eq!(elements, result.unwrap());
457
458 let result = unsafe { CubeExtension::<BaseElement>::bytes_as_elements(&bytes) };
459 assert!(matches!(result, Err(DeserializationError::InvalidValue(_))));
460
461 let result = unsafe { CubeExtension::<BaseElement>::bytes_as_elements(&bytes[1..]) };
462 assert!(matches!(result, Err(DeserializationError::InvalidValue(_))));
463 }
464
465 #[test]
469 fn as_base_elements() {
470 let elements = vec![
471 CubeExtension(BaseElement::new(1), BaseElement::new(2), BaseElement::new(3)),
472 CubeExtension(BaseElement::new(4), BaseElement::new(5), BaseElement::new(6)),
473 ];
474
475 let expected = vec![
476 BaseElement::new(1),
477 BaseElement::new(2),
478 BaseElement::new(3),
479 BaseElement::new(4),
480 BaseElement::new(5),
481 BaseElement::new(6),
482 ];
483
484 assert_eq!(expected, CubeExtension::<BaseElement>::slice_as_base_elements(&elements));
485 }
486}