1use alloc::{
20 string::{String, ToString},
21 vec::Vec,
22};
23use core::{
24 fmt::{Debug, Display, Formatter},
25 mem,
26 ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
27 slice,
28};
29
30#[cfg(feature = "serde")]
31use serde::{Deserialize, Serialize};
32use utils::{
33 AsBytes, ByteReader, ByteWriter, Deserializable, DeserializationError, Randomizable,
34 Serializable,
35};
36
37use super::{ExtensibleField, FieldElement, StarkField};
38
39#[cfg(test)]
40mod tests;
41
42const M: u64 = 0xffffffff00000001;
47
48const R2: u64 = 0xfffffffe00000001;
50
51const ELEMENT_BYTES: usize = core::mem::size_of::<u64>();
53
54#[derive(Copy, Clone, Default)]
62#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
63#[cfg_attr(feature = "serde", serde(try_from = "u64", into = "u64"))]
64pub struct BaseElement(u64);
65
66impl BaseElement {
67 pub const fn new(value: u64) -> BaseElement {
73 Self(mont_red_cst((value as u128) * (R2 as u128)))
74 }
75
76 pub const fn from_mont(value: u64) -> BaseElement {
79 BaseElement(value)
80 }
81
82 pub const fn inner(&self) -> u64 {
84 self.0
85 }
86
87 #[inline(always)]
89 pub const fn as_int(&self) -> u64 {
90 mont_to_int(self.0)
91 }
92
93 #[inline(always)]
96 pub fn exp7(self) -> Self {
97 let x2 = self.square();
98 let x4 = x2.square();
99 let x3 = x2 * self;
100 x3 * x4
101 }
102
103 #[inline(always)]
106 pub const fn mul_small(self, rhs: u32) -> Self {
107 let s = (self.inner() as u128) * (rhs as u128);
108 let s_hi = (s >> 64) as u64;
109 let s_lo = s as u64;
110 let z = (s_hi << 32) - s_hi;
111 let (res, over) = s_lo.overflowing_add(z);
112
113 BaseElement::from_mont(res.wrapping_add(0u32.wrapping_sub(over as u32) as u64))
114 }
115}
116
117impl FieldElement for BaseElement {
118 type PositiveInteger = u64;
119 type BaseField = Self;
120
121 const EXTENSION_DEGREE: usize = 1;
122
123 const ZERO: Self = Self::new(0);
124 const ONE: Self = Self::new(1);
125
126 const ELEMENT_BYTES: usize = ELEMENT_BYTES;
127 const IS_CANONICAL: bool = false;
128
129 #[inline]
133 fn double(self) -> Self {
134 let ret = (self.0 as u128) << 1;
135 let (result, over) = (ret as u64, (ret >> 64) as u64);
136 Self(result.wrapping_sub(M * over))
137 }
138
139 #[inline]
140 fn exp(self, power: Self::PositiveInteger) -> Self {
141 let mut b: Self;
142 let mut r = Self::ONE;
143 for i in (0..64).rev() {
144 r = r.square();
145 b = r;
146 b *= self;
147 let mask = -(((power >> i) & 1 == 1) as i64) as u64;
149 r.0 ^= mask & (r.0 ^ b.0);
150 }
151
152 r
153 }
154
155 #[inline]
156 #[allow(clippy::many_single_char_names)]
157 fn inv(self) -> Self {
158 let t2 = self.square() * self;
163
164 let t3 = t2.square() * self;
166
167 let t6 = exp_acc::<3>(t3, t3);
169
170 let t12 = exp_acc::<6>(t6, t6);
172
173 let t24 = exp_acc::<12>(t12, t12);
175
176 let t30 = exp_acc::<6>(t24, t6);
178 let t31 = t30.square() * self;
179
180 let t63 = exp_acc::<32>(t31, t31);
182
183 t63.square() * self
185 }
186
187 fn conjugate(&self) -> Self {
188 Self(self.0)
189 }
190
191 fn base_element(&self, i: usize) -> Self::BaseField {
195 match i {
196 0 => *self,
197 _ => panic!("element index must be 0, but was {i}"),
198 }
199 }
200
201 fn slice_as_base_elements(elements: &[Self]) -> &[Self::BaseField] {
202 elements
203 }
204
205 fn slice_from_base_elements(elements: &[Self::BaseField]) -> &[Self] {
206 elements
207 }
208
209 fn elements_as_bytes(elements: &[Self]) -> &[u8] {
213 let p = elements.as_ptr();
215 let len = elements.len() * Self::ELEMENT_BYTES;
216 unsafe { slice::from_raw_parts(p as *const u8, len) }
217 }
218
219 unsafe fn bytes_as_elements(bytes: &[u8]) -> Result<&[Self], DeserializationError> {
220 if !bytes.len().is_multiple_of(Self::ELEMENT_BYTES) {
221 return Err(DeserializationError::InvalidValue(format!(
222 "number of bytes ({}) does not divide into whole number of field elements",
223 bytes.len(),
224 )));
225 }
226
227 let p = bytes.as_ptr();
228 let len = bytes.len() / Self::ELEMENT_BYTES;
229
230 if !(p as usize).is_multiple_of(mem::align_of::<u64>()) {
231 return Err(DeserializationError::InvalidValue(
232 "slice memory alignment is not valid for this field element type".to_string(),
233 ));
234 }
235
236 Ok(slice::from_raw_parts(p as *const Self, len))
237 }
238}
239
240impl StarkField for BaseElement {
241 const MODULUS: Self::PositiveInteger = M;
247 const MODULUS_BITS: u32 = 64;
248
249 const GENERATOR: Self = Self::new(7);
252
253 const TWO_ADICITY: u32 = 32;
256
257 const TWO_ADIC_ROOT_OF_UNITY: Self = Self::new(7277203076849721926);
268
269 fn get_modulus_le_bytes() -> Vec<u8> {
270 M.to_le_bytes().to_vec()
271 }
272
273 #[inline]
274 fn as_int(&self) -> Self::PositiveInteger {
275 mont_to_int(self.0)
276 }
277}
278
279impl Randomizable for BaseElement {
280 const VALUE_SIZE: usize = Self::ELEMENT_BYTES;
281
282 fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
283 Self::try_from(bytes).ok()
284 }
285}
286
287impl Debug for BaseElement {
288 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
289 write!(f, "{self}")
290 }
291}
292
293impl Display for BaseElement {
294 fn fmt(&self, f: &mut Formatter) -> core::fmt::Result {
295 write!(f, "{}", self.as_int())
296 }
297}
298
299impl PartialEq for BaseElement {
303 #[inline]
304 fn eq(&self, other: &Self) -> bool {
305 equals(self.0, other.0) == 0xffffffffffffffff
306 }
307}
308
309impl Eq for BaseElement {}
310
311impl Add for BaseElement {
315 type Output = Self;
316
317 #[inline]
318 #[allow(clippy::suspicious_arithmetic_impl)]
319 fn add(self, rhs: Self) -> Self {
320 let (x1, c1) = self.0.overflowing_sub(M - rhs.0);
322 let adj = 0u32.wrapping_sub(c1 as u32);
323 Self(x1.wrapping_sub(adj as u64))
324 }
325}
326
327impl AddAssign for BaseElement {
328 #[inline]
329 fn add_assign(&mut self, rhs: Self) {
330 *self = *self + rhs
331 }
332}
333
334impl Sub for BaseElement {
335 type Output = Self;
336
337 #[inline]
338 #[allow(clippy::suspicious_arithmetic_impl)]
339 fn sub(self, rhs: Self) -> Self {
340 let (x1, c1) = self.0.overflowing_sub(rhs.0);
341 let adj = 0u32.wrapping_sub(c1 as u32);
342 Self(x1.wrapping_sub(adj as u64))
343 }
344}
345
346impl SubAssign for BaseElement {
347 #[inline]
348 fn sub_assign(&mut self, rhs: Self) {
349 *self = *self - rhs;
350 }
351}
352
353impl Mul for BaseElement {
354 type Output = Self;
355
356 #[inline]
357 fn mul(self, rhs: Self) -> Self {
358 Self(mont_red_cst((self.0 as u128) * (rhs.0 as u128)))
359 }
360}
361
362impl MulAssign for BaseElement {
363 #[inline]
364 fn mul_assign(&mut self, rhs: Self) {
365 *self = *self * rhs
366 }
367}
368
369impl Div for BaseElement {
370 type Output = Self;
371
372 #[inline]
373 #[allow(clippy::suspicious_arithmetic_impl)]
374 fn div(self, rhs: Self) -> Self {
375 self * rhs.inv()
376 }
377}
378
379impl DivAssign for BaseElement {
380 #[inline]
381 fn div_assign(&mut self, rhs: Self) {
382 *self = *self / rhs
383 }
384}
385
386impl Neg for BaseElement {
387 type Output = Self;
388
389 #[inline]
390 fn neg(self) -> Self {
391 Self::ZERO - self
392 }
393}
394
395impl ExtensibleField<2> for BaseElement {
402 #[inline(always)]
403 fn mul(a: [Self; 2], b: [Self; 2]) -> [Self; 2] {
404 let a0b0 = a[0] * b[0];
408 [a0b0 - (a[1] * b[1]).double(), (a[0] + a[1]) * (b[0] + b[1]) - a0b0]
409 }
410
411 #[inline(always)]
412 fn square(a: [Self; 2]) -> [Self; 2] {
413 let a0 = a[0];
414 let a1 = a[1];
415
416 let a1_sq = a1.square();
417
418 let out0 = a0.square() - a1_sq.double();
419 let out1 = (a0 * a1).double() + a1_sq;
420
421 [out0, out1]
422 }
423
424 #[inline(always)]
425 fn mul_base(a: [Self; 2], b: Self) -> [Self; 2] {
426 [a[0] * b, a[1] * b]
429 }
430
431 #[inline(always)]
432 fn frobenius(x: [Self; 2]) -> [Self; 2] {
433 [x[0] + x[1], -x[1]]
434 }
435}
436
437impl ExtensibleField<3> for BaseElement {
444 #[inline(always)]
445 fn mul(a: [Self; 3], b: [Self; 3]) -> [Self; 3] {
446 let a0b0 = a[0] * b[0];
450 let a1b1 = a[1] * b[1];
451 let a2b2 = a[2] * b[2];
452
453 let a0b0_a0b1_a1b0_a1b1 = (a[0] + a[1]) * (b[0] + b[1]);
454 let a0b0_a0b2_a2b0_a2b2 = (a[0] + a[2]) * (b[0] + b[2]);
455 let a1b1_a1b2_a2b1_a2b2 = (a[1] + a[2]) * (b[1] + b[2]);
456
457 let a0b0_minus_a1b1 = a0b0 - a1b1;
458
459 let a0b0_a1b2_a2b1 = a1b1_a1b2_a2b1_a2b2 + a0b0_minus_a1b1 - a2b2;
460 let a0b1_a1b0_a1b2_a2b1_a2b2 =
461 a0b0_a0b1_a1b0_a1b1 + a1b1_a1b2_a2b1_a2b2 - a1b1.double() - a0b0;
462 let a0b2_a1b1_a2b0_a2b2 = a0b0_a0b2_a2b0_a2b2 - a0b0_minus_a1b1;
463
464 [a0b0_a1b2_a2b1, a0b1_a1b0_a1b2_a2b1_a2b2, a0b2_a1b1_a2b0_a2b2]
465 }
466
467 #[inline(always)]
468 fn square(a: [Self; 3]) -> [Self; 3] {
469 let a0 = a[0];
470 let a1 = a[1];
471 let a2 = a[2];
472
473 let a2_sq = a2.square();
474 let a1_a2 = a1 * a2;
475
476 let out0 = a0.square() + a1_a2.double();
477 let out1 = (a0 * a1 + a1_a2).double() + a2_sq;
478 let out2 = (a0 * a2).double() + a1.square() + a2_sq;
479
480 [out0, out1, out2]
481 }
482
483 #[inline(always)]
484 fn mul_base(a: [Self; 3], b: Self) -> [Self; 3] {
485 [a[0] * b, a[1] * b, a[2] * b]
488 }
489
490 #[inline(always)]
491 fn frobenius(x: [Self; 3]) -> [Self; 3] {
492 [
494 x[0] + Self::new(10615703402128488253) * x[1] + Self::new(6700183068485440220) * x[2],
495 Self::new(10050274602728160328) * x[1] + Self::new(14531223735771536287) * x[2],
496 Self::new(11746561000929144102) * x[1] + Self::new(8396469466686423992) * x[2],
497 ]
498 }
499}
500
501impl From<bool> for BaseElement {
505 fn from(value: bool) -> Self {
506 Self::new(value.into())
507 }
508}
509
510impl From<u8> for BaseElement {
511 fn from(value: u8) -> Self {
512 Self::new(value.into())
513 }
514}
515
516impl From<u16> for BaseElement {
517 fn from(value: u16) -> Self {
518 Self::new(value.into())
519 }
520}
521
522impl From<u32> for BaseElement {
523 fn from(value: u32) -> Self {
524 Self::new(value.into())
525 }
526}
527
528impl TryFrom<u64> for BaseElement {
529 type Error = String;
530
531 fn try_from(value: u64) -> Result<Self, Self::Error> {
532 if value >= M {
533 Err(format!(
534 "invalid field element: value {value} is greater than or equal to the field modulus"
535 ))
536 } else {
537 Ok(Self::new(value))
538 }
539 }
540}
541
542impl TryFrom<u128> for BaseElement {
543 type Error = String;
544
545 fn try_from(value: u128) -> Result<Self, Self::Error> {
546 if value >= M.into() {
547 Err(format!(
548 "invalid field element: value {value} is greater than or equal to the field modulus"
549 ))
550 } else {
551 Ok(Self::new(value as u64))
552 }
553 }
554}
555
556impl TryFrom<usize> for BaseElement {
557 type Error = String;
558
559 fn try_from(value: usize) -> Result<Self, Self::Error> {
560 match u64::try_from(value) {
561 Err(_) => Err(format!("invalid field element: value {value} does not fit in a u64")),
562 Ok(v) => v.try_into(),
563 }
564 }
565}
566
567impl TryFrom<[u8; 8]> for BaseElement {
568 type Error = String;
569
570 fn try_from(bytes: [u8; 8]) -> Result<Self, Self::Error> {
571 let value = u64::from_le_bytes(bytes);
572 Self::try_from(value)
573 }
574}
575
576impl TryFrom<&'_ [u8]> for BaseElement {
577 type Error = DeserializationError;
578
579 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
583 if bytes.len() < ELEMENT_BYTES {
584 return Err(DeserializationError::InvalidValue(format!(
585 "not enough bytes for a full field element; expected {} bytes, but was {} bytes",
586 ELEMENT_BYTES,
587 bytes.len(),
588 )));
589 }
590 if bytes.len() > ELEMENT_BYTES {
591 return Err(DeserializationError::InvalidValue(format!(
592 "too many bytes for a field element; expected {} bytes, but was {} bytes",
593 ELEMENT_BYTES,
594 bytes.len(),
595 )));
596 }
597 let bytes: [u8; 8] = bytes.try_into().expect("slice to array conversion failed");
598 bytes.try_into().map_err(DeserializationError::InvalidValue)
599 }
600}
601
602impl TryFrom<BaseElement> for bool {
603 type Error = String;
604
605 fn try_from(value: BaseElement) -> Result<Self, Self::Error> {
606 match value.as_int() {
607 0 => Ok(false),
608 1 => Ok(true),
609 v => Err(format!("Field element does not represent a boolean, got {v}")),
610 }
611 }
612}
613
614impl TryFrom<BaseElement> for u8 {
615 type Error = String;
616
617 fn try_from(value: BaseElement) -> Result<Self, Self::Error> {
618 value.as_int().try_into().map_err(|e| format!("{e}"))
619 }
620}
621
622impl TryFrom<BaseElement> for u16 {
623 type Error = String;
624
625 fn try_from(value: BaseElement) -> Result<Self, Self::Error> {
626 value.as_int().try_into().map_err(|e| format!("{e}"))
627 }
628}
629
630impl TryFrom<BaseElement> for u32 {
631 type Error = String;
632
633 fn try_from(value: BaseElement) -> Result<Self, Self::Error> {
634 value.as_int().try_into().map_err(|e| format!("{e}"))
635 }
636}
637
638impl From<BaseElement> for u64 {
639 fn from(value: BaseElement) -> Self {
640 value.as_int()
641 }
642}
643
644impl From<BaseElement> for u128 {
645 fn from(value: BaseElement) -> Self {
646 value.as_int().into()
647 }
648}
649
650impl AsBytes for BaseElement {
651 fn as_bytes(&self) -> &[u8] {
652 let self_ptr: *const BaseElement = self;
654 unsafe { slice::from_raw_parts(self_ptr as *const u8, ELEMENT_BYTES) }
655 }
656}
657
658impl Serializable for BaseElement {
662 fn write_into<W: ByteWriter>(&self, target: &mut W) {
663 target.write_bytes(&self.as_int().to_le_bytes());
665 }
666
667 fn get_size_hint(&self) -> usize {
668 self.as_int().get_size_hint()
669 }
670}
671
672impl Deserializable for BaseElement {
673 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
674 let value = source.read_u64()?;
675 if value >= M {
676 return Err(DeserializationError::InvalidValue(format!(
677 "invalid field element: value {value} is greater than or equal to the field modulus"
678 )));
679 }
680 Ok(Self::new(value))
681 }
682}
683
684#[inline(always)]
689fn exp_acc<const N: usize>(base: BaseElement, tail: BaseElement) -> BaseElement {
690 let mut result = base;
691 for _ in 0..N {
692 result = result.square();
693 }
694 result * tail
695}
696
697#[allow(dead_code)]
699#[inline(always)]
700const fn mont_red_var(x: u128) -> u64 {
701 const NPRIME: u64 = 4294967297;
702 let q = (((x as u64) as u128) * (NPRIME as u128)) as u64;
703 let m = (q as u128) * (M as u128);
704 let y = (((x as i128).wrapping_sub(m as i128)) >> 64) as i64;
705 if x < m {
706 (y + (M as i64)) as u64
707 } else {
708 y as u64
709 }
710}
711
712#[inline(always)]
714const fn mont_red_cst(x: u128) -> u64 {
715 let xl = x as u64;
717 let xh = (x >> 64) as u64;
718 let (a, e) = xl.overflowing_add(xl << 32);
719
720 let b = a.wrapping_sub(a >> 32).wrapping_sub(e as u64);
721
722 let (r, c) = xh.overflowing_sub(b);
723 r.wrapping_sub(0u32.wrapping_sub(c as u32) as u64)
724}
725
726#[inline(always)]
731const fn mont_to_int(x: u64) -> u64 {
732 let (a, e) = x.overflowing_add(x << 32);
733 let b = a.wrapping_sub(a >> 32).wrapping_sub(e as u64);
734
735 let (r, c) = 0u64.overflowing_sub(b);
736 r.wrapping_sub(0u32.wrapping_sub(c as u32) as u64)
737}
738
739#[inline(always)]
742fn equals(lhs: u64, rhs: u64) -> u64 {
743 let t = lhs ^ rhs;
744 !((((t | t.wrapping_neg()) as i64) >> 63) as u64)
745}