1use alloc::{
13 string::{String, ToString},
14 vec::Vec,
15};
16use core::{
17 fmt::{Debug, Display, Formatter},
18 mem,
19 ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
20 slice,
21};
22
23#[cfg(feature = "serde")]
24use serde::{Deserialize, Serialize};
25use utils::{
26 AsBytes, ByteReader, ByteWriter, Deserializable, DeserializationError, Randomizable,
27 Serializable,
28};
29
30use super::{ExtensibleField, FieldElement, StarkField};
31
32#[cfg(test)]
33mod tests;
34
35const M: u64 = 4611624995532046337;
40
41const R2: u64 = 630444561284293700;
43
44const R3: u64 = 732984146687909319;
46
47const U: u128 = 4611624995532046335;
49
50const ELEMENT_BYTES: usize = core::mem::size_of::<u64>();
52
53const G: u64 = 4421547261963328785;
55
56#[derive(Copy, Clone, Default)]
64#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
65#[cfg_attr(feature = "serde", serde(try_from = "u64", into = "u64"))]
66pub struct BaseElement(u64);
67
68impl BaseElement {
69 pub const fn new(value: u64) -> BaseElement {
72 let z = mul(value, R2);
76 BaseElement(z)
77 }
78}
79
80impl FieldElement for BaseElement {
81 type PositiveInteger = u64;
82 type BaseField = Self;
83
84 const EXTENSION_DEGREE: usize = 1;
85
86 const ZERO: Self = BaseElement::new(0);
87 const ONE: Self = BaseElement::new(1);
88
89 const ELEMENT_BYTES: usize = ELEMENT_BYTES;
90 const IS_CANONICAL: bool = false;
91
92 #[inline]
96 fn double(self) -> Self {
97 let z = self.0 << 1;
98 let q = (z >> 62) * M;
99 Self(z - q)
100 }
101
102 fn exp(self, power: Self::PositiveInteger) -> Self {
103 let mut b = self;
104
105 if power == 0 {
106 return Self::ONE;
107 } else if b == Self::ZERO {
108 return Self::ZERO;
109 }
110
111 let mut r = if power & 1 == 1 { b } else { Self::ONE };
112 for i in 1..64 - power.leading_zeros() {
113 b = b.square();
114 if (power >> i) & 1 == 1 {
115 r *= b;
116 }
117 }
118
119 r
120 }
121
122 fn inv(self) -> Self {
123 BaseElement(inv(self.0))
124 }
125
126 fn conjugate(&self) -> Self {
127 BaseElement(self.0)
128 }
129
130 fn base_element(&self, i: usize) -> Self::BaseField {
134 match i {
135 0 => *self,
136 _ => panic!("element index must be 0, but was {i}"),
137 }
138 }
139
140 fn slice_as_base_elements(elements: &[Self]) -> &[Self::BaseField] {
141 elements
142 }
143
144 fn slice_from_base_elements(elements: &[Self::BaseField]) -> &[Self] {
145 elements
146 }
147
148 fn elements_as_bytes(elements: &[Self]) -> &[u8] {
152 let p = elements.as_ptr();
154 let len = elements.len() * Self::ELEMENT_BYTES;
155 unsafe { slice::from_raw_parts(p as *const u8, len) }
156 }
157
158 unsafe fn bytes_as_elements(bytes: &[u8]) -> Result<&[Self], DeserializationError> {
159 if bytes.len() % Self::ELEMENT_BYTES != 0 {
160 return Err(DeserializationError::InvalidValue(format!(
161 "number of bytes ({}) does not divide into whole number of field elements",
162 bytes.len(),
163 )));
164 }
165
166 let p = bytes.as_ptr();
167 let len = bytes.len() / Self::ELEMENT_BYTES;
168
169 if (p as usize) % mem::align_of::<u64>() != 0 {
170 return Err(DeserializationError::InvalidValue(
171 "slice memory alignment is not valid for this field element type".to_string(),
172 ));
173 }
174
175 Ok(slice::from_raw_parts(p as *const Self, len))
176 }
177}
178
179impl StarkField for BaseElement {
180 const MODULUS: Self::PositiveInteger = M;
186 const MODULUS_BITS: u32 = 62;
187
188 const GENERATOR: Self = BaseElement::new(3);
191
192 const TWO_ADICITY: u32 = 39;
195
196 const TWO_ADIC_ROOT_OF_UNITY: Self = BaseElement::new(G);
200
201 fn get_modulus_le_bytes() -> Vec<u8> {
202 Self::MODULUS.to_le_bytes().to_vec()
203 }
204
205 #[inline]
206 fn as_int(&self) -> Self::PositiveInteger {
207 let result = mul(self.0, 1);
209 normalize(result)
211 }
212}
213
214impl Randomizable for BaseElement {
215 const VALUE_SIZE: usize = Self::ELEMENT_BYTES;
216
217 fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
218 Self::try_from(bytes).ok()
219 }
220}
221
222impl Debug for BaseElement {
223 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
224 write!(f, "{}", self)
225 }
226}
227
228impl Display for BaseElement {
229 fn fmt(&self, f: &mut Formatter) -> core::fmt::Result {
230 write!(f, "{}", self.as_int())
231 }
232}
233
234impl PartialEq for BaseElement {
238 #[inline]
239 fn eq(&self, other: &Self) -> bool {
240 normalize(self.0) == normalize(other.0)
243 }
244}
245
246impl Eq for BaseElement {}
247
248impl Add for BaseElement {
252 type Output = Self;
253
254 fn add(self, rhs: Self) -> Self {
255 Self(add(self.0, rhs.0))
256 }
257}
258
259impl AddAssign for BaseElement {
260 fn add_assign(&mut self, rhs: Self) {
261 *self = *self + rhs
262 }
263}
264
265impl Sub for BaseElement {
266 type Output = Self;
267
268 fn sub(self, rhs: Self) -> Self {
269 Self(sub(self.0, rhs.0))
270 }
271}
272
273impl SubAssign for BaseElement {
274 fn sub_assign(&mut self, rhs: Self) {
275 *self = *self - rhs;
276 }
277}
278
279impl Mul for BaseElement {
280 type Output = Self;
281
282 fn mul(self, rhs: Self) -> Self {
283 Self(mul(self.0, rhs.0))
284 }
285}
286
287impl MulAssign for BaseElement {
288 fn mul_assign(&mut self, rhs: Self) {
289 *self = *self * rhs
290 }
291}
292
293impl Div for BaseElement {
294 type Output = Self;
295
296 fn div(self, rhs: Self) -> Self {
297 Self(mul(self.0, inv(rhs.0)))
298 }
299}
300
301impl DivAssign for BaseElement {
302 fn div_assign(&mut self, rhs: Self) {
303 *self = *self / rhs
304 }
305}
306
307impl Neg for BaseElement {
308 type Output = Self;
309
310 fn neg(self) -> Self {
311 Self(sub(0, self.0))
312 }
313}
314
315impl ExtensibleField<2> for BaseElement {
322 #[inline(always)]
323 fn mul(a: [Self; 2], b: [Self; 2]) -> [Self; 2] {
324 let z = a[0] * b[0];
325 [z + a[1] * b[1], (a[0] + a[1]) * (b[0] + b[1]) - z]
326 }
327
328 #[inline(always)]
329 fn mul_base(a: [Self; 2], b: Self) -> [Self; 2] {
330 [a[0] * b, a[1] * b]
331 }
332
333 #[inline(always)]
334 fn frobenius(x: [Self; 2]) -> [Self; 2] {
335 [x[0] + x[1], -x[1]]
336 }
337}
338
339impl ExtensibleField<3> for BaseElement {
346 #[inline(always)]
347 fn mul(a: [Self; 3], b: [Self; 3]) -> [Self; 3] {
348 let a0b0 = a[0] * b[0];
352 let a1b1 = a[1] * b[1];
353 let a2b2 = a[2] * b[2];
354
355 let a0b0_a0b1_a1b0_a1b1 = (a[0] + a[1]) * (b[0] + b[1]);
356 let minus_a0b0_a0b2_a2b0_minus_a2b2 = (a[0] - a[2]) * (b[2] - b[0]);
357 let a1b1_minus_a1b2_minus_a2b1_a2b2 = (a[1] - a[2]) * (b[1] - b[2]);
358 let a0b0_a1b1 = a0b0 + a1b1;
359
360 let minus_2a1b2_minus_2a2b1 = (a1b1_minus_a1b2_minus_a2b1_a2b2 - a1b1 - a2b2).double();
361
362 let a0b0_minus_2a1b2_minus_2a2b1 = a0b0 + minus_2a1b2_minus_2a2b1;
363 let a0b1_a1b0_minus_2a1b2_minus_2a2b1_minus_2a2b2 =
364 a0b0_a0b1_a1b0_a1b1 + minus_2a1b2_minus_2a2b1 - a2b2.double() - a0b0_a1b1;
365 let a0b2_a1b1_a2b0_minus_2a2b2 = minus_a0b0_a0b2_a2b0_minus_a2b2 + a0b0_a1b1 - a2b2;
366 [
367 a0b0_minus_2a1b2_minus_2a2b1,
368 a0b1_a1b0_minus_2a1b2_minus_2a2b1_minus_2a2b2,
369 a0b2_a1b1_a2b0_minus_2a2b2,
370 ]
371 }
372
373 #[inline(always)]
374 fn mul_base(a: [Self; 3], b: Self) -> [Self; 3] {
375 [a[0] * b, a[1] * b, a[2] * b]
376 }
377
378 #[inline(always)]
379 fn frobenius(x: [Self; 3]) -> [Self; 3] {
380 [
382 x[0] + BaseElement::new(2061766055618274781) * x[1]
383 + BaseElement::new(786836585661389001) * x[2],
384 BaseElement::new(2868591307402993000) * x[1]
385 + BaseElement::new(3336695525575160559) * x[2],
386 BaseElement::new(2699230790596717670) * x[1]
387 + BaseElement::new(1743033688129053336) * x[2],
388 ]
389 }
390}
391
392impl From<u32> for BaseElement {
396 fn from(value: u32) -> Self {
398 BaseElement::new(value as u64)
399 }
400}
401
402impl From<u16> for BaseElement {
403 fn from(value: u16) -> Self {
405 BaseElement::new(value as u64)
406 }
407}
408
409impl From<u8> for BaseElement {
410 fn from(value: u8) -> Self {
412 BaseElement::new(value as u64)
413 }
414}
415
416impl From<BaseElement> for u128 {
417 fn from(value: BaseElement) -> Self {
418 value.as_int() as u128
419 }
420}
421
422impl From<BaseElement> for u64 {
423 fn from(value: BaseElement) -> Self {
424 value.as_int()
425 }
426}
427
428impl TryFrom<u64> for BaseElement {
429 type Error = String;
430
431 fn try_from(value: u64) -> Result<Self, Self::Error> {
432 if value >= M {
433 Err(format!(
434 "invalid field element: value {value} is greater than or equal to the field modulus"
435 ))
436 } else {
437 Ok(Self::new(value))
438 }
439 }
440}
441
442impl TryFrom<u128> for BaseElement {
443 type Error = String;
444
445 fn try_from(value: u128) -> Result<Self, Self::Error> {
446 if value >= M as u128 {
447 Err(format!(
448 "invalid field element: value {value} is greater than or equal to the field modulus"
449 ))
450 } else {
451 Ok(Self::new(value as u64))
452 }
453 }
454}
455
456impl TryFrom<[u8; 8]> for BaseElement {
457 type Error = String;
458
459 fn try_from(bytes: [u8; 8]) -> Result<Self, Self::Error> {
460 let value = u64::from_le_bytes(bytes);
461 Self::try_from(value)
462 }
463}
464
465impl TryFrom<&'_ [u8]> for BaseElement {
466 type Error = DeserializationError;
467
468 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
472 if bytes.len() < ELEMENT_BYTES {
473 return Err(DeserializationError::InvalidValue(format!(
474 "not enough bytes for a full field element; expected {} bytes, but was {} bytes",
475 ELEMENT_BYTES,
476 bytes.len(),
477 )));
478 }
479 if bytes.len() > ELEMENT_BYTES {
480 return Err(DeserializationError::InvalidValue(format!(
481 "too many bytes for a field element; expected {} bytes, but was {} bytes",
482 ELEMENT_BYTES,
483 bytes.len(),
484 )));
485 }
486 let value = bytes
487 .try_into()
488 .map(u64::from_le_bytes)
489 .map_err(|error| DeserializationError::UnknownError(format!("{error}")))?;
490 if value >= M {
491 return Err(DeserializationError::InvalidValue(format!(
492 "invalid field element: value {value} is greater than or equal to the field modulus"
493 )));
494 }
495 Ok(BaseElement::new(value))
496 }
497}
498
499impl AsBytes for BaseElement {
500 fn as_bytes(&self) -> &[u8] {
501 let self_ptr: *const BaseElement = self;
503 unsafe { slice::from_raw_parts(self_ptr as *const u8, ELEMENT_BYTES) }
504 }
505}
506
507impl Serializable for BaseElement {
511 fn write_into<W: ByteWriter>(&self, target: &mut W) {
512 target.write_bytes(&self.as_int().to_le_bytes());
514 }
515
516 fn get_size_hint(&self) -> usize {
517 self.as_int().get_size_hint()
518 }
519}
520
521impl Deserializable for BaseElement {
522 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
523 let value = source.read_u64()?;
524 if value >= M {
525 return Err(DeserializationError::InvalidValue(format!(
526 "invalid field element: value {value} is greater than or equal to the field modulus"
527 )));
528 }
529 Ok(BaseElement::new(value))
530 }
531}
532
533#[inline(always)]
539fn add(a: u64, b: u64) -> u64 {
540 let z = a + b;
541 let q = (z >> 62) * M;
542 z - q
543}
544
545#[inline(always)]
548fn sub(a: u64, b: u64) -> u64 {
549 if a < b {
550 2 * M - b + a
551 } else {
552 a - b
553 }
554}
555
556#[inline(always)]
559const fn mul(a: u64, b: u64) -> u64 {
560 let z = (a as u128) * (b as u128);
561 let q = (((z as u64) as u128) * U) as u64;
562 let z = z + (q as u128) * (M as u128);
563 (z >> 64) as u64
564}
565
566#[inline(always)]
569#[allow(clippy::many_single_char_names)]
570fn inv(x: u64) -> u64 {
571 if x == 0 {
572 return 0;
573 };
574
575 let mut a: u128 = 0;
576 let mut u: u128 = if x & 1 == 1 {
577 x as u128
578 } else {
579 (x as u128) + (M as u128)
580 };
581 let mut v: u128 = M as u128;
582 let mut d = (M as u128) - 1;
583
584 while v != 1 {
585 while v < u {
586 u -= v;
587 d += a;
588 while u & 1 == 0 {
589 if d & 1 == 1 {
590 d += M as u128;
591 }
592 u >>= 1;
593 d >>= 1;
594 }
595 }
596
597 v -= u;
598 a += d;
599
600 while v & 1 == 0 {
601 if a & 1 == 1 {
602 a += M as u128;
603 }
604 v >>= 1;
605 a >>= 1;
606 }
607 }
608
609 while a > (M as u128) {
610 a -= M as u128;
611 }
612
613 mul(a as u64, R3)
614}
615
616#[inline(always)]
621fn normalize(value: u64) -> u64 {
622 if value >= M {
623 value - M
624 } else {
625 value
626 }
627}