1use rand::Rng;
2use core::cmp::Ordering;
3
4use byteorder::{BigEndian, ByteOrder};
5
6#[cfg(feature = "borsh")]
7use borsh::{BorshDeserialize, BorshSerialize};
8
9#[derive(Copy, Clone, Debug, PartialEq, Eq)]
12#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
13#[repr(C)]
14pub struct U256(pub [u128; 2]);
15
16impl From<[u64; 4]> for U256 {
17 fn from(d: [u64; 4]) -> Self {
18 let mut a = [0u128; 2];
19 a[0] = (d[1] as u128) << 64 | d[0] as u128;
20 a[1] = (d[3] as u128) << 64 | d[2] as u128;
21 U256(a)
22 }
23}
24
25impl From<u64> for U256 {
26 fn from(d: u64) -> Self {
27 U256::from([d, 0, 0, 0])
28 }
29}
30
31#[derive(Copy, Clone, Debug, PartialEq, Eq)]
34#[repr(C)]
35pub struct U512(pub [u128; 4]);
36
37impl From<[u64; 8]> for U512 {
38 fn from(d: [u64; 8]) -> Self {
39 let mut a = [0u128; 4];
40 a[0] = (d[1] as u128) << 64 | d[0] as u128;
41 a[1] = (d[3] as u128) << 64 | d[2] as u128;
42 a[2] = (d[5] as u128) << 64 | d[4] as u128;
43 a[3] = (d[7] as u128) << 64 | d[6] as u128;
44 U512(a)
45 }
46}
47
48impl U512 {
49 pub fn new(c1: &U256, c0: &U256, modulo: &U256) -> U512 {
51 let mut res = [0; 4];
52
53 debug_assert_eq!(c1.0.len(), 2);
54 unroll! {
55 for i in 0..2 {
56 mac_digit(i, &mut res, &modulo.0, c1.0[i]);
57 }
58 }
59
60 let mut carry = 0;
61
62 debug_assert_eq!(res.len(), 4);
63 unroll! {
64 for i in 0..2 {
65 res[i] = adc(res[i], c0.0[i], &mut carry);
66 }
67 }
68
69 unroll! {
70 for i in 0..2 {
71 let (a1, a0) = split_u128(res[i + 2]);
72 let (c, r0) = split_u128(a0 + carry);
73 let (c, r1) = split_u128(a1 + c);
74 carry = c;
75
76 res[i + 2] = combine_u128(r1, r0);
77 }
78 }
79
80 debug_assert!(0 == carry);
81
82 U512(res)
83 }
84
85 pub fn from_slice(s: &[u8]) -> Result<U512, Error> {
86 if s.len() != 64 {
87 return Err(Error::InvalidLength {
88 expected: 32,
89 actual: s.len(),
90 });
91 }
92
93 let mut n = [0; 4];
94 for (l, i) in (0..4).rev().zip((0..4).map(|i| i * 16)) {
95 n[l] = BigEndian::read_u128(&s[i..]);
96 }
97
98 Ok(U512(n))
99 }
100
101 pub fn random<R: Rng>(rng: &mut R) -> U512 {
103 U512(rng.gen())
104 }
105
106 pub fn get_bit(&self, n: usize) -> Option<bool> {
107 if n >= 512 {
108 None
109 } else {
110 let part = n / 128;
111 let bit = n - (128 * part);
112
113 Some(self.0[part] & (1 << bit) > 0)
114 }
115 }
116
117 pub fn divrem(&self, modulo: &U256) -> (Option<U256>, U256) {
120 let mut q = Some(U256::zero());
121 let mut r = U256::zero();
122
123 for i in (0..512).rev() {
124 mul2(&mut r.0);
127 assert!(r.set_bit(0, self.get_bit(i).unwrap()));
128 if &r >= modulo {
129 sub_noborrow(&mut r.0, &modulo.0);
130 if q.is_some() && !q.as_mut().unwrap().set_bit(i, true) {
131 q = None
132 }
133 }
134 }
135
136 if q.is_some() && (q.as_ref().unwrap() >= modulo) {
137 (None, r)
138 } else {
139 (q, r)
140 }
141 }
142
143 pub fn interpret(buf: &[u8; 64]) -> U512 {
144 let mut n = [0; 4];
145 for (l, i) in (0..4).rev().zip((0..4).map(|i| i * 16)) {
146 n[l] = BigEndian::read_u128(&buf[i..]);
147 }
148
149 U512(n)
150 }
151}
152
153impl Ord for U512 {
154 #[inline]
155 fn cmp(&self, other: &U512) -> Ordering {
156 for (a, b) in self.0.iter().zip(other.0.iter()).rev() {
157 if *a < *b {
158 return Ordering::Less;
159 } else if *a > *b {
160 return Ordering::Greater;
161 }
162 }
163
164 return Ordering::Equal;
165 }
166}
167
168impl PartialOrd for U512 {
169 #[inline]
170 fn partial_cmp(&self, other: &U512) -> Option<Ordering> {
171 Some(self.cmp(other))
172 }
173}
174
175impl Ord for U256 {
176 #[inline]
177 fn cmp(&self, other: &U256) -> Ordering {
178 for (a, b) in self.0.iter().zip(other.0.iter()).rev() {
179 if *a < *b {
180 return Ordering::Less;
181 } else if *a > *b {
182 return Ordering::Greater;
183 }
184 }
185
186 return Ordering::Equal;
187 }
188}
189
190impl PartialOrd for U256 {
191 #[inline]
192 fn partial_cmp(&self, other: &U256) -> Option<Ordering> {
193 Some(self.cmp(other))
194 }
195}
196
197#[derive(Debug)]
199pub enum Error {
200 InvalidLength { expected: usize, actual: usize },
201}
202
203impl U256 {
204 pub fn from_slice(s: &[u8]) -> Result<U256, Error> {
206 if s.len() != 32 {
207 return Err(Error::InvalidLength {
208 expected: 32,
209 actual: s.len(),
210 });
211 }
212
213 let mut n = [0; 2];
214 for (l, i) in (0..2).rev().zip((0..2).map(|i| i * 16)) {
215 n[l] = BigEndian::read_u128(&s[i..]);
216 }
217
218 Ok(U256(n))
219 }
220
221 pub fn to_big_endian(&self, s: &mut [u8]) -> Result<(), Error> {
222 if s.len() != 32 {
223 return Err(Error::InvalidLength {
224 expected: 32,
225 actual: s.len(),
226 });
227 }
228
229 for (l, i) in (0..2).rev().zip((0..2).map(|i| i * 16)) {
230 BigEndian::write_u128(&mut s[i..], self.0[l]);
231 }
232
233 Ok(())
234 }
235
236 #[inline]
237 pub fn zero() -> U256 {
238 U256([0, 0])
239 }
240
241 #[inline]
242 pub fn one() -> U256 {
243 U256([1, 0])
244 }
245
246 pub fn random<R: Rng>(rng: &mut R, modulo: &U256) -> U256 {
248 U512::random(rng).divrem(modulo).1
249 }
250
251 pub fn is_zero(&self) -> bool {
252 self.0[0] == 0 && self.0[1] == 0
253 }
254
255 pub fn set_bit(&mut self, n: usize, to: bool) -> bool {
256 if n >= 256 {
257 false
258 } else {
259 let part = n / 128;
260 let bit = n - (128 * part);
261
262 if to {
263 self.0[part] |= 1 << bit;
264 } else {
265 self.0[part] &= !(1 << bit);
266 }
267
268 true
269 }
270 }
271
272 pub fn get_bit(&self, n: usize) -> Option<bool> {
273 if n >= 256 {
274 None
275 } else {
276 let part = n / 128;
277 let bit = n - (128 * part);
278
279 Some(self.0[part] & (1 << bit) > 0)
280 }
281 }
282
283 pub fn add(&mut self, other: &U256, modulo: &U256) {
285 add_nocarry(&mut self.0, &other.0);
286
287 if *self >= *modulo {
288 sub_noborrow(&mut self.0, &modulo.0);
289 }
290 }
291
292 pub fn sub(&mut self, other: &U256, modulo: &U256) {
294 if *self < *other {
295 add_nocarry(&mut self.0, &modulo.0);
296 }
297
298 sub_noborrow(&mut self.0, &other.0);
299 }
300
301 pub fn mul(&mut self, other: &U256, modulo: &U256, inv: u128) {
304 mul_reduce(&mut self.0, &other.0, &modulo.0, inv);
305
306 if *self >= *modulo {
307 sub_noborrow(&mut self.0, &modulo.0);
308 }
309 }
310
311 pub fn neg(&mut self, modulo: &U256) {
313 if *self > Self::zero() {
314 let mut tmp = modulo.0;
315 sub_noborrow(&mut tmp, &self.0);
316
317 self.0 = tmp;
318 }
319 }
320
321 #[inline]
322 pub fn is_even(&self) -> bool {
323 self.0[0] & 1 == 0
324 }
325
326 pub fn invert(&mut self, modulo: &U256) {
328 let mut u = *self;
333 let mut v = *modulo;
334 let mut b = U256::one();
335 let mut c = U256::zero();
336
337 while u != U256::one() && v != U256::one() {
338 while u.is_even() {
339 div2(&mut u.0);
340
341 if b.is_even() {
342 div2(&mut b.0);
343 } else {
344 add_nocarry(&mut b.0, &modulo.0);
345 div2(&mut b.0);
346 }
347 }
348 while v.is_even() {
349 div2(&mut v.0);
350
351 if c.is_even() {
352 div2(&mut c.0);
353 } else {
354 add_nocarry(&mut c.0, &modulo.0);
355 div2(&mut c.0);
356 }
357 }
358
359 if u >= v {
360 sub_noborrow(&mut u.0, &v.0);
361 b.sub(&c, modulo);
362 } else {
363 sub_noborrow(&mut v.0, &u.0);
364 c.sub(&b, modulo);
365 }
366 }
367
368 if u == U256::one() {
369 self.0 = b.0;
370 } else {
371 self.0 = c.0;
372 }
373 }
374
375 pub fn bits(&self) -> BitIterator {
378 BitIterator { int: &self, n: 256 }
379 }
380}
381
382pub struct BitIterator<'a> {
383 int: &'a U256,
384 n: usize,
385}
386
387impl<'a> Iterator for BitIterator<'a> {
388 type Item = bool;
389
390 fn next(&mut self) -> Option<bool> {
391 if self.n == 0 {
392 None
393 } else {
394 self.n -= 1;
395
396 self.int.get_bit(self.n)
397 }
398 }
399}
400
401#[inline]
403fn div2(a: &mut [u128; 2]) {
404 let tmp = a[1] << 127;
405 a[1] >>= 1;
406 a[0] >>= 1;
407 a[0] |= tmp;
408}
409
410#[inline]
412fn mul2(a: &mut [u128; 2]) {
413 let tmp = a[0] >> 127;
414 a[0] <<= 1;
415 a[1] <<= 1;
416 a[1] |= tmp;
417}
418
419#[inline(always)]
420fn split_u128(i: u128) -> (u128, u128) {
421 (i >> 64, i & 0xFFFFFFFFFFFFFFFF)
422}
423
424#[inline(always)]
425fn combine_u128(hi: u128, lo: u128) -> u128 {
426 (hi << 64) | lo
427}
428
429#[inline]
430fn adc(a: u128, b: u128, carry: &mut u128) -> u128 {
431 let (a1, a0) = split_u128(a);
432 let (b1, b0) = split_u128(b);
433 let (c, r0) = split_u128(a0 + b0 + *carry);
434 let (c, r1) = split_u128(a1 + b1 + c);
435 *carry = c;
436
437 combine_u128(r1, r0)
438}
439
440#[inline]
441fn add_nocarry(a: &mut [u128; 2], b: &[u128; 2]) {
442 let mut carry = 0;
443
444 for (a, b) in a.into_iter().zip(b.iter()) {
445 *a = adc(*a, *b, &mut carry);
446 }
447
448 debug_assert!(0 == carry);
449}
450
451#[inline]
452fn sub_noborrow(a: &mut [u128; 2], b: &[u128; 2]) {
453 #[inline]
454 fn sbb(a: u128, b: u128, borrow: &mut u128) -> u128 {
455 let (a1, a0) = split_u128(a);
456 let (b1, b0) = split_u128(b);
457 let (b, r0) = split_u128((1 << 64) + a0 - b0 - *borrow);
458 let (b, r1) = split_u128((1 << 64) + a1 - b1 - ((b == 0) as u128));
459
460 *borrow = (b == 0) as u128;
461
462 combine_u128(r1, r0)
463 }
464
465 let mut borrow = 0;
466
467 for (a, b) in a.into_iter().zip(b.iter()) {
468 *a = sbb(*a, *b, &mut borrow);
469 }
470
471 debug_assert!(0 == borrow);
472}
473
474#[inline(always)]
476fn mac_digit(from_index: usize, acc: &mut [u128; 4], b: &[u128; 2], c: u128) {
477 #[inline]
478 fn mac_with_carry(a: u128, b: u128, c: u128, carry: &mut u128) -> u128 {
479 let (b_hi, b_lo) = split_u128(b);
480 let (c_hi, c_lo) = split_u128(c);
481
482 let (a_hi, a_lo) = split_u128(a);
483 let (carry_hi, carry_lo) = split_u128(*carry);
484 let (x_hi, x_lo) = split_u128(b_lo * c_lo + a_lo + carry_lo);
485 let (y_hi, y_lo) = split_u128(b_lo * c_hi);
486 let (z_hi, z_lo) = split_u128(b_hi * c_lo);
487 let (r_hi, r_lo) = split_u128((x_hi + y_lo) + (z_lo + a_hi) + carry_hi);
489
490 *carry = (b_hi * c_hi) + r_hi + y_hi + z_hi;
491
492 combine_u128(r_lo, x_lo)
493 }
494
495 if c == 0 {
496 return;
497 }
498
499 let mut carry = 0;
500
501 debug_assert_eq!(acc.len(), 4);
502 unroll! {
503 for i in 0..2 {
504 let a_index = i + from_index;
505 acc[a_index] = mac_with_carry(acc[a_index], b[i], c, &mut carry);
506 }
507 }
508 unroll! {
509 for i in 0..2 {
510 let a_index = i + from_index + 2;
511 if a_index < 4 {
512 let (a_hi, a_lo) = split_u128(acc[a_index]);
513 let (carry_hi, carry_lo) = split_u128(carry);
514 let (x_hi, x_lo) = split_u128(a_lo + carry_lo);
515 let (r_hi, r_lo) = split_u128(x_hi + a_hi + carry_hi);
516
517 carry = r_hi;
518
519 acc[a_index] = combine_u128(r_lo, x_lo);
520 }
521 }
522 }
523
524 debug_assert!(carry == 0);
525}
526
527#[inline]
528fn mul_reduce(this: &mut [u128; 2], by: &[u128; 2], modulus: &[u128; 2], inv: u128) {
529 let mut res = [0; 2 * 2];
534 unroll! {
535 for i in 0..2 {
536 mac_digit(i, &mut res, by, this[i]);
537 }
538 }
539
540 unroll! {
541 for i in 0..2 {
542 let k = inv.wrapping_mul(res[i]);
543 mac_digit(i, &mut res, modulus, k);
544 }
545 }
546
547 this.copy_from_slice(&res[2..]);
548}
549
550#[test]
551fn setting_bits() {
552 let rng = &mut ::rand::thread_rng();
553 let modulo = U256::from([0xffffffffffffffff; 4]);
554
555 let a = U256::random(rng, &modulo);
556 let mut e = U256::zero();
557 for (i, b) in a.bits().enumerate() {
558 assert!(e.set_bit(255 - i, b));
559 }
560
561 assert_eq!(a, e);
562}
563
564#[test]
565fn from_slice() {
566 let tst = U256::one();
567 let mut s = [0u8; 32];
568 s[31] = 1;
569
570 let num =
571 U256::from_slice(&s).expect("U256 should initialize ok from slice in `from_slice` test");
572 assert_eq!(num, tst);
573}
574
575#[test]
576fn to_big_endian() {
577 let num = U256::one();
578 let mut s = [0u8; 32];
579
580 num.to_big_endian(&mut s)
581 .expect("U256 should convert to bytes ok in `to_big_endian` test");
582 assert_eq!(
583 s,
584 [
585 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
586 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 1u8,
587 ]
588 );
589}
590
591#[test]
592fn testing_divrem() {
593 let rng = &mut ::rand::thread_rng();
594
595 let modulo = U256::from([
596 0x3c208c16d87cfd47,
597 0x97816a916871ca8d,
598 0xb85045b68181585d,
599 0x30644e72e131a029,
600 ]);
601
602 for _ in 0..100 {
603 let c0 = U256::random(rng, &modulo);
604 let c1 = U256::random(rng, &modulo);
605
606 let c1q_plus_c0 = U512::new(&c1, &c0, &modulo);
607
608 let (new_c1, new_c0) = c1q_plus_c0.divrem(&modulo);
609
610 assert!(c1 == new_c1.unwrap());
611 assert!(c0 == new_c0);
612 }
613
614 {
615 let a = U512::from([
617 0x3c208c16d87cfd47,
618 0x97816a916871ca8d,
619 0xb85045b68181585d,
620 0x30644e72e131a029,
621 0,
622 0,
623 0,
624 0,
625 ]);
626
627 let (c1, c0) = a.divrem(&modulo);
628 assert_eq!(c1.unwrap(), U256::one());
629 assert_eq!(c0, U256::zero());
630 }
631
632 {
633 let a = U512::from([
635 0x3b5458a2275d69b0,
636 0xa602072d09eac101,
637 0x4a50189c6d96cadc,
638 0x04689e957a1242c8,
639 0x26edfa5c34c6b38d,
640 0xb00b855116375606,
641 0x599a6f7c0348d21c,
642 0x0925c4b8763cbf9c,
643 ]);
644
645 let (c1, c0) = a.divrem(&modulo);
646 assert_eq!(
647 c1.unwrap(),
648 U256::from([
649 0x3c208c16d87cfd46,
650 0x97816a916871ca8d,
651 0xb85045b68181585d,
652 0x30644e72e131a029
653 ])
654 );
655 assert_eq!(
656 c0,
657 U256::from([
658 0x3c208c16d87cfd46,
659 0x97816a916871ca8d,
660 0xb85045b68181585d,
661 0x30644e72e131a029
662 ])
663 );
664 }
665
666 {
667 let a = U512::from([
669 0x3b5458a2275d69af,
670 0xa602072d09eac101,
671 0x4a50189c6d96cadc,
672 0x04689e957a1242c8,
673 0x26edfa5c34c6b38d,
674 0xb00b855116375606,
675 0x599a6f7c0348d21c,
676 0x0925c4b8763cbf9c,
677 ]);
678
679 let (c1, c0) = a.divrem(&modulo);
680
681 assert_eq!(
682 c1.unwrap(),
683 U256::from([
684 0x3c208c16d87cfd46,
685 0x97816a916871ca8d,
686 0xb85045b68181585d,
687 0x30644e72e131a029
688 ])
689 );
690 assert_eq!(
691 c0,
692 U256::from([
693 0x3c208c16d87cfd45,
694 0x97816a916871ca8d,
695 0xb85045b68181585d,
696 0x30644e72e131a029
697 ])
698 );
699 }
700
701 {
702 let a = U512::from([
704 0xffffffffffffffff,
705 0xffffffffffffffff,
706 0xffffffffffffffff,
707 0xffffffffffffffff,
708 0xffffffffffffffff,
709 0xffffffffffffffff,
710 0xffffffffffffffff,
711 0xffffffffffffffff,
712 ]);
713
714 let (c1, c0) = a.divrem(&modulo);
715 assert!(c1.is_none());
716 assert_eq!(
717 c0,
718 U256::from([
719 0xf32cfc5b538afa88,
720 0xb5e71911d44501fb,
721 0x47ab1eff0a417ff6,
722 0x06d89f71cab8351f
723 ])
724 );
725 }
726
727 {
728 let a = U512::from([
730 0x3b5458a2275d69b1,
731 0xa602072d09eac101,
732 0x4a50189c6d96cadc,
733 0x04689e957a1242c8,
734 0x26edfa5c34c6b38d,
735 0xb00b855116375606,
736 0x599a6f7c0348d21c,
737 0x0925c4b8763cbf9c,
738 ]);
739
740 let (c1, c0) = a.divrem(&modulo);
741 assert!(c1.is_none());
742 assert_eq!(c0, U256::zero());
743 }
744
745 {
746 let a = U512::from([
748 0x3b5458a2275d69b2,
749 0xa602072d09eac101,
750 0x4a50189c6d96cadc,
751 0x04689e957a1242c8,
752 0x26edfa5c34c6b38d,
753 0xb00b855116375606,
754 0x599a6f7c0348d21c,
755 0x0925c4b8763cbf9c,
756 ]);
757
758 let (c1, c0) = a.divrem(&modulo);
759 assert!(c1.is_none());
760 assert_eq!(c0, U256::one());
761 }
762
763 {
764 let modulo = U256::from([
765 0x43e1f593f0000001,
766 0x2833e84879b97091,
767 0xb85045b68181585d,
768 0x30644e72e131a029,
769 ]);
770
771 let a = U512::from([
773 0xffffffffffffffff,
774 0xffffffffffffffff,
775 0xffffffffffffffff,
776 0xffffffffffffffff,
777 0xffffffffffffffff,
778 0xffffffffffffffff,
779 0xffffffffffffffff,
780 0x07ffffffffffffff,
781 ]);
782
783 let (c1, c0) = a.divrem(&modulo);
784
785 assert!(c1.unwrap() < modulo);
786 assert!(c0 < modulo);
787 }
788}