1use alloc::{
14 string::{String, ToString},
15 vec::Vec,
16};
17use core::{
18 fmt::{Debug, Display, Formatter},
19 mem,
20 ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
21 slice,
22};
23
24#[cfg(feature = "serde")]
25use serde::{Deserialize, Serialize};
26use utils::{
27 AsBytes, ByteReader, ByteWriter, Deserializable, DeserializationError, Randomizable,
28 Serializable,
29};
30
31use super::{ExtensibleField, FieldElement, StarkField};
32
33#[cfg(test)]
34mod tests;
35
36const M: u128 = 340282366920938463463374557953744961537;
41
42const G: u128 = 23953097886125630542083529559205016746;
44
45const ELEMENT_BYTES: usize = core::mem::size_of::<u128>();
47
48#[derive(Copy, Clone, PartialEq, Eq, Default)]
56#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
57#[cfg_attr(feature = "serde", serde(transparent))]
58pub struct BaseElement(u128);
59
60impl BaseElement {
61 pub const fn new(value: u128) -> Self {
65 BaseElement(if value < M { value } else { value - M })
66 }
67}
68
69impl FieldElement for BaseElement {
70 type PositiveInteger = u128;
71 type BaseField = Self;
72
73 const EXTENSION_DEGREE: usize = 1;
74
75 const ZERO: Self = BaseElement(0);
76 const ONE: Self = BaseElement(1);
77
78 const ELEMENT_BYTES: usize = ELEMENT_BYTES;
79
80 const IS_CANONICAL: bool = true;
81
82 fn inv(self) -> Self {
86 BaseElement(inv(self.0))
87 }
88
89 fn conjugate(&self) -> Self {
90 BaseElement(self.0)
91 }
92
93 fn base_element(&self, i: usize) -> Self::BaseField {
97 match i {
98 0 => *self,
99 _ => panic!("element index must be 0, but was {i}"),
100 }
101 }
102
103 fn slice_as_base_elements(elements: &[Self]) -> &[Self::BaseField] {
104 elements
105 }
106
107 fn slice_from_base_elements(elements: &[Self::BaseField]) -> &[Self] {
108 elements
109 }
110
111 fn elements_as_bytes(elements: &[Self]) -> &[u8] {
115 let p = elements.as_ptr();
117 let len = elements.len() * Self::ELEMENT_BYTES;
118 unsafe { slice::from_raw_parts(p as *const u8, len) }
119 }
120
121 unsafe fn bytes_as_elements(bytes: &[u8]) -> Result<&[Self], DeserializationError> {
122 if !bytes.len().is_multiple_of(Self::ELEMENT_BYTES) {
123 return Err(DeserializationError::InvalidValue(format!(
124 "number of bytes ({}) does not divide into whole number of field elements",
125 bytes.len(),
126 )));
127 }
128
129 let p = bytes.as_ptr();
130 let len = bytes.len() / Self::ELEMENT_BYTES;
131
132 if !(p as usize).is_multiple_of(mem::align_of::<u128>()) {
133 return Err(DeserializationError::InvalidValue(
134 "slice memory alignment is not valid for this field element type".to_string(),
135 ));
136 }
137
138 Ok(slice::from_raw_parts(p as *const Self, len))
139 }
140}
141
142impl StarkField for BaseElement {
143 const MODULUS: Self::PositiveInteger = M;
149 const MODULUS_BITS: u32 = 128;
150
151 const GENERATOR: Self = BaseElement(3);
154
155 const TWO_ADICITY: u32 = 40;
158
159 const TWO_ADIC_ROOT_OF_UNITY: Self = BaseElement(G);
163
164 fn get_modulus_le_bytes() -> Vec<u8> {
165 Self::MODULUS.to_le_bytes().to_vec()
166 }
167
168 #[inline]
169 fn as_int(&self) -> Self::PositiveInteger {
170 self.0
171 }
172}
173
174impl Randomizable for BaseElement {
175 const VALUE_SIZE: usize = Self::ELEMENT_BYTES;
176
177 fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
178 Self::try_from(bytes).ok()
179 }
180}
181
182impl Debug for BaseElement {
183 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
184 write!(f, "{self}")
185 }
186}
187
188impl Display for BaseElement {
189 fn fmt(&self, f: &mut Formatter) -> core::fmt::Result {
190 write!(f, "{}", self.0)
191 }
192}
193
194impl Add for BaseElement {
198 type Output = Self;
199
200 fn add(self, rhs: Self) -> Self {
201 Self(add(self.0, rhs.0))
202 }
203}
204
205impl AddAssign for BaseElement {
206 fn add_assign(&mut self, rhs: Self) {
207 *self = *self + rhs
208 }
209}
210
211impl Sub for BaseElement {
212 type Output = Self;
213
214 fn sub(self, rhs: Self) -> Self {
215 Self(sub(self.0, rhs.0))
216 }
217}
218
219impl SubAssign for BaseElement {
220 fn sub_assign(&mut self, rhs: Self) {
221 *self = *self - rhs;
222 }
223}
224
225impl Mul for BaseElement {
226 type Output = Self;
227
228 fn mul(self, rhs: Self) -> Self {
229 Self(mul(self.0, rhs.0))
230 }
231}
232
233impl MulAssign for BaseElement {
234 fn mul_assign(&mut self, rhs: Self) {
235 *self = *self * rhs
236 }
237}
238
239impl Div for BaseElement {
240 type Output = Self;
241
242 fn div(self, rhs: Self) -> Self {
243 Self(mul(self.0, inv(rhs.0)))
244 }
245}
246
247impl DivAssign for BaseElement {
248 fn div_assign(&mut self, rhs: Self) {
249 *self = *self / rhs
250 }
251}
252
253impl Neg for BaseElement {
254 type Output = Self;
255
256 fn neg(self) -> Self {
257 Self(sub(0, self.0))
258 }
259}
260
261impl ExtensibleField<2> for BaseElement {
268 #[inline(always)]
269 fn mul(a: [Self; 2], b: [Self; 2]) -> [Self; 2] {
270 let z = a[0] * b[0];
271 [z + a[1] * b[1], (a[0] + a[1]) * (b[0] + b[1]) - z]
272 }
273
274 #[inline(always)]
275 fn mul_base(a: [Self; 2], b: Self) -> [Self; 2] {
276 [a[0] * b, a[1] * b]
277 }
278
279 #[inline(always)]
280 fn frobenius(x: [Self; 2]) -> [Self; 2] {
281 [x[0] + x[1], Self::ZERO - x[1]]
282 }
283}
284
285impl ExtensibleField<3> for BaseElement {
291 fn mul(_a: [Self; 3], _b: [Self; 3]) -> [Self; 3] {
292 unimplemented!()
293 }
294
295 #[inline(always)]
296 fn mul_base(_a: [Self; 3], _b: Self) -> [Self; 3] {
297 unimplemented!()
298 }
299
300 #[inline(always)]
301 fn frobenius(_x: [Self; 3]) -> [Self; 3] {
302 unimplemented!()
303 }
304
305 fn is_supported() -> bool {
306 false
307 }
308}
309
310impl From<u64> for BaseElement {
314 fn from(value: u64) -> Self {
316 BaseElement(value as u128)
317 }
318}
319
320impl From<u32> for BaseElement {
321 fn from(value: u32) -> Self {
323 BaseElement(value as u128)
324 }
325}
326
327impl From<u16> for BaseElement {
328 fn from(value: u16) -> Self {
330 BaseElement(value as u128)
331 }
332}
333
334impl From<u8> for BaseElement {
335 fn from(value: u8) -> Self {
337 BaseElement(value as u128)
338 }
339}
340
341impl TryFrom<u128> for BaseElement {
342 type Error = String;
343
344 fn try_from(value: u128) -> Result<Self, Self::Error> {
345 if value >= M {
346 Err(format!(
347 "invalid field element: value {value} is greater than or equal to the field modulus"
348 ))
349 } else {
350 Ok(Self::new(value))
351 }
352 }
353}
354
355impl TryFrom<&'_ [u8]> for BaseElement {
356 type Error = String;
357
358 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
361 let value =
362 bytes.try_into().map(u128::from_le_bytes).map_err(|error| format!("{error}"))?;
363 if value >= M {
364 return Err(format!(
365 "cannot convert bytes into a field element: \
366 value {value} is greater or equal to the field modulus"
367 ));
368 }
369 Ok(BaseElement(value))
370 }
371}
372
373impl AsBytes for BaseElement {
374 fn as_bytes(&self) -> &[u8] {
375 let self_ptr: *const BaseElement = self;
377 unsafe { slice::from_raw_parts(self_ptr as *const u8, BaseElement::ELEMENT_BYTES) }
378 }
379}
380
381impl Serializable for BaseElement {
385 fn write_into<W: ByteWriter>(&self, target: &mut W) {
386 target.write_bytes(&self.0.to_le_bytes());
387 }
388
389 fn get_size_hint(&self) -> usize {
390 self.0.get_size_hint()
391 }
392}
393
394impl Deserializable for BaseElement {
395 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
396 let value = source.read_u128()?;
397 if value >= M {
398 return Err(DeserializationError::InvalidValue(format!(
399 "invalid field element: value {value} is greater than or equal to the field modulus"
400 )));
401 }
402 Ok(BaseElement(value))
403 }
404}
405
406fn add(a: u128, b: u128) -> u128 {
411 let z = M - b;
412 if a < z {
413 M - z + a
414 } else {
415 a - z
416 }
417}
418
419fn sub(a: u128, b: u128) -> u128 {
421 if a < b {
422 M - b + a
423 } else {
424 a - b
425 }
426}
427
428fn mul(a: u128, b: u128) -> u128 {
430 let (x0, x1, x2) = mul_128x64(a, (b >> 64) as u64); let (mut x0, mut x1, x2) = mul_reduce(x0, x1, x2); if x2 == 1 {
433 let (t0, t1) = sub_modulus(x0, x1); x0 = t0;
439 x1 = t1;
440 }
441
442 let (y0, y1, y2) = mul_128x64(a, b as u64); let (mut y1, carry) = add64_with_carry(y1, x0, 0); let (mut y2, y3) = add64_with_carry(y2, x1, carry);
446 if y3 == 1 {
447 let (t0, t1) = sub_modulus(y1, y2); y1 = t0;
453 y2 = t1;
454 }
455
456 let (mut z0, mut z1, z2) = mul_reduce(y0, y1, y2); if z2 == 1 || (z1 == (M >> 64) as u64 && z0 >= (M as u64)) {
460 let (t0, t1) = sub_modulus(z0, z1); z0 = t0;
462 z1 = t1;
463 }
464
465 ((z1 as u128) << 64) + (z0 as u128)
466}
467
468fn inv(x: u128) -> u128 {
471 if x == 0 {
472 return 0;
473 };
474
475 let mut v = M;
477 let (mut a0, mut a1, mut a2) = (0, 0, 0);
478 let (mut u0, mut u1, mut u2) = if x & 1 == 1 {
479 (x as u64, (x >> 64) as u64, 0)
481 } else {
482 add_192x192(x as u64, (x >> 64) as u64, 0, M as u64, (M >> 64) as u64, 0)
484 };
485 let (mut d0, mut d1, mut d2) = ((M as u64) - 1, (M >> 64) as u64, 0);
487
488 while v != 1 {
490 while u2 > 0 || ((u0 as u128) + ((u1 as u128) << 64)) > v {
491 let (t0, t1, t2) = sub_192x192(u0, u1, u2, v as u64, (v >> 64) as u64, 0);
494 u0 = t0;
495 u1 = t1;
496 u2 = t2;
497
498 let (t0, t1, t2) = add_192x192(d0, d1, d2, a0, a1, a2);
500 d0 = t0;
501 d1 = t1;
502 d2 = t2;
503
504 while u0 & 1 == 0 {
505 if d0 & 1 == 1 {
506 let (t0, t1, t2) = add_192x192(d0, d1, d2, M as u64, (M >> 64) as u64, 0);
508 d0 = t0;
509 d1 = t1;
510 d2 = t2;
511 }
512
513 u0 = (u0 >> 1) | ((u1 & 1) << 63);
515 u1 = (u1 >> 1) | ((u2 & 1) << 63);
516 u2 >>= 1;
517
518 d0 = (d0 >> 1) | ((d1 & 1) << 63);
520 d1 = (d1 >> 1) | ((d2 & 1) << 63);
521 d2 >>= 1;
522 }
523 }
524
525 v -= (u0 as u128) + ((u1 as u128) << 64);
527
528 let (t0, t1, t2) = add_192x192(a0, a1, a2, d0, d1, d2);
530 a0 = t0;
531 a1 = t1;
532 a2 = t2;
533
534 while v & 1 == 0 {
535 if a0 & 1 == 1 {
536 let (t0, t1, t2) = add_192x192(a0, a1, a2, M as u64, (M >> 64) as u64, 0);
538 a0 = t0;
539 a1 = t1;
540 a2 = t2;
541 }
542
543 v >>= 1;
544
545 a0 = (a0 >> 1) | ((a1 & 1) << 63);
547 a1 = (a1 >> 1) | ((a2 & 1) << 63);
548 a2 >>= 1;
549 }
550 }
551
552 let mut a = (a0 as u128) + ((a1 as u128) << 64);
554 while a2 > 0 || a >= M {
555 let (t0, t1, t2) = sub_192x192(a0, a1, a2, M as u64, (M >> 64) as u64, 0);
556 a0 = t0;
557 a1 = t1;
558 a2 = t2;
559 a = (a0 as u128) + ((a1 as u128) << 64);
560 }
561
562 a
563}
564
565#[inline]
569fn mul_128x64(a: u128, b: u64) -> (u64, u64, u64) {
570 let z_lo = ((a as u64) as u128) * (b as u128);
571 let z_hi = (a >> 64) * (b as u128);
572 let z_hi = z_hi + (z_lo >> 64);
573 (z_lo as u64, z_hi as u64, (z_hi >> 64) as u64)
574}
575
576#[inline]
577fn mul_reduce(z0: u64, z1: u64, z2: u64) -> (u64, u64, u64) {
578 let (q0, q1, q2) = mul_by_modulus(z2);
579 let (z0, z1, z2) = sub_192x192(z0, z1, z2, q0, q1, q2);
580 (z0, z1, z2)
581}
582
583#[inline]
584fn mul_by_modulus(a: u64) -> (u64, u64, u64) {
585 let a_lo = (a as u128).wrapping_mul(M);
586 let a_hi = if a == 0 { 0 } else { a - 1 };
587 (a_lo as u64, (a_lo >> 64) as u64, a_hi)
588}
589
590#[inline]
591fn sub_modulus(a_lo: u64, a_hi: u64) -> (u64, u64) {
592 let mut z = 0u128.wrapping_sub(M);
593 z = z.wrapping_add(a_lo as u128);
594 z = z.wrapping_add((a_hi as u128) << 64);
595 (z as u64, (z >> 64) as u64)
596}
597
598#[inline]
599fn sub_192x192(a0: u64, a1: u64, a2: u64, b0: u64, b1: u64, b2: u64) -> (u64, u64, u64) {
600 let z0 = (a0 as u128).wrapping_sub(b0 as u128);
601 let z1 = (a1 as u128).wrapping_sub((b1 as u128) + (z0 >> 127));
602 let z2 = (a2 as u128).wrapping_sub((b2 as u128) + (z1 >> 127));
603 (z0 as u64, z1 as u64, z2 as u64)
604}
605
606#[inline]
607fn add_192x192(a0: u64, a1: u64, a2: u64, b0: u64, b1: u64, b2: u64) -> (u64, u64, u64) {
608 let z0 = (a0 as u128) + (b0 as u128);
609 let z1 = (a1 as u128) + (b1 as u128) + (z0 >> 64);
610 let z2 = (a2 as u128) + (b2 as u128) + (z1 >> 64);
611 (z0 as u64, z1 as u64, z2 as u64)
612}
613
614#[inline]
615const fn add64_with_carry(a: u64, b: u64, carry: u64) -> (u64, u64) {
616 let ret = (a as u128) + (b as u128) + (carry as u128);
617 (ret as u64, (ret >> 64) as u64)
618}